diff options
author | Alberto Ruiz <aruiz@um.es> | 2016-03-15 20:33:32 +0100 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2016-03-15 20:33:32 +0100 |
commit | f5e235bbdb4bc342b623676b07245d781a9fb994 (patch) | |
tree | 9c3b7331f40b3fa773de7ce1f09460c58c8e272f /packages/tests/src/Numeric/LinearAlgebra/Tests | |
parent | 6a0bf038091e453115a3451c040cbe790e770b89 (diff) | |
parent | 80e88bbb1fef8b904e5e01d3ca6cc35a97339cda (diff) |
Merge pull request #178 from sid-kap/matrix_binary
Add binary instances for static matrix and vector
Diffstat (limited to 'packages/tests/src/Numeric/LinearAlgebra/Tests')
-rw-r--r-- | packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs | 39 | ||||
-rw-r--r-- | packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs | 44 |
2 files changed, 77 insertions, 6 deletions
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs index 3d5441d..37f7da2 100644 --- a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs | |||
@@ -1,4 +1,4 @@ | |||
1 | {-# LANGUAGE FlexibleContexts, UndecidableInstances, FlexibleInstances #-} | 1 | {-# LANGUAGE FlexibleContexts, UndecidableInstances, FlexibleInstances, ScopedTypeVariables #-} |
2 | ----------------------------------------------------------------------------- | 2 | ----------------------------------------------------------------------------- |
3 | {- | | 3 | {- | |
4 | Module : Numeric.LinearAlgebra.Tests.Instances | 4 | Module : Numeric.LinearAlgebra.Tests.Instances |
@@ -29,6 +29,10 @@ import Numeric.LinearAlgebra.HMatrix hiding (vector) | |||
29 | import Control.Monad(replicateM) | 29 | import Control.Monad(replicateM) |
30 | import Test.QuickCheck(Arbitrary,arbitrary,choose,vector,sized,shrink) | 30 | import Test.QuickCheck(Arbitrary,arbitrary,choose,vector,sized,shrink) |
31 | 31 | ||
32 | import GHC.TypeLits | ||
33 | import Data.Proxy (Proxy(..)) | ||
34 | import qualified Numeric.LinearAlgebra.Static as Static | ||
35 | |||
32 | 36 | ||
33 | shrinkListElementwise :: (Arbitrary a) => [a] -> [[a]] | 37 | shrinkListElementwise :: (Arbitrary a) => [a] -> [[a]] |
34 | shrinkListElementwise [] = [] | 38 | shrinkListElementwise [] = [] |
@@ -40,14 +44,25 @@ shrinkPair (a,b) = [ (a,x) | x <- shrink b ] ++ [ (x,b) | x <- shrink a ] | |||
40 | 44 | ||
41 | chooseDim = sized $ \m -> choose (1,max 1 m) | 45 | chooseDim = sized $ \m -> choose (1,max 1 m) |
42 | 46 | ||
43 | instance (Field a, Arbitrary a) => Arbitrary (Vector a) where | 47 | instance (Field a, Arbitrary a) => Arbitrary (Vector a) where |
44 | arbitrary = do m <- chooseDim | 48 | arbitrary = do m <- chooseDim |
45 | l <- vector m | 49 | l <- vector m |
46 | return $ fromList l | 50 | return $ fromList l |
47 | -- shrink any one of the components | 51 | -- shrink any one of the components |
48 | shrink = map fromList . shrinkListElementwise . toList | 52 | shrink = map fromList . shrinkListElementwise . toList |
49 | 53 | ||
50 | instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where | 54 | instance KnownNat n => Arbitrary (Static.R n) where |
55 | arbitrary = do | ||
56 | l <- vector n | ||
57 | return (Static.fromList l) | ||
58 | |||
59 | where | ||
60 | n :: Int | ||
61 | n = fromIntegral (natVal (Proxy :: Proxy n)) | ||
62 | |||
63 | shrink v = [] | ||
64 | |||
65 | instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where | ||
51 | arbitrary = do | 66 | arbitrary = do |
52 | m <- chooseDim | 67 | m <- chooseDim |
53 | n <- chooseDim | 68 | n <- chooseDim |
@@ -57,9 +72,23 @@ instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where | |||
57 | -- shrink any one of the components | 72 | -- shrink any one of the components |
58 | shrink a = map (rows a >< cols a) | 73 | shrink a = map (rows a >< cols a) |
59 | . shrinkListElementwise | 74 | . shrinkListElementwise |
60 | . concat . toLists | 75 | . concat . toLists |
61 | $ a | 76 | $ a |
62 | 77 | ||
78 | instance (KnownNat n, KnownNat m) => Arbitrary (Static.L m n) where | ||
79 | arbitrary = do | ||
80 | l <- vector (m * n) | ||
81 | return (Static.fromList l) | ||
82 | |||
83 | where | ||
84 | m :: Int | ||
85 | m = fromIntegral (natVal (Proxy :: Proxy m)) | ||
86 | |||
87 | n :: Int | ||
88 | n = fromIntegral (natVal (Proxy :: Proxy n)) | ||
89 | |||
90 | shrink mat = [] | ||
91 | |||
63 | -- a square matrix | 92 | -- a square matrix |
64 | newtype (Sq a) = Sq (Matrix a) deriving Show | 93 | newtype (Sq a) = Sq (Matrix a) deriving Show |
65 | instance (Element a, Arbitrary a) => Arbitrary (Sq a) where | 94 | instance (Element a, Arbitrary a) => Arbitrary (Sq a) where |
@@ -121,7 +150,7 @@ instance (ArbitraryField a, Numeric a) => Arbitrary (SqWC a) where | |||
121 | 150 | ||
122 | -- a positive definite square matrix (the eigenvalues are between 0 and 100) | 151 | -- a positive definite square matrix (the eigenvalues are between 0 and 100) |
123 | newtype (PosDef a) = PosDef (Matrix a) deriving Show | 152 | newtype (PosDef a) = PosDef (Matrix a) deriving Show |
124 | instance (Numeric a, ArbitraryField a, Num (Vector a)) | 153 | instance (Numeric a, ArbitraryField a, Num (Vector a)) |
125 | => Arbitrary (PosDef a) where | 154 | => Arbitrary (PosDef a) where |
126 | arbitrary = do | 155 | arbitrary = do |
127 | m <- arbitrary | 156 | m <- arbitrary |
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs index 046644f..0de9f37 100644 --- a/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs | |||
@@ -1,5 +1,6 @@ | |||
1 | {-# LANGUAGE FlexibleContexts #-} | 1 | {-# LANGUAGE FlexibleContexts #-} |
2 | {-# LANGUAGE TypeFamilies #-} | 2 | {-# LANGUAGE TypeFamilies #-} |
3 | {-# LANGUAGE DataKinds #-} | ||
3 | 4 | ||
4 | ----------------------------------------------------------------------------- | 5 | ----------------------------------------------------------------------------- |
5 | {- | | 6 | {- | |
@@ -39,12 +40,25 @@ module Numeric.LinearAlgebra.Tests.Properties ( | |||
39 | expmDiagProp, | 40 | expmDiagProp, |
40 | multProp1, multProp2, | 41 | multProp1, multProp2, |
41 | subProp, | 42 | subProp, |
42 | linearSolveProp, linearSolvePropH, linearSolveProp2 | 43 | linearSolveProp, linearSolvePropH, linearSolveProp2, |
44 | |||
45 | -- Binary properties | ||
46 | vectorBinaryRoundtripProp | ||
47 | , staticVectorBinaryRoundtripProp | ||
48 | , matrixBinaryRoundtripProp | ||
49 | , staticMatrixBinaryRoundtripProp | ||
50 | , staticVectorBinaryFailProp | ||
43 | ) where | 51 | ) where |
44 | 52 | ||
45 | import Numeric.LinearAlgebra.HMatrix hiding (Testable,unitary) | 53 | import Numeric.LinearAlgebra.HMatrix hiding (Testable,unitary) |
54 | import qualified Numeric.LinearAlgebra.Static as Static | ||
46 | import Test.QuickCheck | 55 | import Test.QuickCheck |
47 | 56 | ||
57 | import Data.Binary | ||
58 | import Data.Binary.Get (runGet) | ||
59 | import Data.Either (isLeft) | ||
60 | import Debug.Trace (traceShowId) | ||
61 | |||
48 | (~=) :: Double -> Double -> Bool | 62 | (~=) :: Double -> Double -> Bool |
49 | a ~= b = abs (a - b) < 1e-10 | 63 | a ~= b = abs (a - b) < 1e-10 |
50 | 64 | ||
@@ -275,3 +289,31 @@ linearSolveProp2 f (a,x) = not wc `trivial` (not wc || a <> f a b |~| b) | |||
275 | 289 | ||
276 | subProp m = m == (conj . tr . fromColumns . toRows) m | 290 | subProp m = m == (conj . tr . fromColumns . toRows) m |
277 | 291 | ||
292 | ------------------------------------------------------------------ | ||
293 | |||
294 | vectorBinaryRoundtripProp :: Vector Double -> Bool | ||
295 | vectorBinaryRoundtripProp vec = decode (encode vec) == vec | ||
296 | |||
297 | staticVectorBinaryRoundtripProp :: Static.R 5 -> Bool | ||
298 | staticVectorBinaryRoundtripProp vec = | ||
299 | let | ||
300 | decoded = decode (encode vec) :: Static.R 500 | ||
301 | in | ||
302 | Static.extract decoded == Static.extract vec | ||
303 | |||
304 | matrixBinaryRoundtripProp :: Matrix Double -> Bool | ||
305 | matrixBinaryRoundtripProp mat = decode (encode mat) == mat | ||
306 | |||
307 | staticMatrixBinaryRoundtripProp :: Static.L 100 200 -> Bool | ||
308 | staticMatrixBinaryRoundtripProp mat = | ||
309 | let | ||
310 | decoded = decode (encode mat) :: Static.L 100 200 | ||
311 | in | ||
312 | (Static.extract decoded) == (Static.extract mat) | ||
313 | |||
314 | staticVectorBinaryFailProp :: Static.R 20 -> Bool | ||
315 | staticVectorBinaryFailProp vec = | ||
316 | let | ||
317 | decoded = runGet get (encode vec) :: Either String (Static.R 50) | ||
318 | in | ||
319 | isLeft decoded | ||