summaryrefslogtreecommitdiff
path: root/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs')
-rw-r--r--packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs46
1 files changed, 41 insertions, 5 deletions
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs
index 3d5441d..23d7e6f 100644
--- a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs
+++ b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs
@@ -1,4 +1,4 @@
1{-# LANGUAGE FlexibleContexts, UndecidableInstances, FlexibleInstances #-} 1{-# LANGUAGE FlexibleContexts, UndecidableInstances, FlexibleInstances, ScopedTypeVariables #-}
2----------------------------------------------------------------------------- 2-----------------------------------------------------------------------------
3{- | 3{- |
4Module : Numeric.LinearAlgebra.Tests.Instances 4Module : Numeric.LinearAlgebra.Tests.Instances
@@ -29,6 +29,10 @@ import Numeric.LinearAlgebra.HMatrix hiding (vector)
29import Control.Monad(replicateM) 29import Control.Monad(replicateM)
30import Test.QuickCheck(Arbitrary,arbitrary,choose,vector,sized,shrink) 30import Test.QuickCheck(Arbitrary,arbitrary,choose,vector,sized,shrink)
31 31
32import GHC.TypeLits
33import Data.Proxy (Proxy)
34import qualified Numeric.LinearAlgebra.Static as Static
35
32 36
33shrinkListElementwise :: (Arbitrary a) => [a] -> [[a]] 37shrinkListElementwise :: (Arbitrary a) => [a] -> [[a]]
34shrinkListElementwise [] = [] 38shrinkListElementwise [] = []
@@ -40,14 +44,27 @@ shrinkPair (a,b) = [ (a,x) | x <- shrink b ] ++ [ (x,b) | x <- shrink a ]
40 44
41chooseDim = sized $ \m -> choose (1,max 1 m) 45chooseDim = sized $ \m -> choose (1,max 1 m)
42 46
43instance (Field a, Arbitrary a) => Arbitrary (Vector a) where 47instance (Field a, Arbitrary a) => Arbitrary (Vector a) where
44 arbitrary = do m <- chooseDim 48 arbitrary = do m <- chooseDim
45 l <- vector m 49 l <- vector m
46 return $ fromList l 50 return $ fromList l
47 -- shrink any one of the components 51 -- shrink any one of the components
48 shrink = map fromList . shrinkListElementwise . toList 52 shrink = map fromList . shrinkListElementwise . toList
49 53
50instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where 54instance KnownNat n => Arbitrary (Static.R n) where
55 arbitrary = do
56 l <- vector n
57 return (Static.fromList l)
58
59 where proxy :: Proxy n
60 proxy = proxy
61
62 n :: Int
63 n = fromIntegral (natVal proxy)
64
65 shrink v = []
66
67instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where
51 arbitrary = do 68 arbitrary = do
52 m <- chooseDim 69 m <- chooseDim
53 n <- chooseDim 70 n <- chooseDim
@@ -57,9 +74,28 @@ instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where
57 -- shrink any one of the components 74 -- shrink any one of the components
58 shrink a = map (rows a >< cols a) 75 shrink a = map (rows a >< cols a)
59 . shrinkListElementwise 76 . shrinkListElementwise
60 . concat . toLists 77 . concat . toLists
61 $ a 78 $ a
62 79
80instance (KnownNat n, KnownNat m) => Arbitrary (Static.L m n) where
81 arbitrary = do
82 l <- vector (m * n)
83 return (Static.fromList l)
84
85 where proxyM :: Proxy m
86 proxyM = proxyM
87
88 proxyN :: Proxy n
89 proxyN = proxyN
90
91 m :: Int
92 m = fromIntegral (natVal proxyM)
93
94 n :: Int
95 n = fromIntegral (natVal proxyN)
96
97 shrink mat = []
98
63-- a square matrix 99-- a square matrix
64newtype (Sq a) = Sq (Matrix a) deriving Show 100newtype (Sq a) = Sq (Matrix a) deriving Show
65instance (Element a, Arbitrary a) => Arbitrary (Sq a) where 101instance (Element a, Arbitrary a) => Arbitrary (Sq a) where
@@ -121,7 +157,7 @@ instance (ArbitraryField a, Numeric a) => Arbitrary (SqWC a) where
121 157
122-- a positive definite square matrix (the eigenvalues are between 0 and 100) 158-- a positive definite square matrix (the eigenvalues are between 0 and 100)
123newtype (PosDef a) = PosDef (Matrix a) deriving Show 159newtype (PosDef a) = PosDef (Matrix a) deriving Show
124instance (Numeric a, ArbitraryField a, Num (Vector a)) 160instance (Numeric a, ArbitraryField a, Num (Vector a))
125 => Arbitrary (PosDef a) where 161 => Arbitrary (PosDef a) where
126 arbitrary = do 162 arbitrary = do
127 m <- arbitrary 163 m <- arbitrary