summaryrefslogtreecommitdiff
path: root/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs')
-rw-r--r--packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs39
1 files changed, 34 insertions, 5 deletions
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs
index 3d5441d..37f7da2 100644
--- a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs
+++ b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs
@@ -1,4 +1,4 @@
1{-# LANGUAGE FlexibleContexts, UndecidableInstances, FlexibleInstances #-} 1{-# LANGUAGE FlexibleContexts, UndecidableInstances, FlexibleInstances, ScopedTypeVariables #-}
2----------------------------------------------------------------------------- 2-----------------------------------------------------------------------------
3{- | 3{- |
4Module : Numeric.LinearAlgebra.Tests.Instances 4Module : Numeric.LinearAlgebra.Tests.Instances
@@ -29,6 +29,10 @@ import Numeric.LinearAlgebra.HMatrix hiding (vector)
29import Control.Monad(replicateM) 29import Control.Monad(replicateM)
30import Test.QuickCheck(Arbitrary,arbitrary,choose,vector,sized,shrink) 30import Test.QuickCheck(Arbitrary,arbitrary,choose,vector,sized,shrink)
31 31
32import GHC.TypeLits
33import Data.Proxy (Proxy(..))
34import qualified Numeric.LinearAlgebra.Static as Static
35
32 36
33shrinkListElementwise :: (Arbitrary a) => [a] -> [[a]] 37shrinkListElementwise :: (Arbitrary a) => [a] -> [[a]]
34shrinkListElementwise [] = [] 38shrinkListElementwise [] = []
@@ -40,14 +44,25 @@ shrinkPair (a,b) = [ (a,x) | x <- shrink b ] ++ [ (x,b) | x <- shrink a ]
40 44
41chooseDim = sized $ \m -> choose (1,max 1 m) 45chooseDim = sized $ \m -> choose (1,max 1 m)
42 46
43instance (Field a, Arbitrary a) => Arbitrary (Vector a) where 47instance (Field a, Arbitrary a) => Arbitrary (Vector a) where
44 arbitrary = do m <- chooseDim 48 arbitrary = do m <- chooseDim
45 l <- vector m 49 l <- vector m
46 return $ fromList l 50 return $ fromList l
47 -- shrink any one of the components 51 -- shrink any one of the components
48 shrink = map fromList . shrinkListElementwise . toList 52 shrink = map fromList . shrinkListElementwise . toList
49 53
50instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where 54instance KnownNat n => Arbitrary (Static.R n) where
55 arbitrary = do
56 l <- vector n
57 return (Static.fromList l)
58
59 where
60 n :: Int
61 n = fromIntegral (natVal (Proxy :: Proxy n))
62
63 shrink v = []
64
65instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where
51 arbitrary = do 66 arbitrary = do
52 m <- chooseDim 67 m <- chooseDim
53 n <- chooseDim 68 n <- chooseDim
@@ -57,9 +72,23 @@ instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where
57 -- shrink any one of the components 72 -- shrink any one of the components
58 shrink a = map (rows a >< cols a) 73 shrink a = map (rows a >< cols a)
59 . shrinkListElementwise 74 . shrinkListElementwise
60 . concat . toLists 75 . concat . toLists
61 $ a 76 $ a
62 77
78instance (KnownNat n, KnownNat m) => Arbitrary (Static.L m n) where
79 arbitrary = do
80 l <- vector (m * n)
81 return (Static.fromList l)
82
83 where
84 m :: Int
85 m = fromIntegral (natVal (Proxy :: Proxy m))
86
87 n :: Int
88 n = fromIntegral (natVal (Proxy :: Proxy n))
89
90 shrink mat = []
91
63-- a square matrix 92-- a square matrix
64newtype (Sq a) = Sq (Matrix a) deriving Show 93newtype (Sq a) = Sq (Matrix a) deriving Show
65instance (Element a, Arbitrary a) => Arbitrary (Sq a) where 94instance (Element a, Arbitrary a) => Arbitrary (Sq a) where
@@ -121,7 +150,7 @@ instance (ArbitraryField a, Numeric a) => Arbitrary (SqWC a) where
121 150
122-- a positive definite square matrix (the eigenvalues are between 0 and 100) 151-- a positive definite square matrix (the eigenvalues are between 0 and 100)
123newtype (PosDef a) = PosDef (Matrix a) deriving Show 152newtype (PosDef a) = PosDef (Matrix a) deriving Show
124instance (Numeric a, ArbitraryField a, Num (Vector a)) 153instance (Numeric a, ArbitraryField a, Num (Vector a))
125 => Arbitrary (PosDef a) where 154 => Arbitrary (PosDef a) where
126 arbitrary = do 155 arbitrary = do
127 m <- arbitrary 156 m <- arbitrary