diff options
Diffstat (limited to 'packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs')
-rw-r--r-- | packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs | 46 |
1 files changed, 41 insertions, 5 deletions
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs index 3d5441d..23d7e6f 100644 --- a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs | |||
@@ -1,4 +1,4 @@ | |||
1 | {-# LANGUAGE FlexibleContexts, UndecidableInstances, FlexibleInstances #-} | 1 | {-# LANGUAGE FlexibleContexts, UndecidableInstances, FlexibleInstances, ScopedTypeVariables #-} |
2 | ----------------------------------------------------------------------------- | 2 | ----------------------------------------------------------------------------- |
3 | {- | | 3 | {- | |
4 | Module : Numeric.LinearAlgebra.Tests.Instances | 4 | Module : Numeric.LinearAlgebra.Tests.Instances |
@@ -29,6 +29,10 @@ import Numeric.LinearAlgebra.HMatrix hiding (vector) | |||
29 | import Control.Monad(replicateM) | 29 | import Control.Monad(replicateM) |
30 | import Test.QuickCheck(Arbitrary,arbitrary,choose,vector,sized,shrink) | 30 | import Test.QuickCheck(Arbitrary,arbitrary,choose,vector,sized,shrink) |
31 | 31 | ||
32 | import GHC.TypeLits | ||
33 | import Data.Proxy (Proxy) | ||
34 | import qualified Numeric.LinearAlgebra.Static as Static | ||
35 | |||
32 | 36 | ||
33 | shrinkListElementwise :: (Arbitrary a) => [a] -> [[a]] | 37 | shrinkListElementwise :: (Arbitrary a) => [a] -> [[a]] |
34 | shrinkListElementwise [] = [] | 38 | shrinkListElementwise [] = [] |
@@ -40,14 +44,27 @@ shrinkPair (a,b) = [ (a,x) | x <- shrink b ] ++ [ (x,b) | x <- shrink a ] | |||
40 | 44 | ||
41 | chooseDim = sized $ \m -> choose (1,max 1 m) | 45 | chooseDim = sized $ \m -> choose (1,max 1 m) |
42 | 46 | ||
43 | instance (Field a, Arbitrary a) => Arbitrary (Vector a) where | 47 | instance (Field a, Arbitrary a) => Arbitrary (Vector a) where |
44 | arbitrary = do m <- chooseDim | 48 | arbitrary = do m <- chooseDim |
45 | l <- vector m | 49 | l <- vector m |
46 | return $ fromList l | 50 | return $ fromList l |
47 | -- shrink any one of the components | 51 | -- shrink any one of the components |
48 | shrink = map fromList . shrinkListElementwise . toList | 52 | shrink = map fromList . shrinkListElementwise . toList |
49 | 53 | ||
50 | instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where | 54 | instance KnownNat n => Arbitrary (Static.R n) where |
55 | arbitrary = do | ||
56 | l <- vector n | ||
57 | return (Static.fromList l) | ||
58 | |||
59 | where proxy :: Proxy n | ||
60 | proxy = proxy | ||
61 | |||
62 | n :: Int | ||
63 | n = fromIntegral (natVal proxy) | ||
64 | |||
65 | shrink v = [] | ||
66 | |||
67 | instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where | ||
51 | arbitrary = do | 68 | arbitrary = do |
52 | m <- chooseDim | 69 | m <- chooseDim |
53 | n <- chooseDim | 70 | n <- chooseDim |
@@ -57,9 +74,28 @@ instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where | |||
57 | -- shrink any one of the components | 74 | -- shrink any one of the components |
58 | shrink a = map (rows a >< cols a) | 75 | shrink a = map (rows a >< cols a) |
59 | . shrinkListElementwise | 76 | . shrinkListElementwise |
60 | . concat . toLists | 77 | . concat . toLists |
61 | $ a | 78 | $ a |
62 | 79 | ||
80 | instance (KnownNat n, KnownNat m) => Arbitrary (Static.L m n) where | ||
81 | arbitrary = do | ||
82 | l <- vector (m * n) | ||
83 | return (Static.fromList l) | ||
84 | |||
85 | where proxyM :: Proxy m | ||
86 | proxyM = proxyM | ||
87 | |||
88 | proxyN :: Proxy n | ||
89 | proxyN = proxyN | ||
90 | |||
91 | m :: Int | ||
92 | m = fromIntegral (natVal proxyM) | ||
93 | |||
94 | n :: Int | ||
95 | n = fromIntegral (natVal proxyN) | ||
96 | |||
97 | shrink mat = [] | ||
98 | |||
63 | -- a square matrix | 99 | -- a square matrix |
64 | newtype (Sq a) = Sq (Matrix a) deriving Show | 100 | newtype (Sq a) = Sq (Matrix a) deriving Show |
65 | instance (Element a, Arbitrary a) => Arbitrary (Sq a) where | 101 | instance (Element a, Arbitrary a) => Arbitrary (Sq a) where |
@@ -121,7 +157,7 @@ instance (ArbitraryField a, Numeric a) => Arbitrary (SqWC a) where | |||
121 | 157 | ||
122 | -- a positive definite square matrix (the eigenvalues are between 0 and 100) | 158 | -- a positive definite square matrix (the eigenvalues are between 0 and 100) |
123 | newtype (PosDef a) = PosDef (Matrix a) deriving Show | 159 | newtype (PosDef a) = PosDef (Matrix a) deriving Show |
124 | instance (Numeric a, ArbitraryField a, Num (Vector a)) | 160 | instance (Numeric a, ArbitraryField a, Num (Vector a)) |
125 | => Arbitrary (PosDef a) where | 161 | => Arbitrary (PosDef a) where |
126 | arbitrary = do | 162 | arbitrary = do |
127 | m <- arbitrary | 163 | m <- arbitrary |