summaryrefslogtreecommitdiff
path: root/packages/tests/src
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2016-03-15 20:33:32 +0100
committerAlberto Ruiz <aruiz@um.es>2016-03-15 20:33:32 +0100
commitf5e235bbdb4bc342b623676b07245d781a9fb994 (patch)
tree9c3b7331f40b3fa773de7ce1f09460c58c8e272f /packages/tests/src
parent6a0bf038091e453115a3451c040cbe790e770b89 (diff)
parent80e88bbb1fef8b904e5e01d3ca6cc35a97339cda (diff)
Merge pull request #178 from sid-kap/matrix_binary
Add binary instances for static matrix and vector
Diffstat (limited to 'packages/tests/src')
-rw-r--r--packages/tests/src/Numeric/LinearAlgebra/Tests.hs10
-rw-r--r--packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs39
-rw-r--r--packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs44
3 files changed, 87 insertions, 6 deletions
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
index d9cc3b6..043ebf3 100644
--- a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
+++ b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
@@ -26,6 +26,7 @@ module Numeric.LinearAlgebra.Tests(
26 utest, 26 utest,
27 runTests, 27 runTests,
28 runBenchmarks 28 runBenchmarks
29 , binaryTests
29-- , findNaN 30-- , findNaN
30--, runBigTests 31--, runBigTests
31) where 32) where
@@ -743,6 +744,15 @@ makeUnitary v | realPart n > 1 = v / scalar n
743 | otherwise = v 744 | otherwise = v
744 where n = sqrt (v `dot` v) 745 where n = sqrt (v `dot` v)
745 746
747binaryTests :: IO ()
748binaryTests = do
749 let test :: forall t . T.Testable t => t -> IO ()
750 test = qCheck 100
751 test vectorBinaryRoundtripProp
752 test staticVectorBinaryRoundtripProp
753 qCheck 30 matrixBinaryRoundtripProp
754 qCheck 30 staticMatrixBinaryRoundtripProp
755
746-- -- | Some additional tests on big matrices. They take a few minutes. 756-- -- | Some additional tests on big matrices. They take a few minutes.
747-- runBigTests :: IO () 757-- runBigTests :: IO ()
748-- runBigTests = undefined 758-- runBigTests = undefined
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs
index 3d5441d..37f7da2 100644
--- a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs
+++ b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs
@@ -1,4 +1,4 @@
1{-# LANGUAGE FlexibleContexts, UndecidableInstances, FlexibleInstances #-} 1{-# LANGUAGE FlexibleContexts, UndecidableInstances, FlexibleInstances, ScopedTypeVariables #-}
2----------------------------------------------------------------------------- 2-----------------------------------------------------------------------------
3{- | 3{- |
4Module : Numeric.LinearAlgebra.Tests.Instances 4Module : Numeric.LinearAlgebra.Tests.Instances
@@ -29,6 +29,10 @@ import Numeric.LinearAlgebra.HMatrix hiding (vector)
29import Control.Monad(replicateM) 29import Control.Monad(replicateM)
30import Test.QuickCheck(Arbitrary,arbitrary,choose,vector,sized,shrink) 30import Test.QuickCheck(Arbitrary,arbitrary,choose,vector,sized,shrink)
31 31
32import GHC.TypeLits
33import Data.Proxy (Proxy(..))
34import qualified Numeric.LinearAlgebra.Static as Static
35
32 36
33shrinkListElementwise :: (Arbitrary a) => [a] -> [[a]] 37shrinkListElementwise :: (Arbitrary a) => [a] -> [[a]]
34shrinkListElementwise [] = [] 38shrinkListElementwise [] = []
@@ -40,14 +44,25 @@ shrinkPair (a,b) = [ (a,x) | x <- shrink b ] ++ [ (x,b) | x <- shrink a ]
40 44
41chooseDim = sized $ \m -> choose (1,max 1 m) 45chooseDim = sized $ \m -> choose (1,max 1 m)
42 46
43instance (Field a, Arbitrary a) => Arbitrary (Vector a) where 47instance (Field a, Arbitrary a) => Arbitrary (Vector a) where
44 arbitrary = do m <- chooseDim 48 arbitrary = do m <- chooseDim
45 l <- vector m 49 l <- vector m
46 return $ fromList l 50 return $ fromList l
47 -- shrink any one of the components 51 -- shrink any one of the components
48 shrink = map fromList . shrinkListElementwise . toList 52 shrink = map fromList . shrinkListElementwise . toList
49 53
50instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where 54instance KnownNat n => Arbitrary (Static.R n) where
55 arbitrary = do
56 l <- vector n
57 return (Static.fromList l)
58
59 where
60 n :: Int
61 n = fromIntegral (natVal (Proxy :: Proxy n))
62
63 shrink v = []
64
65instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where
51 arbitrary = do 66 arbitrary = do
52 m <- chooseDim 67 m <- chooseDim
53 n <- chooseDim 68 n <- chooseDim
@@ -57,9 +72,23 @@ instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where
57 -- shrink any one of the components 72 -- shrink any one of the components
58 shrink a = map (rows a >< cols a) 73 shrink a = map (rows a >< cols a)
59 . shrinkListElementwise 74 . shrinkListElementwise
60 . concat . toLists 75 . concat . toLists
61 $ a 76 $ a
62 77
78instance (KnownNat n, KnownNat m) => Arbitrary (Static.L m n) where
79 arbitrary = do
80 l <- vector (m * n)
81 return (Static.fromList l)
82
83 where
84 m :: Int
85 m = fromIntegral (natVal (Proxy :: Proxy m))
86
87 n :: Int
88 n = fromIntegral (natVal (Proxy :: Proxy n))
89
90 shrink mat = []
91
63-- a square matrix 92-- a square matrix
64newtype (Sq a) = Sq (Matrix a) deriving Show 93newtype (Sq a) = Sq (Matrix a) deriving Show
65instance (Element a, Arbitrary a) => Arbitrary (Sq a) where 94instance (Element a, Arbitrary a) => Arbitrary (Sq a) where
@@ -121,7 +150,7 @@ instance (ArbitraryField a, Numeric a) => Arbitrary (SqWC a) where
121 150
122-- a positive definite square matrix (the eigenvalues are between 0 and 100) 151-- a positive definite square matrix (the eigenvalues are between 0 and 100)
123newtype (PosDef a) = PosDef (Matrix a) deriving Show 152newtype (PosDef a) = PosDef (Matrix a) deriving Show
124instance (Numeric a, ArbitraryField a, Num (Vector a)) 153instance (Numeric a, ArbitraryField a, Num (Vector a))
125 => Arbitrary (PosDef a) where 154 => Arbitrary (PosDef a) where
126 arbitrary = do 155 arbitrary = do
127 m <- arbitrary 156 m <- arbitrary
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs
index 046644f..0de9f37 100644
--- a/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs
+++ b/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs
@@ -1,5 +1,6 @@
1{-# LANGUAGE FlexibleContexts #-} 1{-# LANGUAGE FlexibleContexts #-}
2{-# LANGUAGE TypeFamilies #-} 2{-# LANGUAGE TypeFamilies #-}
3{-# LANGUAGE DataKinds #-}
3 4
4----------------------------------------------------------------------------- 5-----------------------------------------------------------------------------
5{- | 6{- |
@@ -39,12 +40,25 @@ module Numeric.LinearAlgebra.Tests.Properties (
39 expmDiagProp, 40 expmDiagProp,
40 multProp1, multProp2, 41 multProp1, multProp2,
41 subProp, 42 subProp,
42 linearSolveProp, linearSolvePropH, linearSolveProp2 43 linearSolveProp, linearSolvePropH, linearSolveProp2,
44
45 -- Binary properties
46 vectorBinaryRoundtripProp
47 , staticVectorBinaryRoundtripProp
48 , matrixBinaryRoundtripProp
49 , staticMatrixBinaryRoundtripProp
50 , staticVectorBinaryFailProp
43) where 51) where
44 52
45import Numeric.LinearAlgebra.HMatrix hiding (Testable,unitary) 53import Numeric.LinearAlgebra.HMatrix hiding (Testable,unitary)
54import qualified Numeric.LinearAlgebra.Static as Static
46import Test.QuickCheck 55import Test.QuickCheck
47 56
57import Data.Binary
58import Data.Binary.Get (runGet)
59import Data.Either (isLeft)
60import Debug.Trace (traceShowId)
61
48(~=) :: Double -> Double -> Bool 62(~=) :: Double -> Double -> Bool
49a ~= b = abs (a - b) < 1e-10 63a ~= b = abs (a - b) < 1e-10
50 64
@@ -275,3 +289,31 @@ linearSolveProp2 f (a,x) = not wc `trivial` (not wc || a <> f a b |~| b)
275 289
276subProp m = m == (conj . tr . fromColumns . toRows) m 290subProp m = m == (conj . tr . fromColumns . toRows) m
277 291
292------------------------------------------------------------------
293
294vectorBinaryRoundtripProp :: Vector Double -> Bool
295vectorBinaryRoundtripProp vec = decode (encode vec) == vec
296
297staticVectorBinaryRoundtripProp :: Static.R 5 -> Bool
298staticVectorBinaryRoundtripProp vec =
299 let
300 decoded = decode (encode vec) :: Static.R 500
301 in
302 Static.extract decoded == Static.extract vec
303
304matrixBinaryRoundtripProp :: Matrix Double -> Bool
305matrixBinaryRoundtripProp mat = decode (encode mat) == mat
306
307staticMatrixBinaryRoundtripProp :: Static.L 100 200 -> Bool
308staticMatrixBinaryRoundtripProp mat =
309 let
310 decoded = decode (encode mat) :: Static.L 100 200
311 in
312 (Static.extract decoded) == (Static.extract mat)
313
314staticVectorBinaryFailProp :: Static.R 20 -> Bool
315staticVectorBinaryFailProp vec =
316 let
317 decoded = runGet get (encode vec) :: Either String (Static.R 50)
318 in
319 isLeft decoded