diff options
Diffstat (limited to 'packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs')
-rw-r--r-- | packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs | 251 |
1 files changed, 251 insertions, 0 deletions
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs new file mode 100644 index 0000000..647a06c --- /dev/null +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs | |||
@@ -0,0 +1,251 @@ | |||
1 | {-# LANGUAGE FlexibleContexts, UndecidableInstances, CPP, FlexibleInstances #-} | ||
2 | {-# OPTIONS_GHC -fno-warn-unused-imports #-} | ||
3 | ----------------------------------------------------------------------------- | ||
4 | {- | | ||
5 | Module : Numeric.LinearAlgebra.Tests.Instances | ||
6 | Copyright : (c) Alberto Ruiz 2008 | ||
7 | License : GPL-style | ||
8 | |||
9 | Maintainer : Alberto Ruiz (aruiz at um dot es) | ||
10 | Stability : provisional | ||
11 | Portability : portable | ||
12 | |||
13 | Arbitrary instances for vectors, matrices. | ||
14 | |||
15 | -} | ||
16 | |||
17 | module Numeric.LinearAlgebra.Tests.Instances( | ||
18 | Sq(..), rSq,cSq, | ||
19 | Rot(..), rRot,cRot, | ||
20 | Her(..), rHer,cHer, | ||
21 | WC(..), rWC,cWC, | ||
22 | SqWC(..), rSqWC, cSqWC, | ||
23 | PosDef(..), rPosDef, cPosDef, | ||
24 | Consistent(..), rConsist, cConsist, | ||
25 | RM,CM, rM,cM, | ||
26 | FM,ZM, fM,zM | ||
27 | ) where | ||
28 | |||
29 | import System.Random | ||
30 | |||
31 | import Numeric.LinearAlgebra | ||
32 | import Control.Monad(replicateM) | ||
33 | import Test.QuickCheck(Arbitrary,arbitrary,coarbitrary,choose,vector | ||
34 | ,sized,classify,Testable,Property | ||
35 | ,quickCheckWith,maxSize,stdArgs,shrink) | ||
36 | |||
37 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
38 | shrinkListElementwise :: (Arbitrary a) => [a] -> [[a]] | ||
39 | shrinkListElementwise [] = [] | ||
40 | shrinkListElementwise (x:xs) = [ y:xs | y <- shrink x ] | ||
41 | ++ [ x:ys | ys <- shrinkListElementwise xs ] | ||
42 | |||
43 | shrinkPair :: (Arbitrary a, Arbitrary b) => (a,b) -> [(a,b)] | ||
44 | shrinkPair (a,b) = [ (a,x) | x <- shrink b ] ++ [ (x,b) | x <- shrink a ] | ||
45 | #endif | ||
46 | |||
47 | #if MIN_VERSION_QuickCheck(2,1,1) | ||
48 | #else | ||
49 | instance (Arbitrary a, RealFloat a) => Arbitrary (Complex a) where | ||
50 | arbitrary = do | ||
51 | re <- arbitrary | ||
52 | im <- arbitrary | ||
53 | return (re :+ im) | ||
54 | |||
55 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
56 | shrink (re :+ im) = | ||
57 | [ u :+ v | (u,v) <- shrinkPair (re,im) ] | ||
58 | #else | ||
59 | -- this has been moved to the 'Coarbitrary' class in QuickCheck 2 | ||
60 | coarbitrary = undefined | ||
61 | #endif | ||
62 | |||
63 | #endif | ||
64 | |||
65 | chooseDim = sized $ \m -> choose (1,max 1 m) | ||
66 | |||
67 | instance (Field a, Arbitrary a) => Arbitrary (Vector a) where | ||
68 | arbitrary = do m <- chooseDim | ||
69 | l <- vector m | ||
70 | return $ fromList l | ||
71 | |||
72 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
73 | -- shrink any one of the components | ||
74 | shrink = map fromList . shrinkListElementwise . toList | ||
75 | |||
76 | #else | ||
77 | coarbitrary = undefined | ||
78 | #endif | ||
79 | |||
80 | instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where | ||
81 | arbitrary = do | ||
82 | m <- chooseDim | ||
83 | n <- chooseDim | ||
84 | l <- vector (m*n) | ||
85 | return $ (m><n) l | ||
86 | |||
87 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
88 | -- shrink any one of the components | ||
89 | shrink a = map (rows a >< cols a) | ||
90 | . shrinkListElementwise | ||
91 | . concat . toLists | ||
92 | $ a | ||
93 | #else | ||
94 | coarbitrary = undefined | ||
95 | #endif | ||
96 | |||
97 | |||
98 | -- a square matrix | ||
99 | newtype (Sq a) = Sq (Matrix a) deriving Show | ||
100 | instance (Element a, Arbitrary a) => Arbitrary (Sq a) where | ||
101 | arbitrary = do | ||
102 | n <- chooseDim | ||
103 | l <- vector (n*n) | ||
104 | return $ Sq $ (n><n) l | ||
105 | |||
106 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
107 | shrink (Sq a) = [ Sq b | b <- shrink a ] | ||
108 | #else | ||
109 | coarbitrary = undefined | ||
110 | #endif | ||
111 | |||
112 | |||
113 | -- a unitary matrix | ||
114 | newtype (Rot a) = Rot (Matrix a) deriving Show | ||
115 | instance (Field a, Arbitrary a) => Arbitrary (Rot a) where | ||
116 | arbitrary = do | ||
117 | Sq m <- arbitrary | ||
118 | let (q,_) = qr m | ||
119 | return (Rot q) | ||
120 | |||
121 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
122 | #else | ||
123 | coarbitrary = undefined | ||
124 | #endif | ||
125 | |||
126 | |||
127 | -- a complex hermitian or real symmetric matrix | ||
128 | newtype (Her a) = Her (Matrix a) deriving Show | ||
129 | instance (Field a, Arbitrary a, Num (Vector a)) => Arbitrary (Her a) where | ||
130 | arbitrary = do | ||
131 | Sq m <- arbitrary | ||
132 | let m' = m/2 | ||
133 | return $ Her (m' + ctrans m') | ||
134 | |||
135 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
136 | #else | ||
137 | coarbitrary = undefined | ||
138 | #endif | ||
139 | |||
140 | class (Field a, Arbitrary a, Element (RealOf a), Random (RealOf a)) => ArbitraryField a | ||
141 | instance ArbitraryField Double | ||
142 | instance ArbitraryField (Complex Double) | ||
143 | |||
144 | |||
145 | -- a well-conditioned general matrix (the singular values are between 1 and 100) | ||
146 | newtype (WC a) = WC (Matrix a) deriving Show | ||
147 | instance (ArbitraryField a) => Arbitrary (WC a) where | ||
148 | arbitrary = do | ||
149 | m <- arbitrary | ||
150 | let (u,_,v) = svd m | ||
151 | r = rows m | ||
152 | c = cols m | ||
153 | n = min r c | ||
154 | sv' <- replicateM n (choose (1,100)) | ||
155 | let s = diagRect 0 (fromList sv') r c | ||
156 | return $ WC (u <> real s <> trans v) | ||
157 | |||
158 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
159 | #else | ||
160 | coarbitrary = undefined | ||
161 | #endif | ||
162 | |||
163 | |||
164 | -- a well-conditioned square matrix (the singular values are between 1 and 100) | ||
165 | newtype (SqWC a) = SqWC (Matrix a) deriving Show | ||
166 | instance (ArbitraryField a) => Arbitrary (SqWC a) where | ||
167 | arbitrary = do | ||
168 | Sq m <- arbitrary | ||
169 | let (u,_,v) = svd m | ||
170 | n = rows m | ||
171 | sv' <- replicateM n (choose (1,100)) | ||
172 | let s = diag (fromList sv') | ||
173 | return $ SqWC (u <> real s <> trans v) | ||
174 | |||
175 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
176 | #else | ||
177 | coarbitrary = undefined | ||
178 | #endif | ||
179 | |||
180 | |||
181 | -- a positive definite square matrix (the eigenvalues are between 0 and 100) | ||
182 | newtype (PosDef a) = PosDef (Matrix a) deriving Show | ||
183 | instance (ArbitraryField a, Num (Vector a)) | ||
184 | => Arbitrary (PosDef a) where | ||
185 | arbitrary = do | ||
186 | Her m <- arbitrary | ||
187 | let (_,v) = eigSH m | ||
188 | n = rows m | ||
189 | l <- replicateM n (choose (0,100)) | ||
190 | let s = diag (fromList l) | ||
191 | p = v <> real s <> ctrans v | ||
192 | return $ PosDef (0.5 * p + 0.5 * ctrans p) | ||
193 | |||
194 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
195 | #else | ||
196 | coarbitrary = undefined | ||
197 | #endif | ||
198 | |||
199 | |||
200 | -- a pair of matrices that can be multiplied | ||
201 | newtype (Consistent a) = Consistent (Matrix a, Matrix a) deriving Show | ||
202 | instance (Field a, Arbitrary a) => Arbitrary (Consistent a) where | ||
203 | arbitrary = do | ||
204 | n <- chooseDim | ||
205 | k <- chooseDim | ||
206 | m <- chooseDim | ||
207 | la <- vector (n*k) | ||
208 | lb <- vector (k*m) | ||
209 | return $ Consistent ((n><k) la, (k><m) lb) | ||
210 | |||
211 | #if MIN_VERSION_QuickCheck(2,0,0) | ||
212 | shrink (Consistent (x,y)) = [ Consistent (u,v) | (u,v) <- shrinkPair (x,y) ] | ||
213 | #else | ||
214 | coarbitrary = undefined | ||
215 | #endif | ||
216 | |||
217 | |||
218 | |||
219 | type RM = Matrix Double | ||
220 | type CM = Matrix (Complex Double) | ||
221 | type FM = Matrix Float | ||
222 | type ZM = Matrix (Complex Float) | ||
223 | |||
224 | |||
225 | rM m = m :: RM | ||
226 | cM m = m :: CM | ||
227 | fM m = m :: FM | ||
228 | zM m = m :: ZM | ||
229 | |||
230 | |||
231 | rHer (Her m) = m :: RM | ||
232 | cHer (Her m) = m :: CM | ||
233 | |||
234 | rRot (Rot m) = m :: RM | ||
235 | cRot (Rot m) = m :: CM | ||
236 | |||
237 | rSq (Sq m) = m :: RM | ||
238 | cSq (Sq m) = m :: CM | ||
239 | |||
240 | rWC (WC m) = m :: RM | ||
241 | cWC (WC m) = m :: CM | ||
242 | |||
243 | rSqWC (SqWC m) = m :: RM | ||
244 | cSqWC (SqWC m) = m :: CM | ||
245 | |||
246 | rPosDef (PosDef m) = m :: RM | ||
247 | cPosDef (PosDef m) = m :: CM | ||
248 | |||
249 | rConsist (Consistent (a,b)) = (a,b::RM) | ||
250 | cConsist (Consistent (a,b)) = (a,b::CM) | ||
251 | |||