diff options
author | Alberto Ruiz <aruiz@um.es> | 2011-12-14 13:08:43 +0100 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2011-12-14 13:08:43 +0100 |
commit | 77552d080e88fc70312f55fd3303fac3464ab46e (patch) | |
tree | 1dc87dd22ce0da0f1807765568fbc04285bf3621 /packages/tests/src/Numeric/LinearAlgebra | |
parent | c3bda2d38c432fb53ce456cba295b097fd4d6ad1 (diff) |
new package hmatrix-tests
Diffstat (limited to 'packages/tests/src/Numeric/LinearAlgebra')
3 files changed, 1254 insertions, 0 deletions
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs new file mode 100644 index 0000000..69ef1b3 --- /dev/null +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs | |||
@@ -0,0 +1,731 @@ | |||
1 | {-# LANGUAGE CPP #-} | ||
2 | {-# OPTIONS_GHC -fno-warn-unused-imports -fno-warn-incomplete-patterns #-} | ||
3 | ----------------------------------------------------------------------------- | ||
4 | {- | | ||
5 | Module : Numeric.LinearAlgebra.Tests | ||
6 | Copyright : (c) Alberto Ruiz 2007-11 | ||
7 | License : GPL-style | ||
8 | |||
9 | Maintainer : Alberto Ruiz (aruiz at um dot es) | ||
10 | Stability : provisional | ||
11 | Portability : portable | ||
12 | |||
13 | Some tests. | ||
14 | |||
15 | -} | ||
16 | |||
17 | module Numeric.LinearAlgebra.Tests( | ||
18 | -- module Numeric.LinearAlgebra.Tests.Instances, | ||
19 | -- module Numeric.LinearAlgebra.Tests.Properties, | ||
20 | -- qCheck, | ||
21 | runTests, | ||
22 | runBenchmarks | ||
23 | -- , findNaN | ||
24 | --, runBigTests | ||
25 | ) where | ||
26 | |||
27 | --import Data.Packed.Random | ||
28 | import Numeric.LinearAlgebra | ||
29 | import Numeric.LinearAlgebra.LAPACK | ||
30 | import Numeric.LinearAlgebra.Tests.Instances | ||
31 | import Numeric.LinearAlgebra.Tests.Properties | ||
32 | import Test.HUnit hiding ((~:),test,Testable,State) | ||
33 | import System.Info | ||
34 | import Data.List(foldl1') | ||
35 | import Numeric.GSL | ||
36 | import Prelude hiding ((^)) | ||
37 | import qualified Prelude | ||
38 | import System.CPUTime | ||
39 | import Text.Printf | ||
40 | import Data.Packed.Development(unsafeFromForeignPtr,unsafeToForeignPtr) | ||
41 | import Control.Arrow((***)) | ||
42 | import Debug.Trace | ||
43 | |||
44 | import Test.QuickCheck(Arbitrary,arbitrary,coarbitrary,choose,vector | ||
45 | ,sized,classify,Testable,Property | ||
46 | ,quickCheckWith,maxSize,stdArgs,shrink) | ||
47 | |||
48 | qCheck n = quickCheckWith stdArgs {maxSize = n} | ||
49 | |||
50 | a ^ b = a Prelude.^ (b :: Int) | ||
51 | |||
52 | utest str b = TestCase $ assertBool str b | ||
53 | |||
54 | a ~~ b = fromList a |~| fromList b | ||
55 | |||
56 | feye n = flipud (ident n) :: Matrix Double | ||
57 | |||
58 | ----------------------------------------------------------- | ||
59 | |||
60 | detTest1 = det m == 26 | ||
61 | && det mc == 38 :+ (-3) | ||
62 | && det (feye 2) == -1 | ||
63 | where | ||
64 | m = (3><3) | ||
65 | [ 1, 2, 3 | ||
66 | , 4, 5, 7 | ||
67 | , 2, 8, 4 :: Double | ||
68 | ] | ||
69 | mc = (3><3) | ||
70 | [ 1, 2, 3 | ||
71 | , 4, 5, 7 | ||
72 | , 2, 8, i | ||
73 | ] | ||
74 | |||
75 | detTest2 = inv1 |~| inv2 && [det1] ~~ [det2] | ||
76 | where | ||
77 | m = complex (feye 6) | ||
78 | inv1 = inv m | ||
79 | det1 = det m | ||
80 | (inv2,(lda,sa)) = invlndet m | ||
81 | det2 = sa * exp lda | ||
82 | |||
83 | -------------------------------------------------------------------- | ||
84 | |||
85 | polyEval cs x = foldr (\c ac->ac*x+c) 0 cs | ||
86 | |||
87 | polySolveProp p = length p <2 || last p == 0|| 1E-8 > maximum (map magnitude $ map (polyEval (map (:+0) p)) (polySolve p)) | ||
88 | |||
89 | --------------------------------------------------------------------- | ||
90 | |||
91 | quad f a b = fst $ integrateQAGS 1E-9 100 f a b | ||
92 | |||
93 | -- A multiple integral can be easily defined using partial application | ||
94 | quad2 f a b g1 g2 = quad h a b | ||
95 | where h x = quad (f x) (g1 x) (g2 x) | ||
96 | |||
97 | volSphere r = 8 * quad2 (\x y -> sqrt (r*r-x*x-y*y)) | ||
98 | 0 r (const 0) (\x->sqrt (r*r-x*x)) | ||
99 | |||
100 | --------------------------------------------------------------------- | ||
101 | |||
102 | derivTest = abs (d (\x-> x * d (\y-> x+y) 1) 1 - 1) < 1E-10 | ||
103 | where d f x = fst $ derivCentral 0.01 f x | ||
104 | |||
105 | --------------------------------------------------------------------- | ||
106 | |||
107 | -- besselTest = utest "bessel_J0_e" ( abs (r-expected) < e ) | ||
108 | -- where (r,e) = bessel_J0_e 5.0 | ||
109 | -- expected = -0.17759677131433830434739701 | ||
110 | |||
111 | -- exponentialTest = utest "exp_e10_e" ( abs (v*10^e - expected) < 4E-2 ) | ||
112 | -- where (v,e,_err) = exp_e10_e 30.0 | ||
113 | -- expected = exp 30.0 | ||
114 | |||
115 | --------------------------------------------------------------------- | ||
116 | |||
117 | nd1 = (3><3) [ 1/2, 1/4, 1/4 | ||
118 | , 0/1, 1/2, 1/4 | ||
119 | , 1/2, 1/4, 1/2 :: Double] | ||
120 | |||
121 | nd2 = (2><2) [1, 0, 1, 1:: Complex Double] | ||
122 | |||
123 | expmTest1 = expm nd1 :~14~: (3><3) | ||
124 | [ 1.762110887278176 | ||
125 | , 0.478085470590435 | ||
126 | , 0.478085470590435 | ||
127 | , 0.104719410945666 | ||
128 | , 1.709751181805343 | ||
129 | , 0.425725765117601 | ||
130 | , 0.851451530235203 | ||
131 | , 0.530445176063267 | ||
132 | , 1.814470592751009 ] | ||
133 | |||
134 | expmTest2 = expm nd2 :~15~: (2><2) | ||
135 | [ 2.718281828459045 | ||
136 | , 0.000000000000000 | ||
137 | , 2.718281828459045 | ||
138 | , 2.718281828459045 ] | ||
139 | |||
140 | --------------------------------------------------------------------- | ||
141 | |||
142 | minimizationTest = TestList | ||
143 | [ utest "minimization conjugatefr" (minim1 f df [5,7] ~~ [1,2]) | ||
144 | , utest "minimization nmsimplex2" (minim2 f [5,7] `elem` [24,25]) | ||
145 | ] | ||
146 | where f [x,y] = 10*(x-1)^2 + 20*(y-2)^2 + 30 | ||
147 | df [x,y] = [20*(x-1), 40*(y-2)] | ||
148 | minim1 g dg ini = fst $ minimizeD ConjugateFR 1E-3 30 1E-2 1E-4 g dg ini | ||
149 | minim2 g ini = rows $ snd $ minimize NMSimplex2 1E-2 30 [1,1] g ini | ||
150 | |||
151 | --------------------------------------------------------------------- | ||
152 | |||
153 | rootFindingTest = TestList [ utest "root Hybrids" (fst sol1 ~~ [1,1]) | ||
154 | , utest "root Newton" (rows (snd sol2) == 2) | ||
155 | ] | ||
156 | where sol1 = root Hybrids 1E-7 30 (rosenbrock 1 10) [-10,-5] | ||
157 | sol2 = rootJ Newton 1E-7 30 (rosenbrock 1 10) (jacobian 1 10) [-10,-5] | ||
158 | rosenbrock a b [x,y] = [ a*(1-x), b*(y-x^2) ] | ||
159 | jacobian a b [x,_y] = [ [-a , 0] | ||
160 | , [-2*b*x, b] ] | ||
161 | |||
162 | --------------------------------------------------------------------- | ||
163 | |||
164 | odeTest = utest "ode" (last (toLists sol) ~~ [-1.7588880332411019, 8.364348908711941e-2]) | ||
165 | where sol = odeSolveV RK8pd 1E-6 1E-6 0 (l2v $ vanderpol 10) Nothing (fromList [1,0]) ts | ||
166 | ts = linspace 101 (0,100) | ||
167 | l2v f = \t -> fromList . f t . toList | ||
168 | vanderpol mu _t [x,y] = [y, -x + mu * y * (1-x^2) ] | ||
169 | |||
170 | --------------------------------------------------------------------- | ||
171 | |||
172 | fittingTest = utest "levmar" (ok1 && ok2) | ||
173 | where | ||
174 | xs = map return [0 .. 39] | ||
175 | sigma = 0.1 | ||
176 | ys = map return $ toList $ fromList (map (head . expModel [5,0.1,1]) xs) | ||
177 | + scalar sigma * (randomVector 0 Gaussian 40) | ||
178 | dats = zip xs (zip ys (repeat sigma)) | ||
179 | dat = zip xs ys | ||
180 | |||
181 | expModel [a,lambda,b] [t] = [a * exp (-lambda * t) + b] | ||
182 | expModelDer [a,lambda,_b] [t] = [[exp (-lambda * t), -t * a * exp(-lambda*t) , 1]] | ||
183 | |||
184 | sols = fst $ fitModelScaled 1E-4 1E-4 20 (expModel, expModelDer) dats [1,0,0] | ||
185 | sol = fst $ fitModel 1E-4 1E-4 20 (expModel, expModelDer) dat [1,0,0] | ||
186 | |||
187 | ok1 = and (zipWith f sols [5,0.1,1]) where f (x,d) r = abs (x-r)<2*d | ||
188 | ok2 = norm2 (fromList (map fst sols) - fromList sol) < 1E-5 | ||
189 | |||
190 | ----------------------------------------------------- | ||
191 | |||
192 | mbCholTest = utest "mbCholTest" (ok1 && ok2) where | ||
193 | m1 = (2><2) [2,5,5,8 :: Double] | ||
194 | m2 = (2><2) [3,5,5,9 :: Complex Double] | ||
195 | ok1 = mbCholSH m1 == Nothing | ||
196 | ok2 = mbCholSH m2 == Just (chol m2) | ||
197 | |||
198 | --------------------------------------------------------------------- | ||
199 | |||
200 | randomTestGaussian = c :~1~: snd (meanCov dat) where | ||
201 | a = (3><3) [1,2,3, | ||
202 | 2,4,0, | ||
203 | -2,2,1] | ||
204 | m = 3 |> [1,2,3] | ||
205 | c = a <> trans a | ||
206 | dat = gaussianSample 7 (10^6) m c | ||
207 | |||
208 | randomTestUniform = c :~1~: snd (meanCov dat) where | ||
209 | c = diag $ 3 |> map ((/12).(^2)) [1,2,3] | ||
210 | dat = uniformSample 7 (10^6) [(0,1),(1,3),(3,6)] | ||
211 | |||
212 | --------------------------------------------------------------------- | ||
213 | |||
214 | rot :: Double -> Matrix Double | ||
215 | rot a = (3><3) [ c,0,s | ||
216 | , 0,1,0 | ||
217 | ,-s,0,c ] | ||
218 | where c = cos a | ||
219 | s = sin a | ||
220 | |||
221 | rotTest = fun (10^5) :~11~: rot 5E4 | ||
222 | where fun n = foldl1' (<>) (map rot angles) | ||
223 | where angles = toList $ linspace n (0,1) | ||
224 | |||
225 | --------------------------------------------------------------------- | ||
226 | -- vector <= 0.6.0.2 bug discovered by Patrick Perry | ||
227 | -- http://trac.haskell.org/vector/ticket/31 | ||
228 | |||
229 | offsetTest = y == y' where | ||
230 | x = fromList [0..3 :: Double] | ||
231 | y = subVector 1 3 x | ||
232 | (f,o,n) = unsafeToForeignPtr y | ||
233 | y' = unsafeFromForeignPtr f o n | ||
234 | |||
235 | --------------------------------------------------------------------- | ||
236 | |||
237 | normsVTest = TestList [ | ||
238 | utest "normv2CD" $ norm2PropC v | ||
239 | , utest "normv2CF" $ norm2PropC (single v) | ||
240 | #ifndef NONORMVTEST | ||
241 | , utest "normv2D" $ norm2PropR x | ||
242 | , utest "normv2F" $ norm2PropR (single x) | ||
243 | #endif | ||
244 | , utest "normv1CD" $ norm1 v == 8 | ||
245 | , utest "normv1CF" $ norm1 (single v) == 8 | ||
246 | , utest "normv1D" $ norm1 x == 6 | ||
247 | , utest "normv1F" $ norm1 (single x) == 6 | ||
248 | |||
249 | , utest "normvInfCD" $ normInf v == 5 | ||
250 | , utest "normvInfCF" $ normInf (single v) == 5 | ||
251 | , utest "normvInfD" $ normInf x == 3 | ||
252 | , utest "normvInfF" $ normInf (single x) == 3 | ||
253 | |||
254 | ] where v = fromList [1,-2,3:+4] :: Vector (Complex Double) | ||
255 | x = fromList [1,2,-3] :: Vector Double | ||
256 | #ifndef NONORMVTEST | ||
257 | norm2PropR a = norm2 a =~= sqrt (dot a a) | ||
258 | #endif | ||
259 | norm2PropC a = norm2 a =~= realPart (sqrt (dot a (conj a))) | ||
260 | a =~= b = fromList [a] |~| fromList [b] | ||
261 | |||
262 | normsMTest = TestList [ | ||
263 | utest "norm2mCD" $ pnorm PNorm2 v =~= 8.86164970498005 | ||
264 | , utest "norm2mCF" $ pnorm PNorm2 (single v) =~= 8.86164970498005 | ||
265 | , utest "norm2mD" $ pnorm PNorm2 x =~= 5.96667765076216 | ||
266 | , utest "norm2mF" $ pnorm PNorm2 (single x) =~= 5.96667765076216 | ||
267 | |||
268 | , utest "norm1mCD" $ pnorm PNorm1 v == 9 | ||
269 | , utest "norm1mCF" $ pnorm PNorm1 (single v) == 9 | ||
270 | , utest "norm1mD" $ pnorm PNorm1 x == 7 | ||
271 | , utest "norm1mF" $ pnorm PNorm1 (single x) == 7 | ||
272 | |||
273 | , utest "normmInfCD" $ pnorm Infinity v == 12 | ||
274 | , utest "normmInfCF" $ pnorm Infinity (single v) == 12 | ||
275 | , utest "normmInfD" $ pnorm Infinity x == 8 | ||
276 | , utest "normmInfF" $ pnorm Infinity (single x) == 8 | ||
277 | |||
278 | , utest "normmFroCD" $ pnorm Frobenius v =~= 8.88819441731559 | ||
279 | , utest "normmFroCF" $ pnorm Frobenius (single v) =~~= 8.88819441731559 | ||
280 | , utest "normmFroD" $ pnorm Frobenius x =~= 6.24499799839840 | ||
281 | , utest "normmFroF" $ pnorm Frobenius (single x) =~~= 6.24499799839840 | ||
282 | |||
283 | ] where v = (2><2) [1,-2*i,3:+4,7] :: Matrix (Complex Double) | ||
284 | x = (2><2) [1,2,-3,5] :: Matrix Double | ||
285 | a =~= b = fromList [a] :~10~: fromList [b] | ||
286 | a =~~= b = fromList [a] :~5~: fromList [b] | ||
287 | |||
288 | --------------------------------------------------------------------- | ||
289 | |||
290 | sumprodTest = TestList [ | ||
291 | utest "sumCD" $ sumElements z == 6 | ||
292 | , utest "sumCF" $ sumElements (single z) == 6 | ||
293 | , utest "sumD" $ sumElements v == 6 | ||
294 | , utest "sumF" $ sumElements (single v) == 6 | ||
295 | |||
296 | , utest "prodCD" $ prodProp z | ||
297 | , utest "prodCF" $ prodProp (single z) | ||
298 | , utest "prodD" $ prodProp v | ||
299 | , utest "prodF" $ prodProp (single v) | ||
300 | ] where v = fromList [1,2,3] :: Vector Double | ||
301 | z = fromList [1,2-i,3+i] | ||
302 | prodProp x = prodElements x == product (toList x) | ||
303 | |||
304 | --------------------------------------------------------------------- | ||
305 | |||
306 | chainTest = utest "chain" $ foldl1' (<>) ms |~| optimiseMult ms where | ||
307 | ms = [ diag (fromList [1,2,3 :: Double]) | ||
308 | , konst 3 (3,5) | ||
309 | , (5><10) [1 .. ] | ||
310 | , konst 5 (10,2) | ||
311 | ] | ||
312 | |||
313 | --------------------------------------------------------------------- | ||
314 | |||
315 | conjuTest m = mapVector conjugate (flatten (trans m)) == flatten (ctrans m) | ||
316 | |||
317 | --------------------------------------------------------------------- | ||
318 | |||
319 | newtype State s a = State { runState :: s -> (a,s) } | ||
320 | |||
321 | instance Monad (State s) where | ||
322 | return a = State $ \s -> (a,s) | ||
323 | m >>= f = State $ \s -> let (a,s') = runState m s | ||
324 | in runState (f a) s' | ||
325 | |||
326 | state_get :: State s s | ||
327 | state_get = State $ \s -> (s,s) | ||
328 | |||
329 | state_put :: s -> State s () | ||
330 | state_put s = State $ \_ -> ((),s) | ||
331 | |||
332 | evalState :: State s a -> s -> a | ||
333 | evalState m s = let (a,s') = runState m s | ||
334 | in seq s' a | ||
335 | |||
336 | newtype MaybeT m a = MaybeT { runMaybeT :: m (Maybe a) } | ||
337 | |||
338 | instance Monad m => Monad (MaybeT m) where | ||
339 | return a = MaybeT $ return $ Just a | ||
340 | m >>= f = MaybeT $ do | ||
341 | res <- runMaybeT m | ||
342 | case res of | ||
343 | Nothing -> return Nothing | ||
344 | Just r -> runMaybeT (f r) | ||
345 | fail _ = MaybeT $ return Nothing | ||
346 | |||
347 | lift_maybe m = MaybeT $ do | ||
348 | res <- m | ||
349 | return $ Just res | ||
350 | |||
351 | -- apply a test to successive elements of a vector, evaluates to true iff test passes for all pairs | ||
352 | --successive_ :: Storable a => (a -> a -> Bool) -> Vector a -> Bool | ||
353 | successive_ t v = maybe False (\_ -> True) $ evalState (runMaybeT (mapVectorM_ stp (subVector 1 (dim v - 1) v))) (v @> 0) | ||
354 | where stp e = do | ||
355 | ep <- lift_maybe $ state_get | ||
356 | if t e ep | ||
357 | then lift_maybe $ state_put e | ||
358 | else (fail "successive_ test failed") | ||
359 | |||
360 | -- operate on successive elements of a vector and return the resulting vector, whose length 1 less than that of the input | ||
361 | --successive :: (Storable a, Storable b) => (a -> a -> b) -> Vector a -> Vector b | ||
362 | successive f v = evalState (mapVectorM stp (subVector 1 (dim v - 1) v)) (v @> 0) | ||
363 | where stp e = do | ||
364 | ep <- state_get | ||
365 | state_put e | ||
366 | return $ f ep e | ||
367 | |||
368 | |||
369 | succTest = utest "successive" $ | ||
370 | successive_ (>) (fromList [1 :: Double,2,3,4]) == True | ||
371 | && successive_ (>) (fromList [1 :: Double,3,2,4]) == False | ||
372 | && successive (+) (fromList [1..10 :: Double]) == 9 |> [3,5,7,9,11,13,15,17,19] | ||
373 | |||
374 | --------------------------------------------------------------------- | ||
375 | |||
376 | findAssocTest = utest "findAssoc" ok | ||
377 | where | ||
378 | ok = m1 == m2 | ||
379 | m1 = assoc (6,6) 7 $ zip (find (>0) (ident 5 :: Matrix Float)) [10 ..] :: Matrix Double | ||
380 | m2 = diagRect 7 (fromList[10..14]) 6 6 | ||
381 | |||
382 | --------------------------------------------------------------------- | ||
383 | |||
384 | condTest = utest "cond" ok | ||
385 | where | ||
386 | ok = step v * v == cond v 0 0 0 v | ||
387 | v = fromList [-7 .. 7 ] :: Vector Float | ||
388 | |||
389 | --------------------------------------------------------------------- | ||
390 | |||
391 | conformTest = utest "conform" ok | ||
392 | where | ||
393 | ok = 1 + row [1,2,3] + col [10,20,30,40] + (4><3) [1..] | ||
394 | == (4><3) [13,15,17 | ||
395 | ,26,28,30 | ||
396 | ,39,41,43 | ||
397 | ,52,54,56] | ||
398 | row = asRow . fromList | ||
399 | col = asColumn . fromList :: [Double] -> Matrix Double | ||
400 | |||
401 | --------------------------------------------------------------------- | ||
402 | |||
403 | accumTest = utest "accum" ok | ||
404 | where | ||
405 | x = ident 3 :: Matrix Double | ||
406 | ok = accum x (+) [((1,2),7), ((2,2),3)] | ||
407 | == (3><3) [1,0,0 | ||
408 | ,0,1,7 | ||
409 | ,0,0,4] | ||
410 | && | ||
411 | toList (flatten x) == [1,0,0,0,1,0,0,0,1] | ||
412 | |||
413 | --------------------------------------------------------------------- | ||
414 | |||
415 | -- | All tests must pass with a maximum dimension of about 20 | ||
416 | -- (some tests may fail with bigger sizes due to precision loss). | ||
417 | runTests :: Int -- ^ maximum dimension | ||
418 | -> IO () | ||
419 | runTests n = do | ||
420 | setErrorHandlerOff | ||
421 | let test p = qCheck n p | ||
422 | putStrLn "------ mult Double" | ||
423 | test (multProp1 10 . rConsist) | ||
424 | test (multProp1 10 . cConsist) | ||
425 | test (multProp2 10 . rConsist) | ||
426 | test (multProp2 10 . cConsist) | ||
427 | putStrLn "------ mult Float" | ||
428 | test (multProp1 6 . (single *** single) . rConsist) | ||
429 | test (multProp1 6 . (single *** single) . cConsist) | ||
430 | test (multProp2 6 . (single *** single) . rConsist) | ||
431 | test (multProp2 6 . (single *** single) . cConsist) | ||
432 | putStrLn "------ sub-trans" | ||
433 | test (subProp . rM) | ||
434 | test (subProp . cM) | ||
435 | putStrLn "------ ctrans" | ||
436 | test (conjuTest . cM) | ||
437 | test (conjuTest . zM) | ||
438 | putStrLn "------ lu" | ||
439 | test (luProp . rM) | ||
440 | test (luProp . cM) | ||
441 | putStrLn "------ inv (linearSolve)" | ||
442 | test (invProp . rSqWC) | ||
443 | test (invProp . cSqWC) | ||
444 | putStrLn "------ luSolve" | ||
445 | test (linearSolveProp (luSolve.luPacked) . rSqWC) | ||
446 | test (linearSolveProp (luSolve.luPacked) . cSqWC) | ||
447 | putStrLn "------ cholSolve" | ||
448 | test (linearSolveProp (cholSolve.chol) . rPosDef) | ||
449 | test (linearSolveProp (cholSolve.chol) . cPosDef) | ||
450 | putStrLn "------ luSolveLS" | ||
451 | test (linearSolveProp linearSolveLS . rSqWC) | ||
452 | test (linearSolveProp linearSolveLS . cSqWC) | ||
453 | test (linearSolveProp2 linearSolveLS . rConsist) | ||
454 | test (linearSolveProp2 linearSolveLS . cConsist) | ||
455 | putStrLn "------ pinv (linearSolveSVD)" | ||
456 | test (pinvProp . rM) | ||
457 | test (pinvProp . cM) | ||
458 | putStrLn "------ det" | ||
459 | test (detProp . rSqWC) | ||
460 | test (detProp . cSqWC) | ||
461 | putStrLn "------ svd" | ||
462 | test (svdProp1 . rM) | ||
463 | test (svdProp1 . cM) | ||
464 | test (svdProp1a svdR) | ||
465 | test (svdProp1a svdC) | ||
466 | test (svdProp1a svdRd) | ||
467 | test (svdProp1b svdR) | ||
468 | test (svdProp1b svdC) | ||
469 | test (svdProp1b svdRd) | ||
470 | test (svdProp2 thinSVDR) | ||
471 | test (svdProp2 thinSVDC) | ||
472 | test (svdProp2 thinSVDRd) | ||
473 | test (svdProp2 thinSVDCd) | ||
474 | test (svdProp3 . rM) | ||
475 | test (svdProp3 . cM) | ||
476 | test (svdProp4 . rM) | ||
477 | test (svdProp4 . cM) | ||
478 | test (svdProp5a) | ||
479 | test (svdProp5b) | ||
480 | test (svdProp6a) | ||
481 | test (svdProp6b) | ||
482 | test (svdProp7 . rM) | ||
483 | test (svdProp7 . cM) | ||
484 | putStrLn "------ svdCd" | ||
485 | #ifdef NOZGESDD | ||
486 | putStrLn "Omitted" | ||
487 | #else | ||
488 | test (svdProp1a svdCd) | ||
489 | test (svdProp1b svdCd) | ||
490 | #endif | ||
491 | putStrLn "------ eig" | ||
492 | test (eigSHProp . rHer) | ||
493 | test (eigSHProp . cHer) | ||
494 | test (eigProp . rSq) | ||
495 | test (eigProp . cSq) | ||
496 | test (eigSHProp2 . rHer) | ||
497 | test (eigSHProp2 . cHer) | ||
498 | test (eigProp2 . rSq) | ||
499 | test (eigProp2 . cSq) | ||
500 | putStrLn "------ nullSpace" | ||
501 | test (nullspaceProp . rM) | ||
502 | test (nullspaceProp . cM) | ||
503 | putStrLn "------ qr" | ||
504 | test (qrProp . rM) | ||
505 | test (qrProp . cM) | ||
506 | test (rqProp . rM) | ||
507 | test (rqProp . cM) | ||
508 | test (rqProp1 . cM) | ||
509 | test (rqProp2 . cM) | ||
510 | test (rqProp3 . cM) | ||
511 | putStrLn "------ hess" | ||
512 | test (hessProp . rSq) | ||
513 | test (hessProp . cSq) | ||
514 | putStrLn "------ schur" | ||
515 | test (schurProp2 . rSq) | ||
516 | test (schurProp1 . cSq) | ||
517 | putStrLn "------ chol" | ||
518 | test (cholProp . rPosDef) | ||
519 | test (cholProp . cPosDef) | ||
520 | test (exactProp . rPosDef) | ||
521 | test (exactProp . cPosDef) | ||
522 | putStrLn "------ expm" | ||
523 | test (expmDiagProp . complex. rSqWC) | ||
524 | test (expmDiagProp . cSqWC) | ||
525 | putStrLn "------ fft" | ||
526 | test (\v -> ifft (fft v) |~| v) | ||
527 | putStrLn "------ vector operations - Double" | ||
528 | test (\u -> sin u ^ 2 + cos u ^ 2 |~| (1::RM)) | ||
529 | test $ (\u -> sin u ^ 2 + cos u ^ 2 |~| (1::CM)) . liftMatrix makeUnitary | ||
530 | test (\u -> sin u ** 2 + cos u ** 2 |~| (1::RM)) | ||
531 | test (\u -> cos u * tan u |~| sin (u::RM)) | ||
532 | test $ (\u -> cos u * tan u |~| sin (u::CM)) . liftMatrix makeUnitary | ||
533 | putStrLn "------ vector operations - Float" | ||
534 | test (\u -> sin u ^ 2 + cos u ^ 2 |~~| (1::FM)) | ||
535 | test $ (\u -> sin u ^ 2 + cos u ^ 2 |~~| (1::ZM)) . liftMatrix makeUnitary | ||
536 | test (\u -> sin u ** 2 + cos u ** 2 |~~| (1::FM)) | ||
537 | test (\u -> cos u * tan u |~~| sin (u::FM)) | ||
538 | test $ (\u -> cos u * tan u |~~| sin (u::ZM)) . liftMatrix makeUnitary | ||
539 | putStrLn "------ read . show" | ||
540 | test (\m -> (m::RM) == read (show m)) | ||
541 | test (\m -> (m::CM) == read (show m)) | ||
542 | test (\m -> toRows (m::RM) == read (show (toRows m))) | ||
543 | test (\m -> toRows (m::CM) == read (show (toRows m))) | ||
544 | test (\m -> (m::FM) == read (show m)) | ||
545 | test (\m -> (m::ZM) == read (show m)) | ||
546 | test (\m -> toRows (m::FM) == read (show (toRows m))) | ||
547 | test (\m -> toRows (m::ZM) == read (show (toRows m))) | ||
548 | putStrLn "------ some unit tests" | ||
549 | _ <- runTestTT $ TestList | ||
550 | [ utest "1E5 rots" rotTest | ||
551 | , utest "det1" detTest1 | ||
552 | , utest "invlndet" detTest2 | ||
553 | , utest "expm1" (expmTest1) | ||
554 | , utest "expm2" (expmTest2) | ||
555 | , utest "arith1" $ ((ones (100,100) * 5 + 2)/0.5 - 7)**2 |~| (49 :: RM) | ||
556 | , utest "arith2" $ ((scalar (1+i) * ones (100,100) * 5 + 2)/0.5 - 7)**2 |~| ( scalar (140*i-51) :: CM) | ||
557 | , utest "arith3" $ exp (scalar i * ones(10,10)*pi) + 1 |~| 0 | ||
558 | , utest "<\\>" $ (3><2) [2,0,0,3,1,1::Double] <\> 3|>[4,9,5] |~| 2|>[2,3] | ||
559 | -- , utest "gamma" (gamma 5 == 24.0) | ||
560 | -- , besselTest | ||
561 | -- , exponentialTest | ||
562 | , utest "deriv" derivTest | ||
563 | , utest "integrate" (abs (volSphere 2.5 - 4/3*pi*2.5^3) < 1E-8) | ||
564 | , utest "polySolve" (polySolveProp [1,2,3,4]) | ||
565 | , minimizationTest | ||
566 | , rootFindingTest | ||
567 | , utest "randomGaussian" randomTestGaussian | ||
568 | , utest "randomUniform" randomTestUniform | ||
569 | , utest "buildVector/Matrix" $ | ||
570 | complex (10 |> [0::Double ..]) == buildVector 10 fromIntegral | ||
571 | && ident 5 == buildMatrix 5 5 (\(r,c) -> if r==c then 1::Double else 0) | ||
572 | , utest "rank" $ rank ((2><3)[1,0,0,1,6*eps,0]) == 1 | ||
573 | && rank ((2><3)[1,0,0,1,7*eps,0]) == 2 | ||
574 | , utest "block" $ fromBlocks [[ident 3,0],[0,ident 4]] == (ident 7 :: CM) | ||
575 | , odeTest | ||
576 | , fittingTest | ||
577 | , mbCholTest | ||
578 | , utest "offset" offsetTest | ||
579 | , normsVTest | ||
580 | , normsMTest | ||
581 | , sumprodTest | ||
582 | , chainTest | ||
583 | , succTest | ||
584 | , findAssocTest | ||
585 | , condTest | ||
586 | , conformTest | ||
587 | , accumTest | ||
588 | ] | ||
589 | return () | ||
590 | |||
591 | |||
592 | -- single precision approximate equality | ||
593 | infixl 4 |~~| | ||
594 | a |~~| b = a :~6~: b | ||
595 | |||
596 | makeUnitary v | realPart n > 1 = v / scalar n | ||
597 | | otherwise = v | ||
598 | where n = sqrt (conj v <.> v) | ||
599 | |||
600 | -- -- | Some additional tests on big matrices. They take a few minutes. | ||
601 | -- runBigTests :: IO () | ||
602 | -- runBigTests = undefined | ||
603 | |||
604 | {- | ||
605 | -- | testcase for nonempty fpu stack | ||
606 | findNaN :: Int -> Bool | ||
607 | findNaN n = all (bugProp . eye) (take n $ cycle [1..20]) | ||
608 | where eye m = ident m :: Matrix ( Double) | ||
609 | -} | ||
610 | |||
611 | -------------------------------------------------------------------------------- | ||
612 | |||
613 | -- | Performance measurements. | ||
614 | runBenchmarks :: IO () | ||
615 | runBenchmarks = do | ||
616 | solveBench | ||
617 | subBench | ||
618 | multBench | ||
619 | cholBench | ||
620 | svdBench | ||
621 | eigBench | ||
622 | putStrLn "" | ||
623 | |||
624 | -------------------------------- | ||
625 | |||
626 | time msg act = do | ||
627 | putStr (msg++" ") | ||
628 | t0 <- getCPUTime | ||
629 | act `seq` putStr " " | ||
630 | t1 <- getCPUTime | ||
631 | printf "%6.2f s CPU\n" $ (fromIntegral (t1 - t0) / (10^12 :: Double)) :: IO () | ||
632 | return () | ||
633 | |||
634 | -------------------------------- | ||
635 | |||
636 | manymult n = foldl1' (<>) (map rot2 angles) where | ||
637 | angles = toList $ linspace n (0,1) | ||
638 | rot2 :: Double -> Matrix Double | ||
639 | rot2 a = (3><3) [ c,0,s | ||
640 | , 0,1,0 | ||
641 | ,-s,0,c ] | ||
642 | where c = cos a | ||
643 | s = sin a | ||
644 | |||
645 | multb n = foldl1' (<>) (replicate (10^6) (ident n :: Matrix Double)) | ||
646 | |||
647 | -------------------------------- | ||
648 | |||
649 | subBench = do | ||
650 | putStrLn "" | ||
651 | let g = foldl1' (.) (replicate (10^5) (\v -> subVector 1 (dim v -1) v)) | ||
652 | time "0.1M subVector " (g (constant 1 (1+10^5) :: Vector Double) @> 0) | ||
653 | let f = foldl1' (.) (replicate (10^5) (fromRows.toRows)) | ||
654 | time "subVector-join 3" (f (ident 3 :: Matrix Double) @@>(0,0)) | ||
655 | time "subVector-join 10" (f (ident 10 :: Matrix Double) @@>(0,0)) | ||
656 | |||
657 | -------------------------------- | ||
658 | |||
659 | multBench = do | ||
660 | let a = ident 1000 :: Matrix Double | ||
661 | let b = ident 2000 :: Matrix Double | ||
662 | a `seq` b `seq` putStrLn "" | ||
663 | time "product of 1M different 3x3 matrices" (manymult (10^6)) | ||
664 | putStrLn "" | ||
665 | time "product of 1M constant 1x1 matrices" (multb 1) | ||
666 | time "product of 1M constant 3x3 matrices" (multb 3) | ||
667 | --time "product of 1M constant 5x5 matrices" (multb 5) | ||
668 | time "product of 1M const. 10x10 matrices" (multb 10) | ||
669 | --time "product of 1M const. 15x15 matrices" (multb 15) | ||
670 | time "product of 1M const. 20x20 matrices" (multb 20) | ||
671 | --time "product of 1M const. 25x25 matrices" (multb 25) | ||
672 | putStrLn "" | ||
673 | time "product (1000 x 1000)<>(1000 x 1000)" (a<>a) | ||
674 | time "product (2000 x 2000)<>(2000 x 2000)" (b<>b) | ||
675 | |||
676 | -------------------------------- | ||
677 | |||
678 | eigBench = do | ||
679 | let m = reshape 1000 (randomVector 777 Uniform (1000*1000)) | ||
680 | s = m + trans m | ||
681 | m `seq` s `seq` putStrLn "" | ||
682 | time "eigenvalues symmetric 1000x1000" (eigenvaluesSH' m) | ||
683 | time "eigenvectors symmetric 1000x1000" (snd $ eigSH' m) | ||
684 | time "eigenvalues general 1000x1000" (eigenvalues m) | ||
685 | time "eigenvectors general 1000x1000" (snd $ eig m) | ||
686 | |||
687 | -------------------------------- | ||
688 | |||
689 | svdBench = do | ||
690 | let a = reshape 500 (randomVector 777 Uniform (3000*500)) | ||
691 | b = reshape 1000 (randomVector 777 Uniform (1000*1000)) | ||
692 | fv (_,_,v) = v@@>(0,0) | ||
693 | a `seq` b `seq` putStrLn "" | ||
694 | time "singular values 3000x500" (singularValues a) | ||
695 | time "thin svd 3000x500" (fv $ thinSVD a) | ||
696 | time "full svd 3000x500" (fv $ svd a) | ||
697 | time "singular values 1000x1000" (singularValues b) | ||
698 | time "full svd 1000x1000" (fv $ svd b) | ||
699 | |||
700 | -------------------------------- | ||
701 | |||
702 | solveBenchN n = do | ||
703 | let x = uniformSample 777 (2*n) (replicate n (-1,1)) | ||
704 | a = trans x <> x | ||
705 | b = asColumn $ randomVector 666 Uniform n | ||
706 | a `seq` b `seq` putStrLn "" | ||
707 | time ("svd solve " ++ show n) (linearSolveSVD a b) | ||
708 | time (" ls solve " ++ show n) (linearSolveLS a b) | ||
709 | time (" solve " ++ show n) (linearSolve a b) | ||
710 | time ("cholSolve " ++ show n) (cholSolve (chol a) b) | ||
711 | |||
712 | solveBench = do | ||
713 | solveBenchN 500 | ||
714 | solveBenchN 1000 | ||
715 | -- solveBenchN 1500 | ||
716 | |||
717 | -------------------------------- | ||
718 | |||
719 | cholBenchN n = do | ||
720 | let x = uniformSample 777 (2*n) (replicate n (-1,1)) | ||
721 | a = trans x <> x | ||
722 | a `seq` putStr "" | ||
723 | time ("chol " ++ show n) (chol a) | ||
724 | |||
725 | cholBench = do | ||
726 | putStrLn "" | ||
727 | cholBenchN 1200 | ||
728 | cholBenchN 600 | ||
729 | cholBenchN 300 | ||
730 | -- cholBenchN 150 | ||
731 | -- cholBenchN 50 | ||
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs new file mode 100644 index 0000000..647a06c --- /dev/null +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs | |||
@@ -0,0 +1,251 @@ | |||
1 | {-# LANGUAGE FlexibleContexts, UndecidableInstances, CPP, FlexibleInstances #-} | ||
2 | {-# OPTIONS_GHC -fno-warn-unused-imports #-} | ||
3 | ----------------------------------------------------------------------------- | ||
4 | {- | | ||
5 | Module : Numeric.LinearAlgebra.Tests.Instances | ||
6 | Copyright : (c) Alberto Ruiz 2008 | ||
7 | License : GPL-style | ||
8 | |||
9 | Maintainer : Alberto Ruiz (aruiz at um dot es) | ||
10 | Stability : provisional | ||
11 | Portability : portable | ||
12 | |||
13 | Arbitrary instances for vectors, matrices. | ||
14 | |||
15 | -} | ||
16 | |||
17 | module Numeric.LinearAlgebra.Tests.Instances( | ||
18 | Sq(..), rSq,cSq, | ||
19 | Rot(..), rRot,cRot, | ||
20 | Her(..), rHer,cHer, | ||
21 | WC(..), rWC,cWC, | ||
22 | SqWC(..), rSqWC, cSqWC, | ||
23 | PosDef(..), rPosDef, cPosDef, | ||
24 | Consistent(..), rConsist, cConsist, | ||
25 | RM,CM, rM,cM, | ||
26 | FM,ZM, fM,zM | ||
27 | ) where | ||
28 | |||
29 | import System.Random | ||
30 | |||
31 | import Numeric.LinearAlgebra | ||
32 | import Control.Monad(replicateM) | ||
33 | import Test.QuickCheck(Arbitrary,arbitrary,coarbitrary,choose,vector | ||
34 | ,sized,classify,Testable,Property | ||
35 | ,quickCheckWith,maxSize,stdArgs,shrink) | ||
36 | |||
37 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
38 | shrinkListElementwise :: (Arbitrary a) => [a] -> [[a]] | ||
39 | shrinkListElementwise [] = [] | ||
40 | shrinkListElementwise (x:xs) = [ y:xs | y <- shrink x ] | ||
41 | ++ [ x:ys | ys <- shrinkListElementwise xs ] | ||
42 | |||
43 | shrinkPair :: (Arbitrary a, Arbitrary b) => (a,b) -> [(a,b)] | ||
44 | shrinkPair (a,b) = [ (a,x) | x <- shrink b ] ++ [ (x,b) | x <- shrink a ] | ||
45 | #endif | ||
46 | |||
47 | #if MIN_VERSION_QuickCheck(2,1,1) | ||
48 | #else | ||
49 | instance (Arbitrary a, RealFloat a) => Arbitrary (Complex a) where | ||
50 | arbitrary = do | ||
51 | re <- arbitrary | ||
52 | im <- arbitrary | ||
53 | return (re :+ im) | ||
54 | |||
55 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
56 | shrink (re :+ im) = | ||
57 | [ u :+ v | (u,v) <- shrinkPair (re,im) ] | ||
58 | #else | ||
59 | -- this has been moved to the 'Coarbitrary' class in QuickCheck 2 | ||
60 | coarbitrary = undefined | ||
61 | #endif | ||
62 | |||
63 | #endif | ||
64 | |||
65 | chooseDim = sized $ \m -> choose (1,max 1 m) | ||
66 | |||
67 | instance (Field a, Arbitrary a) => Arbitrary (Vector a) where | ||
68 | arbitrary = do m <- chooseDim | ||
69 | l <- vector m | ||
70 | return $ fromList l | ||
71 | |||
72 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
73 | -- shrink any one of the components | ||
74 | shrink = map fromList . shrinkListElementwise . toList | ||
75 | |||
76 | #else | ||
77 | coarbitrary = undefined | ||
78 | #endif | ||
79 | |||
80 | instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where | ||
81 | arbitrary = do | ||
82 | m <- chooseDim | ||
83 | n <- chooseDim | ||
84 | l <- vector (m*n) | ||
85 | return $ (m><n) l | ||
86 | |||
87 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
88 | -- shrink any one of the components | ||
89 | shrink a = map (rows a >< cols a) | ||
90 | . shrinkListElementwise | ||
91 | . concat . toLists | ||
92 | $ a | ||
93 | #else | ||
94 | coarbitrary = undefined | ||
95 | #endif | ||
96 | |||
97 | |||
98 | -- a square matrix | ||
99 | newtype (Sq a) = Sq (Matrix a) deriving Show | ||
100 | instance (Element a, Arbitrary a) => Arbitrary (Sq a) where | ||
101 | arbitrary = do | ||
102 | n <- chooseDim | ||
103 | l <- vector (n*n) | ||
104 | return $ Sq $ (n><n) l | ||
105 | |||
106 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
107 | shrink (Sq a) = [ Sq b | b <- shrink a ] | ||
108 | #else | ||
109 | coarbitrary = undefined | ||
110 | #endif | ||
111 | |||
112 | |||
113 | -- a unitary matrix | ||
114 | newtype (Rot a) = Rot (Matrix a) deriving Show | ||
115 | instance (Field a, Arbitrary a) => Arbitrary (Rot a) where | ||
116 | arbitrary = do | ||
117 | Sq m <- arbitrary | ||
118 | let (q,_) = qr m | ||
119 | return (Rot q) | ||
120 | |||
121 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
122 | #else | ||
123 | coarbitrary = undefined | ||
124 | #endif | ||
125 | |||
126 | |||
127 | -- a complex hermitian or real symmetric matrix | ||
128 | newtype (Her a) = Her (Matrix a) deriving Show | ||
129 | instance (Field a, Arbitrary a, Num (Vector a)) => Arbitrary (Her a) where | ||
130 | arbitrary = do | ||
131 | Sq m <- arbitrary | ||
132 | let m' = m/2 | ||
133 | return $ Her (m' + ctrans m') | ||
134 | |||
135 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
136 | #else | ||
137 | coarbitrary = undefined | ||
138 | #endif | ||
139 | |||
140 | class (Field a, Arbitrary a, Element (RealOf a), Random (RealOf a)) => ArbitraryField a | ||
141 | instance ArbitraryField Double | ||
142 | instance ArbitraryField (Complex Double) | ||
143 | |||
144 | |||
145 | -- a well-conditioned general matrix (the singular values are between 1 and 100) | ||
146 | newtype (WC a) = WC (Matrix a) deriving Show | ||
147 | instance (ArbitraryField a) => Arbitrary (WC a) where | ||
148 | arbitrary = do | ||
149 | m <- arbitrary | ||
150 | let (u,_,v) = svd m | ||
151 | r = rows m | ||
152 | c = cols m | ||
153 | n = min r c | ||
154 | sv' <- replicateM n (choose (1,100)) | ||
155 | let s = diagRect 0 (fromList sv') r c | ||
156 | return $ WC (u <> real s <> trans v) | ||
157 | |||
158 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
159 | #else | ||
160 | coarbitrary = undefined | ||
161 | #endif | ||
162 | |||
163 | |||
164 | -- a well-conditioned square matrix (the singular values are between 1 and 100) | ||
165 | newtype (SqWC a) = SqWC (Matrix a) deriving Show | ||
166 | instance (ArbitraryField a) => Arbitrary (SqWC a) where | ||
167 | arbitrary = do | ||
168 | Sq m <- arbitrary | ||
169 | let (u,_,v) = svd m | ||
170 | n = rows m | ||
171 | sv' <- replicateM n (choose (1,100)) | ||
172 | let s = diag (fromList sv') | ||
173 | return $ SqWC (u <> real s <> trans v) | ||
174 | |||
175 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
176 | #else | ||
177 | coarbitrary = undefined | ||
178 | #endif | ||
179 | |||
180 | |||
181 | -- a positive definite square matrix (the eigenvalues are between 0 and 100) | ||
182 | newtype (PosDef a) = PosDef (Matrix a) deriving Show | ||
183 | instance (ArbitraryField a, Num (Vector a)) | ||
184 | => Arbitrary (PosDef a) where | ||
185 | arbitrary = do | ||
186 | Her m <- arbitrary | ||
187 | let (_,v) = eigSH m | ||
188 | n = rows m | ||
189 | l <- replicateM n (choose (0,100)) | ||
190 | let s = diag (fromList l) | ||
191 | p = v <> real s <> ctrans v | ||
192 | return $ PosDef (0.5 * p + 0.5 * ctrans p) | ||
193 | |||
194 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
195 | #else | ||
196 | coarbitrary = undefined | ||
197 | #endif | ||
198 | |||
199 | |||
200 | -- a pair of matrices that can be multiplied | ||
201 | newtype (Consistent a) = Consistent (Matrix a, Matrix a) deriving Show | ||
202 | instance (Field a, Arbitrary a) => Arbitrary (Consistent a) where | ||
203 | arbitrary = do | ||
204 | n <- chooseDim | ||
205 | k <- chooseDim | ||
206 | m <- chooseDim | ||
207 | la <- vector (n*k) | ||
208 | lb <- vector (k*m) | ||
209 | return $ Consistent ((n><k) la, (k><m) lb) | ||
210 | |||
211 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
212 | shrink (Consistent (x,y)) = [ Consistent (u,v) | (u,v) <- shrinkPair (x,y) ] | ||
213 | #else | ||
214 | coarbitrary = undefined | ||
215 | #endif | ||
216 | |||
217 | |||
218 | |||
219 | type RM = Matrix Double | ||
220 | type CM = Matrix (Complex Double) | ||
221 | type FM = Matrix Float | ||
222 | type ZM = Matrix (Complex Float) | ||
223 | |||
224 | |||
225 | rM m = m :: RM | ||
226 | cM m = m :: CM | ||
227 | fM m = m :: FM | ||
228 | zM m = m :: ZM | ||
229 | |||
230 | |||
231 | rHer (Her m) = m :: RM | ||
232 | cHer (Her m) = m :: CM | ||
233 | |||
234 | rRot (Rot m) = m :: RM | ||
235 | cRot (Rot m) = m :: CM | ||
236 | |||
237 | rSq (Sq m) = m :: RM | ||
238 | cSq (Sq m) = m :: CM | ||
239 | |||
240 | rWC (WC m) = m :: RM | ||
241 | cWC (WC m) = m :: CM | ||
242 | |||
243 | rSqWC (SqWC m) = m :: RM | ||
244 | cSqWC (SqWC m) = m :: CM | ||
245 | |||
246 | rPosDef (PosDef m) = m :: RM | ||
247 | cPosDef (PosDef m) = m :: CM | ||
248 | |||
249 | rConsist (Consistent (a,b)) = (a,b::RM) | ||
250 | cConsist (Consistent (a,b)) = (a,b::CM) | ||
251 | |||
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs new file mode 100644 index 0000000..c96d3de --- /dev/null +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs | |||
@@ -0,0 +1,272 @@ | |||
1 | {-# LANGUAGE CPP, FlexibleContexts #-} | ||
2 | {-# OPTIONS_GHC -fno-warn-unused-imports #-} | ||
3 | ----------------------------------------------------------------------------- | ||
4 | {- | | ||
5 | Module : Numeric.LinearAlgebra.Tests.Properties | ||
6 | Copyright : (c) Alberto Ruiz 2008 | ||
7 | License : GPL-style | ||
8 | |||
9 | Maintainer : Alberto Ruiz (aruiz at um dot es) | ||
10 | Stability : provisional | ||
11 | Portability : portable | ||
12 | |||
13 | Testing properties. | ||
14 | |||
15 | -} | ||
16 | |||
17 | module Numeric.LinearAlgebra.Tests.Properties ( | ||
18 | dist, (|~|), (~:), Aprox((:~)), | ||
19 | zeros, ones, | ||
20 | square, | ||
21 | unitary, | ||
22 | hermitian, | ||
23 | wellCond, | ||
24 | positiveDefinite, | ||
25 | upperTriang, | ||
26 | upperHessenberg, | ||
27 | luProp, | ||
28 | invProp, | ||
29 | pinvProp, | ||
30 | detProp, | ||
31 | nullspaceProp, | ||
32 | bugProp, | ||
33 | svdProp1, svdProp1a, svdProp1b, svdProp2, svdProp3, svdProp4, | ||
34 | svdProp5a, svdProp5b, svdProp6a, svdProp6b, svdProp7, | ||
35 | eigProp, eigSHProp, eigProp2, eigSHProp2, | ||
36 | qrProp, rqProp, rqProp1, rqProp2, rqProp3, | ||
37 | hessProp, | ||
38 | schurProp1, schurProp2, | ||
39 | cholProp, exactProp, | ||
40 | expmDiagProp, | ||
41 | multProp1, multProp2, | ||
42 | subProp, | ||
43 | linearSolveProp, linearSolveProp2 | ||
44 | ) where | ||
45 | |||
46 | import Numeric.LinearAlgebra --hiding (real,complex) | ||
47 | import Numeric.LinearAlgebra.LAPACK | ||
48 | import Debug.Trace | ||
49 | import Test.QuickCheck(Arbitrary,arbitrary,coarbitrary,choose,vector | ||
50 | ,sized,classify,Testable,Property | ||
51 | ,quickCheckWith,maxSize,stdArgs,shrink) | ||
52 | |||
53 | trivial :: Testable a => Bool -> a -> Property | ||
54 | trivial = (`classify` "trivial") | ||
55 | |||
56 | |||
57 | -- relative error | ||
58 | dist :: (Normed c t, Num (c t)) => c t -> c t -> Double | ||
59 | dist a b = realToFrac r | ||
60 | where norm = pnorm Infinity | ||
61 | na = norm a | ||
62 | nb = norm b | ||
63 | nab = norm (a-b) | ||
64 | mx = max na nb | ||
65 | mn = min na nb | ||
66 | r = if mn < peps | ||
67 | then mx | ||
68 | else nab/mx | ||
69 | |||
70 | infixl 4 |~| | ||
71 | a |~| b = a :~10~: b | ||
72 | --a |~| b = dist a b < 10^^(-10) | ||
73 | |||
74 | data Aprox a = (:~) a Int | ||
75 | -- (~:) :: (Normed a, Num a) => Aprox a -> a -> Bool | ||
76 | a :~n~: b = dist a b < 10^^(-n) | ||
77 | |||
78 | ------------------------------------------------------ | ||
79 | |||
80 | square m = rows m == cols m | ||
81 | |||
82 | -- orthonormal columns | ||
83 | orthonormal m = ctrans m <> m |~| ident (cols m) | ||
84 | |||
85 | unitary m = square m && orthonormal m | ||
86 | |||
87 | hermitian m = square m && m |~| ctrans m | ||
88 | |||
89 | wellCond m = rcond m > 1/100 | ||
90 | |||
91 | positiveDefinite m = minimum (toList e) > 0 | ||
92 | where (e,_v) = eigSH m | ||
93 | |||
94 | upperTriang m = rows m == 1 || down == z | ||
95 | where down = fromList $ concat $ zipWith drop [1..] (toLists (ctrans m)) | ||
96 | z = constant 0 (dim down) | ||
97 | |||
98 | upperHessenberg m = rows m < 3 || down == z | ||
99 | where down = fromList $ concat $ zipWith drop [2..] (toLists (ctrans m)) | ||
100 | z = constant 0 (dim down) | ||
101 | |||
102 | zeros (r,c) = reshape c (constant 0 (r*c)) | ||
103 | |||
104 | ones (r,c) = zeros (r,c) + 1 | ||
105 | |||
106 | ----------------------------------------------------- | ||
107 | |||
108 | luProp m = m |~| p <> l <> u && f (det p) |~| f s | ||
109 | where (l,u,p,s) = lu m | ||
110 | f x = fromList [x] | ||
111 | |||
112 | invProp m = m <> inv m |~| ident (rows m) | ||
113 | |||
114 | pinvProp m = m <> p <> m |~| m | ||
115 | && p <> m <> p |~| p | ||
116 | && hermitian (m<>p) | ||
117 | && hermitian (p<>m) | ||
118 | where p = pinv m | ||
119 | |||
120 | detProp m = s d1 |~| s d2 | ||
121 | where d1 = det m | ||
122 | d2 = det' * det q | ||
123 | det' = product $ toList $ takeDiag r | ||
124 | (q,r) = qr m | ||
125 | s x = fromList [x] | ||
126 | |||
127 | nullspaceProp m = null nl `trivial` (null nl || m <> n |~| zeros (r,c) | ||
128 | && orthonormal (fromColumns nl)) | ||
129 | where nl = nullspacePrec 1 m | ||
130 | n = fromColumns nl | ||
131 | r = rows m | ||
132 | c = cols m - rank m | ||
133 | |||
134 | ------------------------------------------------------------------ | ||
135 | |||
136 | -- testcase for nonempty fpu stack | ||
137 | -- uncommenting unitary' signature eliminates the problem | ||
138 | bugProp m = m |~| u <> real d <> trans v && unitary' u && unitary' v | ||
139 | where (u,d,v) = fullSVD m | ||
140 | -- unitary' :: (Num (Vector t), Field t) => Matrix t -> Bool | ||
141 | unitary' a = unitary a | ||
142 | |||
143 | ------------------------------------------------------------------ | ||
144 | |||
145 | -- fullSVD | ||
146 | svdProp1 m = m |~| u <> real d <> trans v && unitary u && unitary v | ||
147 | where (u,d,v) = fullSVD m | ||
148 | |||
149 | svdProp1a svdfun m = m |~| u <> real d <> trans v && unitary u && unitary v where | ||
150 | (u,s,v) = svdfun m | ||
151 | d = diagRect 0 s (rows m) (cols m) | ||
152 | |||
153 | svdProp1b svdfun m = unitary u && unitary v where | ||
154 | (u,_,v) = svdfun m | ||
155 | |||
156 | -- thinSVD | ||
157 | svdProp2 thinSVDfun m = m |~| u <> diag (real s) <> trans v && orthonormal u && orthonormal v && dim s == min (rows m) (cols m) | ||
158 | where (u,s,v) = thinSVDfun m | ||
159 | |||
160 | -- compactSVD | ||
161 | svdProp3 m = (m |~| u <> real (diag s) <> trans v | ||
162 | && orthonormal u && orthonormal v) | ||
163 | where (u,s,v) = compactSVD m | ||
164 | |||
165 | svdProp4 m' = m |~| u <> real (diag s) <> trans v | ||
166 | && orthonormal u && orthonormal v | ||
167 | && (dim s == r || r == 0 && dim s == 1) | ||
168 | where (u,s,v) = compactSVD m | ||
169 | m = fromBlocks [[m'],[m']] | ||
170 | r = rank m' | ||
171 | |||
172 | svdProp5a m = all (s1|~|) [s2,s3,s4,s5,s6] where | ||
173 | s1 = svR m | ||
174 | s2 = svRd m | ||
175 | (_,s3,_) = svdR m | ||
176 | (_,s4,_) = svdRd m | ||
177 | (_,s5,_) = thinSVDR m | ||
178 | (_,s6,_) = thinSVDRd m | ||
179 | |||
180 | svdProp5b m = all (s1|~|) [s2,s3,s4,s5,s6] where | ||
181 | s1 = svC m | ||
182 | s2 = svCd m | ||
183 | (_,s3,_) = svdC m | ||
184 | (_,s4,_) = svdCd m | ||
185 | (_,s5,_) = thinSVDC m | ||
186 | (_,s6,_) = thinSVDCd m | ||
187 | |||
188 | svdProp6a m = s |~| s' && v |~| v' && s |~| s'' && u |~| u' | ||
189 | where (u,s,v) = svdR m | ||
190 | (s',v') = rightSVR m | ||
191 | (u',s'') = leftSVR m | ||
192 | |||
193 | svdProp6b m = s |~| s' && v |~| v' && s |~| s'' && u |~| u' | ||
194 | where (u,s,v) = svdC m | ||
195 | (s',v') = rightSVC m | ||
196 | (u',s'') = leftSVC m | ||
197 | |||
198 | svdProp7 m = s |~| s' && u |~| u' && v |~| v' && s |~| s''' | ||
199 | where (u,s,v) = svd m | ||
200 | (s',v') = rightSV m | ||
201 | (u',_s'') = leftSV m | ||
202 | s''' = singularValues m | ||
203 | |||
204 | ------------------------------------------------------------------ | ||
205 | |||
206 | eigProp m = complex m <> v |~| v <> diag s | ||
207 | where (s, v) = eig m | ||
208 | |||
209 | eigSHProp m = m <> v |~| v <> real (diag s) | ||
210 | && unitary v | ||
211 | && m |~| v <> real (diag s) <> ctrans v | ||
212 | where (s, v) = eigSH m | ||
213 | |||
214 | eigProp2 m = fst (eig m) |~| eigenvalues m | ||
215 | |||
216 | eigSHProp2 m = fst (eigSH m) |~| eigenvaluesSH m | ||
217 | |||
218 | ------------------------------------------------------------------ | ||
219 | |||
220 | qrProp m = q <> r |~| m && unitary q && upperTriang r | ||
221 | where (q,r) = qr m | ||
222 | |||
223 | rqProp m = r <> q |~| m && unitary q && upperTriang' r | ||
224 | where (r,q) = rq m | ||
225 | |||
226 | rqProp1 m = r <> q |~| m | ||
227 | where (r,q) = rq m | ||
228 | |||
229 | rqProp2 m = unitary q | ||
230 | where (_r,q) = rq m | ||
231 | |||
232 | rqProp3 m = upperTriang' r | ||
233 | where (r,_q) = rq m | ||
234 | |||
235 | upperTriang' r = upptr (rows r) (cols r) * r |~| r | ||
236 | where upptr f c = buildMatrix f c $ \(r',c') -> if r'-t > c' then 0 else 1 | ||
237 | where t = f-c | ||
238 | |||
239 | hessProp m = m |~| p <> h <> ctrans p && unitary p && upperHessenberg h | ||
240 | where (p,h) = hess m | ||
241 | |||
242 | schurProp1 m = m |~| u <> s <> ctrans u && unitary u && upperTriang s | ||
243 | where (u,s) = schur m | ||
244 | |||
245 | schurProp2 m = m |~| u <> s <> ctrans u && unitary u && upperHessenberg s -- fixme | ||
246 | where (u,s) = schur m | ||
247 | |||
248 | cholProp m = m |~| ctrans c <> c && upperTriang c | ||
249 | where c = chol m | ||
250 | |||
251 | exactProp m = chol m == chol (m+0) | ||
252 | |||
253 | expmDiagProp m = expm (logm m) :~ 7 ~: complex m | ||
254 | where logm = matFunc log | ||
255 | |||
256 | -- reference multiply | ||
257 | mulH a b = fromLists [[ doth ai bj | bj <- toColumns b] | ai <- toRows a ] | ||
258 | where doth u v = sum $ zipWith (*) (toList u) (toList v) | ||
259 | |||
260 | multProp1 p (a,b) = (a <> b) :~p~: (mulH a b) | ||
261 | |||
262 | multProp2 p (a,b) = (ctrans (a <> b)) :~p~: (ctrans b <> ctrans a) | ||
263 | |||
264 | linearSolveProp f m = f m m |~| ident (rows m) | ||
265 | |||
266 | linearSolveProp2 f (a,x) = not wc `trivial` (not wc || a <> f a b |~| b) | ||
267 | where q = min (rows a) (cols a) | ||
268 | b = a <> x | ||
269 | wc = rank a == q | ||
270 | |||
271 | subProp m = m == (trans . fromColumns . toRows) m | ||
272 | |||