summaryrefslogtreecommitdiff
path: root/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs')
-rw-r--r--packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs251
1 files changed, 251 insertions, 0 deletions
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs
new file mode 100644
index 0000000..647a06c
--- /dev/null
+++ b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs
@@ -0,0 +1,251 @@
1{-# LANGUAGE FlexibleContexts, UndecidableInstances, CPP, FlexibleInstances #-}
2{-# OPTIONS_GHC -fno-warn-unused-imports #-}
3-----------------------------------------------------------------------------
4{- |
5Module : Numeric.LinearAlgebra.Tests.Instances
6Copyright : (c) Alberto Ruiz 2008
7License : GPL-style
8
9Maintainer : Alberto Ruiz (aruiz at um dot es)
10Stability : provisional
11Portability : portable
12
13Arbitrary instances for vectors, matrices.
14
15-}
16
17module Numeric.LinearAlgebra.Tests.Instances(
18 Sq(..), rSq,cSq,
19 Rot(..), rRot,cRot,
20 Her(..), rHer,cHer,
21 WC(..), rWC,cWC,
22 SqWC(..), rSqWC, cSqWC,
23 PosDef(..), rPosDef, cPosDef,
24 Consistent(..), rConsist, cConsist,
25 RM,CM, rM,cM,
26 FM,ZM, fM,zM
27) where
28
29import System.Random
30
31import Numeric.LinearAlgebra
32import Control.Monad(replicateM)
33import Test.QuickCheck(Arbitrary,arbitrary,coarbitrary,choose,vector
34 ,sized,classify,Testable,Property
35 ,quickCheckWith,maxSize,stdArgs,shrink)
36
37#if MIN_VERSION_QuickCheck(2,0,0)
38shrinkListElementwise :: (Arbitrary a) => [a] -> [[a]]
39shrinkListElementwise [] = []
40shrinkListElementwise (x:xs) = [ y:xs | y <- shrink x ]
41 ++ [ x:ys | ys <- shrinkListElementwise xs ]
42
43shrinkPair :: (Arbitrary a, Arbitrary b) => (a,b) -> [(a,b)]
44shrinkPair (a,b) = [ (a,x) | x <- shrink b ] ++ [ (x,b) | x <- shrink a ]
45#endif
46
47#if MIN_VERSION_QuickCheck(2,1,1)
48#else
49instance (Arbitrary a, RealFloat a) => Arbitrary (Complex a) where
50 arbitrary = do
51 re <- arbitrary
52 im <- arbitrary
53 return (re :+ im)
54
55#if MIN_VERSION_QuickCheck(2,0,0)
56 shrink (re :+ im) =
57 [ u :+ v | (u,v) <- shrinkPair (re,im) ]
58#else
59 -- this has been moved to the 'Coarbitrary' class in QuickCheck 2
60 coarbitrary = undefined
61#endif
62
63#endif
64
65chooseDim = sized $ \m -> choose (1,max 1 m)
66
67instance (Field a, Arbitrary a) => Arbitrary (Vector a) where
68 arbitrary = do m <- chooseDim
69 l <- vector m
70 return $ fromList l
71
72#if MIN_VERSION_QuickCheck(2,0,0)
73 -- shrink any one of the components
74 shrink = map fromList . shrinkListElementwise . toList
75
76#else
77 coarbitrary = undefined
78#endif
79
80instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where
81 arbitrary = do
82 m <- chooseDim
83 n <- chooseDim
84 l <- vector (m*n)
85 return $ (m><n) l
86
87#if MIN_VERSION_QuickCheck(2,0,0)
88 -- shrink any one of the components
89 shrink a = map (rows a >< cols a)
90 . shrinkListElementwise
91 . concat . toLists
92 $ a
93#else
94 coarbitrary = undefined
95#endif
96
97
98-- a square matrix
99newtype (Sq a) = Sq (Matrix a) deriving Show
100instance (Element a, Arbitrary a) => Arbitrary (Sq a) where
101 arbitrary = do
102 n <- chooseDim
103 l <- vector (n*n)
104 return $ Sq $ (n><n) l
105
106#if MIN_VERSION_QuickCheck(2,0,0)
107 shrink (Sq a) = [ Sq b | b <- shrink a ]
108#else
109 coarbitrary = undefined
110#endif
111
112
113-- a unitary matrix
114newtype (Rot a) = Rot (Matrix a) deriving Show
115instance (Field a, Arbitrary a) => Arbitrary (Rot a) where
116 arbitrary = do
117 Sq m <- arbitrary
118 let (q,_) = qr m
119 return (Rot q)
120
121#if MIN_VERSION_QuickCheck(2,0,0)
122#else
123 coarbitrary = undefined
124#endif
125
126
127-- a complex hermitian or real symmetric matrix
128newtype (Her a) = Her (Matrix a) deriving Show
129instance (Field a, Arbitrary a, Num (Vector a)) => Arbitrary (Her a) where
130 arbitrary = do
131 Sq m <- arbitrary
132 let m' = m/2
133 return $ Her (m' + ctrans m')
134
135#if MIN_VERSION_QuickCheck(2,0,0)
136#else
137 coarbitrary = undefined
138#endif
139
140class (Field a, Arbitrary a, Element (RealOf a), Random (RealOf a)) => ArbitraryField a
141instance ArbitraryField Double
142instance ArbitraryField (Complex Double)
143
144
145-- a well-conditioned general matrix (the singular values are between 1 and 100)
146newtype (WC a) = WC (Matrix a) deriving Show
147instance (ArbitraryField a) => Arbitrary (WC a) where
148 arbitrary = do
149 m <- arbitrary
150 let (u,_,v) = svd m
151 r = rows m
152 c = cols m
153 n = min r c
154 sv' <- replicateM n (choose (1,100))
155 let s = diagRect 0 (fromList sv') r c
156 return $ WC (u <> real s <> trans v)
157
158#if MIN_VERSION_QuickCheck(2,0,0)
159#else
160 coarbitrary = undefined
161#endif
162
163
164-- a well-conditioned square matrix (the singular values are between 1 and 100)
165newtype (SqWC a) = SqWC (Matrix a) deriving Show
166instance (ArbitraryField a) => Arbitrary (SqWC a) where
167 arbitrary = do
168 Sq m <- arbitrary
169 let (u,_,v) = svd m
170 n = rows m
171 sv' <- replicateM n (choose (1,100))
172 let s = diag (fromList sv')
173 return $ SqWC (u <> real s <> trans v)
174
175#if MIN_VERSION_QuickCheck(2,0,0)
176#else
177 coarbitrary = undefined
178#endif
179
180
181-- a positive definite square matrix (the eigenvalues are between 0 and 100)
182newtype (PosDef a) = PosDef (Matrix a) deriving Show
183instance (ArbitraryField a, Num (Vector a))
184 => Arbitrary (PosDef a) where
185 arbitrary = do
186 Her m <- arbitrary
187 let (_,v) = eigSH m
188 n = rows m
189 l <- replicateM n (choose (0,100))
190 let s = diag (fromList l)
191 p = v <> real s <> ctrans v
192 return $ PosDef (0.5 * p + 0.5 * ctrans p)
193
194#if MIN_VERSION_QuickCheck(2,0,0)
195#else
196 coarbitrary = undefined
197#endif
198
199
200-- a pair of matrices that can be multiplied
201newtype (Consistent a) = Consistent (Matrix a, Matrix a) deriving Show
202instance (Field a, Arbitrary a) => Arbitrary (Consistent a) where
203 arbitrary = do
204 n <- chooseDim
205 k <- chooseDim
206 m <- chooseDim
207 la <- vector (n*k)
208 lb <- vector (k*m)
209 return $ Consistent ((n><k) la, (k><m) lb)
210
211#if MIN_VERSION_QuickCheck(2,0,0)
212 shrink (Consistent (x,y)) = [ Consistent (u,v) | (u,v) <- shrinkPair (x,y) ]
213#else
214 coarbitrary = undefined
215#endif
216
217
218
219type RM = Matrix Double
220type CM = Matrix (Complex Double)
221type FM = Matrix Float
222type ZM = Matrix (Complex Float)
223
224
225rM m = m :: RM
226cM m = m :: CM
227fM m = m :: FM
228zM m = m :: ZM
229
230
231rHer (Her m) = m :: RM
232cHer (Her m) = m :: CM
233
234rRot (Rot m) = m :: RM
235cRot (Rot m) = m :: CM
236
237rSq (Sq m) = m :: RM
238cSq (Sq m) = m :: CM
239
240rWC (WC m) = m :: RM
241cWC (WC m) = m :: CM
242
243rSqWC (SqWC m) = m :: RM
244cSqWC (SqWC m) = m :: CM
245
246rPosDef (PosDef m) = m :: RM
247cPosDef (PosDef m) = m :: CM
248
249rConsist (Consistent (a,b)) = (a,b::RM)
250cConsist (Consistent (a,b)) = (a,b::CM)
251