summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
Diffstat (limited to 'examples')
-rw-r--r--examples/tests.hs490
1 files changed, 11 insertions, 479 deletions
diff --git a/examples/tests.hs b/examples/tests.hs
index b6c9a36..cd923cd 100644
--- a/examples/tests.hs
+++ b/examples/tests.hs
@@ -1,138 +1,13 @@
1{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-}
2
3module Main where 1module Main where
4 2
5import Numeric.GSL hiding (sin,cos,exp,choose)
6import Numeric.LinearAlgebra 3import Numeric.LinearAlgebra
7import Numeric.LinearAlgebra.LAPACK 4import Numeric.LinearAlgebra.Tests
8import qualified Numeric.GSL.Matrix as GSL
9import Test.QuickCheck hiding (test)
10import Test.HUnit hiding ((~:),test)
11import System.Random(randomRs,mkStdGen) 5import System.Random(randomRs,mkStdGen)
12import System.Info 6import Test.HUnit hiding (test)
13import Data.List(foldl1', transpose)
14import System(getArgs) 7import System(getArgs)
15import Debug.Trace(trace)
16
17debug x = trace (show x) x
18
19type RM = Matrix Double
20type CM = Matrix (Complex Double)
21
22-- relative error
23dist :: (Normed t, Num t) => t -> t -> Double
24dist a b = r
25 where norm = pnorm Infinity
26 na = norm a
27 nb = norm b
28 nab = norm (a-b)
29 mx = max na nb
30 mn = min na nb
31 r = if mn < eps
32 then mx
33 else nab/mx
34
35infixl 4 |~|
36a |~| b = a :~10~: b
37
38data Aprox a = (:~) a Int
39
40(~:) :: (Normed a, Num a) => Aprox a -> a -> Bool
41a :~n~: b = dist a b < 10^^(-n)
42
43
44maxdim = 10
45
46instance (Arbitrary a, RealFloat a) => Arbitrary (Complex a) where
47 arbitrary = do
48 r <- arbitrary
49 i <- arbitrary
50 return (r:+i)
51 coarbitrary = undefined
52
53instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where
54 arbitrary = do --m <- sized $ \max -> choose (1,1+3*max)
55 m <- choose (1,maxdim)
56 n <- choose (1,maxdim)
57 l <- vector (m*n)
58 ctype <- arbitrary
59 let h = if ctype then (m><n) else (m>|<n)
60 trMode <- arbitrary
61 let tr = if trMode then trans else id
62 return $ tr (h l)
63 coarbitrary = undefined
64
65data PairM a = PairM (Matrix a) (Matrix a) deriving Show
66instance (Num a, Element a, Arbitrary a) => Arbitrary (PairM a) where
67 arbitrary = do
68 a <- choose (1,maxdim)
69 b <- choose (1,maxdim)
70 c <- choose (1,maxdim)
71 l1 <- vector (a*b)
72 l2 <- vector (b*c)
73 return $ PairM ((a><b) (map fromIntegral (l1::[Int]))) ((b><c) (map fromIntegral (l2::[Int])))
74 --return $ PairM ((a><b) l1) ((b><c) l2)
75 coarbitrary = undefined
76
77data SqM a = SqM (Matrix a) deriving Show
78sqm (SqM a) = a
79instance (Element a, Arbitrary a) => Arbitrary (SqM a) where
80 arbitrary = do
81 n <- choose (1,maxdim)
82 l <- vector (n*n)
83 return $ SqM $ (n><n) l
84 coarbitrary = undefined
85
86data Sym a = Sym (Matrix a) deriving Show
87sym (Sym a) = a
88instance (Linear Vector a, Arbitrary a) => Arbitrary (Sym a) where
89 arbitrary = do
90 SqM m <- arbitrary
91 return $ Sym (m + trans m)
92 coarbitrary = undefined
93 8
94data Her = Her (Matrix (Complex Double)) deriving Show
95her (Her a) = a
96instance {-(Field a, Arbitrary a, Num a) =>-} Arbitrary Her where
97 arbitrary = do
98 SqM m <- arbitrary
99 return $ Her (m + ctrans m)
100 coarbitrary = undefined
101 9
102data PairSM a = PairSM (Matrix a) (Matrix a) deriving Show 10pseudorandomR seed (n,m) = reshape m $ fromList $ take (n*m) $ randomRs (-100,100) $ mkStdGen seed
103instance (Num a, Field a, Arbitrary a) => Arbitrary (PairSM a) where
104 arbitrary = do
105 a <- choose (1,maxdim)
106 c <- choose (1,maxdim)
107 l1 <- vector (a*a)
108 l2 <- vector (a*c)
109 return $ PairSM ((a><a) (map fromIntegral (l1::[Int]))) ((a><c) (map fromIntegral (l2::[Int])))
110 --return $ PairSM ((a><a) l1) ((a><c) l2)
111 coarbitrary = undefined
112
113instance (Field a, Arbitrary a) => Arbitrary (Vector a) where
114 arbitrary = do --m <- sized $ \max -> choose (1,1+3*max)
115 m <- choose (1,maxdim^2)
116 l <- vector m
117 return $ fromList l
118 coarbitrary = undefined
119
120data PairV a = PairV (Vector a) (Vector a)
121instance (Field a, Arbitrary a) => Arbitrary (PairV a) where
122 arbitrary = do --m <- sized $ \max -> choose (1,1+3*max)
123 m <- choose (1,maxdim^2)
124 l1 <- vector m
125 l2 <- vector m
126 return $ PairV (fromList l1) (fromList l2)
127 coarbitrary = undefined
128
129----------------------------------------------------------------------
130
131test str b = TestCase $ assertBool str b
132
133----------------------------------------------------------------------
134
135pseudorandomR seed (n,m) = reshape m $ fromList $ take (n*m) $ randomRs (-100,100) $ mkStdGen seed
136 11
137pseudorandomC seed (n,m) = toComplex (pseudorandomR seed (n,m), pseudorandomR (seed+1) (n,m)) 12pseudorandomC seed (n,m) = toComplex (pseudorandomR seed (n,m), pseudorandomR (seed+1) (n,m))
138 13
@@ -141,366 +16,23 @@ bigmat = m + trans m :: RM
141bigmatc = mc + ctrans mc ::CM 16bigmatc = mc + ctrans mc ::CM
142 where mc = pseudorandomC 19 (1000,1000) 17 where mc = pseudorandomC 19 (1000,1000)
143 18
144---------------------------------------------------------------------- 19utest str b = TestCase $ assertBool str b
145
146
147m = (3><3)
148 [ 1, 2, 3
149 , 4, 5, 7
150 , 2, 8, 4 :: Double
151 ]
152
153mc = (3><3)
154 [ 1, 2, 3
155 , 4, 5, 7
156 , 2, 8, i
157 ]
158
159
160mr = (3><4)
161 [ 1, 2, 3, 4,
162 2, 4, 6, 8,
163 1, 1, 1, 2:: Double
164 ]
165
166mrc = (3><4)
167 [ 1, 2, 3, 4,
168 2, 4, 6, 8,
169 i, i, i, 2
170 ]
171
172a = (3><4)
173 [ 1, 0, 0, 0
174 , 0, 2, 0, 0
175 , 0, 0, 0, 0 :: Double
176 ]
177
178b = (3><4)
179 [ 1, 0, 0, 0
180 , 0, 2, 3, 0
181 , 0, 0, 4, 0 :: Double
182 ]
183
184ac = (2><3) [1 .. 6::Double]
185bc = (3><4) [7 .. 18::Double]
186
187af = (2>|<3) [1,4,2,5,3,6::Double]
188bf = (3>|<4) [7,11,15,8,12,16,9,13,17,10,14,18::Double]
189
190-------------------------------------------------------
191 20
192feye n = flipud (ident n) :: Matrix Double 21feye n = flipud (ident n) :: Matrix Double
193 22
194
195luTest1 m = m |~| p <> l <> u
196 where (l,u,p,_) = lu m
197
198detTest1 = det m == 26
199 && det mc == 38 :+ (-3)
200 && det (feye 2) == -1
201
202detTest2 m = s d1 |~| s d2
203 where d1 = det m
204 d2 = det' m * det q
205 det' m = product $ toList $ takeDiag r
206 (q,r) = qr m
207 s x = fromList [x]
208
209invTest m = degenerate m || m <> inv m |~| ident (rows m)
210
211pinvTest m = m <> p <> m |~| m
212 && p <> m <> p |~| p
213 && hermitian (m<>p)
214 && hermitian (p<>m)
215 where p = pinv m
216
217square m = rows m == cols m
218
219unitary m = square m && m <> ctrans m |~| ident (rows m)
220
221hermitian m = m |~| ctrans m
222
223upperTriang m = rows m == 1 || down == z
224 where down = fromList $ concat $ zipWith drop [1..] (toLists (ctrans m))
225 z = constant 0 (dim down)
226
227upperHessenberg m = rows m < 3 || down == z
228 where down = fromList $ concat $ zipWith drop [2..] (toLists (ctrans m))
229 z = constant 0 (dim down)
230
231svdTest svd m = u <> real d <> trans v |~| m
232 && unitary u && unitary v
233 where (u,d,v) = full svd m
234
235svdTest' svd m = m |~| 0 || u <> real (diag s) <> trans v |~| m
236 where (u,s,v) = economy svd m
237
238eigTest m = complex m <> v |~| v <> diag s
239 where (s, v) = eig m
240
241eigTestSH m = m <> v |~| v <> real (diag s)
242 && unitary v
243 && m |~| v <> real (diag s) <> ctrans v
244 where (s, v) = eigSH m
245
246zeros (r,c) = reshape c (constant 0 (r*c))
247
248ones (r,c) = zeros (r,c) + 1
249
250degenerate m = rank m < min (rows m) (cols m)
251
252prec = 1E-15
253
254singular m = s1 < prec || s2/s1 < prec
255 where (_,ss,_) = svd m
256 s = toList ss
257 s1 = maximum s
258 s2 = minimum s
259
260nullspaceTest m = null nl || m <> n |~| zeros (r,c) -- 0
261 where nl = nullspacePrec 1 m
262 n = fromColumns nl
263 r = rows m
264 c = cols m - rank m
265
266--------------------------------------------------------------------
267
268polyEval cs x = foldr (\c ac->ac*x+c) 0 cs
269
270polySolveTest' p = length p <2 || last p == 0|| 1E-8 > maximum (map magnitude $ map (polyEval (map (:+0) p)) (polySolve p))
271
272
273polySolveTest = test "polySolve" (polySolveTest' [1,2,3,4])
274
275---------------------------------------------------------------------
276
277quad f a b = fst $ integrateQAGS 1E-9 100 f a b
278
279-- A multiple integral can be easily defined using partial application
280quad2 f a b g1 g2 = quad h a b
281 where h x = quad (f x) (g1 x) (g2 x)
282
283volSphere r = 8 * quad2 (\x y -> sqrt (r*r-x*x-y*y))
284 0 r (const 0) (\x->sqrt (r*r-x*x))
285
286epsTol = 1E-8::Double
287
288integrateTest = test "integrate" (abs (volSphere 2.5 - 4/3*pi*2.5^3) < epsTol)
289
290---------------------------------------------------------------------
291
292besselTest = test "bessel_J0_e" ( abs (r-expected) < e )
293 where (r,e) = bessel_J0_e 5.0
294 expected = -0.17759677131433830434739701
295
296exponentialTest = test "exp_e10_e" ( abs (v*10^e - expected) < 4E-2 )
297 where (v,e,err) = exp_e10_e 30.0
298 expected = exp 30.0
299
300gammaTest = test "gamma" (gamma 5 == 24.0)
301
302---------------------------------------------------------------------
303
304cholRTest = chol ((2><2) [1,2,2,9::Double]) == (2><2) [1,2,0,2.23606797749979]
305cholCTest = chol ((2><2) [1,2,2,9::Complex Double]) == (2><2) [1,2,0,2.23606797749979]
306
307---------------------------------------------------------------------
308
309qrTest qr m = q <> r |~| m && unitary q && upperTriang r
310 where (q,r) = qr m
311
312---------------------------------------------------------------------
313
314hessTest m = m |~| p <> h <> ctrans p && unitary p && upperHessenberg h
315 where (p,h) = hess m
316
317---------------------------------------------------------------------
318
319schurTest1 m = m |~| u <> s <> ctrans u && unitary u && upperTriang s
320 where (u,s) = schur m
321
322schurTest2 m = m |~| u <> s <> ctrans u && unitary u && upperHessenberg s -- fixme
323 where (u,s) = schur m
324
325---------------------------------------------------------------------
326
327nd1 = (3><3) [ 1/2, 1/4, 1/4
328 , 0/1, 1/2, 1/4
329 , 1/2, 1/4, 1/2 :: Double]
330
331nd2 = (2><2) [1, 0, 1, 1:: Complex Double]
332
333expmTest1 = expm nd1 :~14~: (3><3)
334 [ 1.762110887278176
335 , 0.478085470590435
336 , 0.478085470590435
337 , 0.104719410945666
338 , 1.709751181805343
339 , 0.425725765117601
340 , 0.851451530235203
341 , 0.530445176063267
342 , 1.814470592751009 ]
343
344expmTest2 = expm nd2 :~15~: (2><2)
345 [ 2.718281828459045
346 , 0.000000000000000
347 , 2.718281828459045
348 , 2.718281828459045 ]
349
350expmTestDiag m = expm (logm m) |~| complex m
351 where logm m = matFunc Prelude.log m
352
353
354
355---------------------------------------------------------------------
356
357asFortran m = (rows m >|< cols m) $ toList (flatten $ trans m)
358asC m = (rows m >< cols m) $ toList (flatten m)
359
360mulC a b = a <> b
361mulF a b = trans $ trans b <> trans a
362
363-------------------------------------------------------------------------
364
365multiplyG a b = reshape (cols b) $ fromList $ concat $ multiplyL (toLists a) (toLists b)
366 where multiplyL a b = [[dotL x y | y <- transpose b] | x <- a]
367 dotL a b = sum (zipWith (*) a b)
368
369r >|< c = f where
370 f l | dim v == r*c = reshapeF r v
371 | otherwise = error "(>|<)"
372 where v = fromList l
373 reshapeF r = trans . reshape r
374
375---------------------------------------------------------------------
376
377rot :: Double -> Matrix Double
378rot a = (3><3) [ c,0,s
379 , 0,1,0
380 ,-s,0,c ]
381 where c = cos a
382 s = sin a
383
384fun n = foldl1' (<>) (map rot angles)
385 where angles = toList $ linspace n (0,1)
386
387rotTest = fun (10^5) :~12~: rot 5E4
388
389---------------------------------------------------------------------
390
391tests = do
392 setErrorHandlerOff
393 putStrLn "--------- internal -----"
394 quickCheck ((\m -> m == trans m).sym :: Sym Double -> Bool)
395 quickCheck ((\m -> m == trans m).sym :: Sym (Complex Double) -> Bool)
396 quickCheck $ \l -> null l || (toList . fromList) l == (l :: [Double])
397 quickCheck $ \l -> null l || (toList . fromList) l == (l :: [Complex Double])
398 quickCheck $ \m -> m == asC (m :: RM)
399 quickCheck $ \m -> m == asC (m :: CM)
400 quickCheck $ \m -> m == asFortran (m :: RM)
401 quickCheck $ \m -> m == asFortran (m :: CM)
402 quickCheck $ \m -> m == (asC . asFortran) (m :: RM)
403 quickCheck $ \m -> m == (asC . asFortran) (m :: CM)
404 runTestTT $ TestList
405 [ test "1E5 rots" rotTest
406 ]
407 putStrLn "--------- multiply ----"
408 quickCheck $ \(PairM m1 m2) -> mulC m1 m2 == mulF m1 (m2 :: RM)
409 quickCheck $ \(PairM m1 m2) -> mulC m1 m2 == mulF m1 (m2 :: CM)
410 quickCheck $ \(PairM m1 m2) -> mulC m1 m2 == trans (mulF (trans m2) (trans m1 :: RM))
411 quickCheck $ \(PairM m1 m2) -> mulC m1 m2 == trans (mulF (trans m2) (trans m1 :: CM))
412 quickCheck $ \(PairM m1 m2) -> mulC m1 m2 == multiplyG m1 (m2 :: RM)
413 quickCheck $ \(PairM m1 m2) -> mulC m1 m2 == multiplyG m1 (m2 :: CM)
414 putStrLn "--------- lu ---------"
415 quickCheck (luTest1 :: RM->Bool)
416 quickCheck (luTest1 :: CM->Bool)
417 quickCheck (detTest2 . sqm :: SqM Double -> Bool)
418 quickCheck (detTest2 . sqm :: SqM (Complex Double) -> Bool)
419 runTestTT $ TestList
420 [ test "det1" detTest1
421 ]
422 putStrLn "--------- svd ---------"
423 quickCheck (svdTest svdR)
424 quickCheck (svdTest svdRdd)
425 quickCheck (svdTest svdC)
426 quickCheck (svdTest' svdR)
427 quickCheck (svdTest' svdRdd)
428 quickCheck (svdTest' svdC)
429 quickCheck (svdTest' GSL.svdg)
430 putStrLn "--------- eig ---------"
431 quickCheck (eigTest . sqm :: SqM Double -> Bool)
432 quickCheck (eigTest . sqm :: SqM (Complex Double) -> Bool)
433 quickCheck (eigTestSH . sym :: Sym Double -> Bool)
434 quickCheck (eigTestSH . her :: Her -> Bool)
435 putStrLn "--------- inv ------"
436 quickCheck (invTest . sqm :: SqM Double -> Bool)
437 quickCheck (invTest . sqm :: SqM (Complex Double) -> Bool)
438 putStrLn "--------- pinv ------"
439 quickCheck (pinvTest ::RM->Bool)
440 if os == "mingw32"
441 then putStrLn "complex pinvTest skipped in this OS"
442 else quickCheck (pinvTest ::CM->Bool)
443 putStrLn "--------- chol ------"
444 runTestTT $ TestList
445 [ test "cholR" cholRTest
446 , test "cholC" cholCTest
447 ]
448 putStrLn "--------- qr ---------"
449 quickCheck (qrTest GSL.qr)
450 quickCheck (qrTest (GSL.unpackQR . GSL.qrPacked))
451 quickCheck (qrTest ( unpackQR . GSL.qrPacked))
452 quickCheck (qrTest qr ::RM->Bool)
453 quickCheck (qrTest qr ::CM->Bool)
454 putStrLn "--------- hess --------"
455 quickCheck (hessTest . sqm ::SqM Double->Bool)
456 quickCheck (hessTest . sqm ::SqM (Complex Double) -> Bool)
457 putStrLn "--------- schur --------"
458 quickCheck (schurTest2 . sqm ::SqM Double->Bool)
459 if os == "mingw32"
460 then putStrLn "complex schur skipped in this OS"
461 else quickCheck (schurTest1 . sqm ::SqM (Complex Double) -> Bool)
462 putStrLn "--------- expm --------"
463 runTestTT $ TestList
464 [ test "expmd" (expmTestDiag $ (2><2) [1,2,3,5 :: Double])
465 , test "expm1" (expmTest1)
466 , test "expm2" (expmTest2)
467 ]
468 putStrLn "--------- nullspace ------"
469 quickCheck (nullspaceTest :: RM -> Bool)
470 quickCheck (nullspaceTest :: CM -> Bool)
471 putStrLn "--------- vector operations ------"
472 quickCheck $ (\u -> sin u ^ 2 + cos u ^ 2 |~| (1::RM))
473 quickCheck $ (\u -> sin u ** 2 + cos u ** 2 |~| (1::RM))
474 quickCheck $ (\u -> cos u * tan u |~| sin (u::RM))
475 quickCheck $ (\u -> (cos u * tan u) :~6~: sin (u::CM))
476 runTestTT $ TestList
477 [ test "arith1" $ ((ones (100,100) * 5 + 2)/0.5 - 7)**2 |~| (49 :: RM)
478 , test "arith2" $ (((1+i) .* ones (100,100) * 5 + 2)/0.5 - 7)**2 |~| ( (140*i-51).*1 :: CM)
479 , test "arith3" $ exp (i.*ones(10,10)*pi) + 1 |~| 0
480 , test "<\\>" $ (3><2) [2,0,0,3,1,1::Double] <\> 3|>[4,9,5] |~| 2|>[2,3]
481 ]
482 putStrLn "--------- GSL ------"
483 quickCheck $ \v -> ifft (fft v) |~| v
484 runTestTT $ TestList
485 [ gammaTest
486 , besselTest
487 , exponentialTest
488 , integrateTest
489 , polySolveTest
490 ]
491
492bigtests = do 23bigtests = do
493 putStrLn "--------- big matrices -----" 24 putStrLn "--------- big matrices -----"
494 runTestTT $ TestList 25 runTestTT $ TestList
495 [ test "eigS" $ eigTestSH bigmat 26 [ utest "eigS" $ eigSHProp bigmat
496 , test "eigH" $ eigTestSH bigmatc 27 , utest "eigH" $ eigSHProp bigmatc
497 , test "eigR" $ eigTest bigmat 28 , utest "eigR" $ eigProp bigmat
498 , test "eigC" $ eigTest bigmatc 29 , utest "eigC" $ eigProp bigmatc
499 , test "det" $ det (feye 1000) == 1 && det (feye 1002) == -1 30 , utest "det" $ det (feye 1000) == 1 && det (feye 1002) == -1
500 ] 31 ]
32 return ()
501 33
502main = do 34main = do
503 args <- getArgs 35 args <- getArgs
504 if "--big" `elem` args 36 if "--big" `elem` args
505 then bigtests 37 then bigtests
506 else tests 38 else runTests 20