diff options
-rw-r--r-- | HSSL.cabal | 2 | ||||
-rw-r--r-- | examples/tests.hs | 152 | ||||
-rw-r--r-- | lib/Data/Packed/Instances.hs | 391 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 18 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Vector.hs | 6 | ||||
-rw-r--r-- | lib/Data/Packed/Matrix.hs | 35 | ||||
-rw-r--r-- | lib/Data/Packed/Vector.hs | 13 | ||||
-rw-r--r-- | lib/GSL.hs | 45 | ||||
-rw-r--r-- | lib/GSL/Matrix.hs | 311 | ||||
-rw-r--r-- | lib/GSL/Vector.hs | 7 | ||||
-rw-r--r-- | lib/LinearAlgebra.hs | 1 | ||||
-rw-r--r-- | lib/LinearAlgebra/Algorithms.hs | 227 |
12 files changed, 1119 insertions, 89 deletions
@@ -24,7 +24,7 @@ Exposed-modules: Data.Packed.Internal, | |||
24 | Data.Packed.Internal.Matrix, Data.Packed.Matrix, | 24 | Data.Packed.Internal.Matrix, Data.Packed.Matrix, |
25 | Data.Packed.Internal.Tensor, Data.Packed.Tensor, | 25 | Data.Packed.Internal.Tensor, Data.Packed.Tensor, |
26 | Data.Packed.Plot | 26 | Data.Packed.Plot |
27 | -- Data.Packed.Instances | 27 | Data.Packed.Instances |
28 | LAPACK, | 28 | LAPACK, |
29 | GSL.Vector, | 29 | GSL.Vector, |
30 | GSL.Matrix | 30 | GSL.Matrix |
diff --git a/examples/tests.hs b/examples/tests.hs index 37800cd..9de9339 100644 --- a/examples/tests.hs +++ b/examples/tests.hs | |||
@@ -16,12 +16,26 @@ import GSL.Fourier | |||
16 | import GSL.Polynomials | 16 | import GSL.Polynomials |
17 | import LAPACK | 17 | import LAPACK |
18 | import Test.QuickCheck | 18 | import Test.QuickCheck |
19 | import Test.HUnit | 19 | import Test.HUnit hiding ((~:)) |
20 | import Complex | 20 | import Complex |
21 | import LinearAlgebra.Algorithms | ||
22 | import GSL.Matrix | ||
23 | import Data.Packed.Instances hiding ((<>)) | ||
24 | |||
25 | dist :: (Normed t, Num t) => t -> t -> Double | ||
26 | dist a b = norm (a-b) | ||
27 | |||
28 | infixl 4 |~| | ||
29 | a |~| b = a :~8~: b | ||
30 | |||
31 | data Aprox a = (:~) a Int | ||
32 | |||
33 | (~:) :: (Normed a, Num a) => Aprox a -> a -> Bool | ||
34 | a :~n~: b = dist a b < 10^^(-n) | ||
21 | 35 | ||
22 | 36 | ||
23 | {- | 37 | {- |
24 | -- Bravo por quickCheck! | 38 | -- Bravo por quickCheck! |
25 | 39 | ||
26 | pinvProp1 tol m = (rank m == cols m) ==> pinv m <> m ~~ ident (cols m) | 40 | pinvProp1 tol m = (rank m == cols m) ==> pinv m <> m ~~ ident (cols m) |
27 | where infix 2 ~~ | 41 | where infix 2 ~~ |
@@ -32,7 +46,7 @@ pinvProp2 tol m = 0 < r && r <= c ==> (r==c) `trivial` (m <> pinv m <> m ~~ m) | |||
32 | c = cols m | 46 | c = cols m |
33 | infix 2 ~~ | 47 | infix 2 ~~ |
34 | (~~) = approxEqual tol | 48 | (~~) = approxEqual tol |
35 | 49 | ||
36 | nullspaceProp tol m = cr > 0 ==> m <> nt ~~ zeros | 50 | nullspaceProp tol m = cr > 0 ==> m <> nt ~~ zeros |
37 | where nt = trans (nullspace m) | 51 | where nt = trans (nullspace m) |
38 | cr = corank m | 52 | cr = corank m |
@@ -49,31 +63,28 @@ mz = (2 >< 3) [1,2,3,4,5,6:+(1::Double)] | |||
49 | af = (2>|<3) [1,4,2,5,3,6::Double] | 63 | af = (2>|<3) [1,4,2,5,3,6::Double] |
50 | bf = (3>|<4) [7,11,15,8,12,16,9,13,17,10,14,18::Double] | 64 | bf = (3>|<4) [7,11,15,8,12,16,9,13,17,10,14,18::Double] |
51 | 65 | ||
52 | a |=| b = rows a == rows b && | ||
53 | cols a == cols b && | ||
54 | toList (cdat a) == toList (cdat b) && | ||
55 | toList (fdat a) == toList (fdat b) | ||
56 | 66 | ||
67 | {- | ||
57 | aprox fun a b = rows a == rows b && | 68 | aprox fun a b = rows a == rows b && |
58 | cols a == cols b && | 69 | cols a == cols b && |
59 | eps > aproxL fun (toList (t a)) (toList (t b)) | 70 | epsTol > aproxL fun (toList (t a)) (toList (t b)) |
60 | where t = if (order a == RowMajor) `xor` isTrans a then cdat else fdat | 71 | where t = if (order a == RowMajor) `xor` isTrans a then cdat else fdat |
61 | 72 | ||
62 | aproxL fun v1 v2 = sum (zipWith (\a b-> fun (a-b)) v1 v2) / fromIntegral (length v1) | 73 | aproxL fun v1 v2 = sum (zipWith (\a b-> fun (a-b)) v1 v2) / fromIntegral (length v1) |
63 | 74 | ||
64 | normVR a b = toScalarR AbsSum (vectorZipR Sub a b) | 75 | normVR a b = toScalarR AbsSum (vectorZipR Sub a b) |
65 | 76 | ||
66 | a |~| b = rows a == rows b && cols a == cols b && eps > normVR (t a) (t b) | 77 | a |~| b = rows a == rows b && cols a == cols b && epsTol > normVR (t a) (t b) |
67 | where t = if (order a == RowMajor) `xor` isTrans a then cdat else fdat | 78 | where t = if (order a == RowMajor) `xor` isTrans a then cdat else fdat |
68 | 79 | ||
69 | (|~~|) = aprox magnitude | 80 | (|~~|) = aprox magnitude |
70 | 81 | ||
71 | v1 ~~ v2 = reshape 1 v1 |~~| reshape 1 v2 | 82 | v1 ~~ v2 = reshape 1 v1 |~~| reshape 1 v2 |
72 | 83 | ||
73 | u ~|~ v = normVR u v < eps | 84 | u ~|~ v = normVR u v < epsTol |
74 | 85 | -} | |
75 | 86 | ||
76 | eps = 1E-8::Double | 87 | epsTol = 1E-8::Double |
77 | 88 | ||
78 | asFortran m = (rows m >|< cols m) $ toList (fdat m) | 89 | asFortran m = (rows m >|< cols m) $ toList (fdat m) |
79 | asC m = (rows m >< cols m) $ toList (cdat m) | 90 | asC m = (rows m >< cols m) $ toList (cdat m) |
@@ -81,6 +92,9 @@ asC m = (rows m >< cols m) $ toList (cdat m) | |||
81 | mulC a b = multiply RowMajor a b | 92 | mulC a b = multiply RowMajor a b |
82 | mulF a b = multiply ColumnMajor a b | 93 | mulF a b = multiply ColumnMajor a b |
83 | 94 | ||
95 | infixl 7 <> | ||
96 | a <> b = mulF a b | ||
97 | |||
84 | cc = mulC ac bf | 98 | cc = mulC ac bf |
85 | cf = mulF af bc | 99 | cf = mulF af bc |
86 | 100 | ||
@@ -133,14 +147,14 @@ data Sym a = Sym (Matrix a) deriving Show | |||
133 | instance (Field a, Arbitrary a, Num a) => Arbitrary (Sym a) where | 147 | instance (Field a, Arbitrary a, Num a) => Arbitrary (Sym a) where |
134 | arbitrary = do | 148 | arbitrary = do |
135 | SqM m <- arbitrary | 149 | SqM m <- arbitrary |
136 | return $ Sym (m `addM` trans m) | 150 | return $ Sym (m + trans m) |
137 | coarbitrary = undefined | 151 | coarbitrary = undefined |
138 | 152 | ||
139 | data Her = Her (Matrix (Complex Double)) deriving Show | 153 | data Her = Her (Matrix (Complex Double)) deriving Show |
140 | instance {-(Field a, Arbitrary a, Num a) =>-} Arbitrary Her where | 154 | instance {-(Field a, Arbitrary a, Num a) =>-} Arbitrary Her where |
141 | arbitrary = do | 155 | arbitrary = do |
142 | SqM m <- arbitrary | 156 | SqM m <- arbitrary |
143 | return $ Her (m `addM` (liftMatrix conj) (trans m)) | 157 | return $ Her (m + conjTrans m) |
144 | coarbitrary = undefined | 158 | coarbitrary = undefined |
145 | 159 | ||
146 | data PairSM a = PairSM (Matrix a) (Matrix a) deriving Show | 160 | data PairSM a = PairSM (Matrix a) (Matrix a) deriving Show |
@@ -171,49 +185,40 @@ instance (Field a, Arbitrary a) => Arbitrary (PairV a) where | |||
171 | coarbitrary = undefined | 185 | coarbitrary = undefined |
172 | 186 | ||
173 | 187 | ||
174 | addM m1 m2 = liftMatrix2 add m1 m2 | ||
175 | |||
176 | 188 | ||
177 | type BaseType = Double | 189 | type BaseType = Double |
178 | 190 | ||
179 | svdTestR fun prod m = u <> s <> trans v |~| m | 191 | svdTestR fun m = u <> s <> trans v |~| m |
180 | && u <> trans u |~| ident (rows m) | 192 | && u <> trans u |~| ident (rows m) |
181 | && v <> trans v |~| ident (cols m) | 193 | && v <> trans v |~| ident (cols m) |
182 | where (u,s,v) = fun m | 194 | where (u,s,v) = fun m |
183 | (<>) = prod | ||
184 | 195 | ||
185 | 196 | ||
186 | svdTestC prod m = u <> s' <> (trans v) |~~| m | 197 | svdTestC m = u <> s' <> (trans v) |~| m |
187 | && u <> (liftMatrix conj) (trans u) |~~| ident (rows m) | 198 | && u <> conjTrans u |~| ident (rows m) |
188 | && v <> (liftMatrix conj) (trans v) |~~| ident (cols m) | 199 | && v <> conjTrans v |~| ident (cols m) |
189 | where (u,s,v) = svdC m | 200 | where (u,s,v) = svdC m |
190 | (<>) = prod | ||
191 | s' = liftMatrix comp s | 201 | s' = liftMatrix comp s |
192 | 202 | ||
193 | eigTestC prod (SqM m) = (m <> v) |~~| (v <> diag s) | 203 | --svdg' m = (u,s',v) where |
194 | && takeDiag ((liftMatrix conj (trans v)) <> v) ~~ constant 1 (rows m) --normalized | 204 | |
205 | eigTestC (SqM m) = (m <> v) |~| (v <> diag s) | ||
206 | && takeDiag (conjTrans v <> v) |~| constant 1 (rows m) --normalized | ||
195 | where (s,v) = eigC m | 207 | where (s,v) = eigC m |
196 | (<>) = prod | ||
197 | 208 | ||
198 | eigTestR prod (SqM m) = (liftMatrix comp m <> v) |~~| (v <> diag s) | 209 | eigTestR (SqM m) = (liftMatrix comp m <> v) |~| (v <> diag s) |
199 | -- && takeDiag ((liftMatrix conj (trans v)) <> v) ~~ constant 1 (rows m) --normalized ??? | 210 | -- && takeDiag ((liftMatrix conj (trans v)) <> v) |~| constant 1 (rows m) --normalized ??? |
200 | where (s,v) = eigR m | 211 | where (s,v) = eigR m |
201 | (<>) = prod | ||
202 | 212 | ||
203 | eigTestS prod (Sym m) = (m <> v) |~| (v <> diag s) | 213 | eigTestS (Sym m) = (m <> v) |~| (v <> diag s) |
204 | && v <> trans v |~| ident (cols m) | 214 | && v <> trans v |~| ident (cols m) |
205 | where (s,v) = eigS m | 215 | where (s,v) = eigS m |
206 | (<>) = prod | ||
207 | 216 | ||
208 | eigTestH prod (Her m) = (m <> v) |~~| (v <> diag (comp s)) | 217 | eigTestH (Her m) = (m <> v) |~| (v <> diag (comp s)) |
209 | && v <> (liftMatrix conj) (trans v) |~~| ident (cols m) | 218 | && v <> conjTrans v |~| ident (cols m) |
210 | where (s,v) = eigH m | 219 | where (s,v) = eigH m |
211 | (<>) = prod | ||
212 | |||
213 | linearSolveSQTest fun eqfun singu prod (PairSM a b) = singu a || (a <> fun a b) ==== b | ||
214 | where (<>) = prod | ||
215 | (====) = eqfun | ||
216 | 220 | ||
221 | linearSolveSQTest fun singu (PairSM a b) = singu a || (a <> fun a b) |~| b | ||
217 | 222 | ||
218 | prec = 1E-15 | 223 | prec = 1E-15 |
219 | 224 | ||
@@ -237,8 +242,7 @@ identC n = toComplex(ident n, (0::Double) <>ident n) | |||
237 | 242 | ||
238 | -------------------------------------------------------------------- | 243 | -------------------------------------------------------------------- |
239 | 244 | ||
240 | pinvTest f feq m = (m <> f m <> m) `feq` m | 245 | pinvTest f m = (m <> f m <> m) |~| m |
241 | where (<>) = mulF | ||
242 | 246 | ||
243 | pinvR m = linearSolveLSR m (ident (rows m)) | 247 | pinvR m = linearSolveLSR m (ident (rows m)) |
244 | pinvC m = linearSolveLSC m (ident (rows m)) | 248 | pinvC m = linearSolveLSC m (ident (rows m)) |
@@ -252,7 +256,7 @@ pinvSVDC m = linearSolveSVDC Nothing m (ident (rows m)) | |||
252 | polyEval cs x = foldr (\c ac->ac*x+c) 0 cs | 256 | polyEval cs x = foldr (\c ac->ac*x+c) 0 cs |
253 | 257 | ||
254 | polySolveTest' p = length p <2 || last p == 0|| 1E-8 > maximum (map magnitude $ map (polyEval (map (:+0) p)) (polySolve p)) | 258 | polySolveTest' p = length p <2 || last p == 0|| 1E-8 > maximum (map magnitude $ map (polyEval (map (:+0) p)) (polySolve p)) |
255 | where l1 |~~| l2 = eps > aproxL magnitude l1 l2 | 259 | |
256 | 260 | ||
257 | polySolveTest = assertBool "polySolve" (polySolveTest' [1,2,3,4]) | 261 | polySolveTest = assertBool "polySolve" (polySolveTest' [1,2,3,4]) |
258 | 262 | ||
@@ -267,17 +271,17 @@ quad2 f a b g1 g2 = quad h a b | |||
267 | volSphere r = 8 * quad2 (\x y -> sqrt (r*r-x*x-y*y)) | 271 | volSphere r = 8 * quad2 (\x y -> sqrt (r*r-x*x-y*y)) |
268 | 0 r (const 0) (\x->sqrt (r*r-x*x)) | 272 | 0 r (const 0) (\x->sqrt (r*r-x*x)) |
269 | 273 | ||
270 | integrateTest = assertBool "integrate" (abs (volSphere 2.5 - 4/3*pi*2.5^3) < eps) | 274 | integrateTest = assertBool "integrate" (abs (volSphere 2.5 - 4/3*pi*2.5^3) < epsTol) |
271 | 275 | ||
272 | 276 | ||
273 | --------------------------------------------------------------------- | 277 | --------------------------------------------------------------------- |
274 | 278 | ||
275 | arit1 u = vectorMapValR PowVS 2 (vectorMapR Sin u) | 279 | arit1 u = vectorMapValR PowVS 2 (vectorMapR Sin u) |
276 | `add` vectorMapValR PowVS 2 (vectorMapR Cos u) | 280 | `add` vectorMapValR PowVS 2 (vectorMapR Cos u) |
277 | ~|~ constant 1 (dim u) | 281 | |~| constant 1 (dim u) |
278 | 282 | ||
279 | arit2 u = (vectorMapR Cos u) `mul` (vectorMapR Tan u) | 283 | arit2 u = (vectorMapR Cos u) `mul` (vectorMapR Tan u) |
280 | ~|~ vectorMapR Sin u | 284 | |~| vectorMapR Sin u |
281 | 285 | ||
282 | 286 | ||
283 | -- arit3 (PairV u v) = | 287 | -- arit3 (PairV u v) = |
@@ -305,50 +309,48 @@ tests = TestList | |||
305 | 309 | ||
306 | main = do | 310 | main = do |
307 | putStrLn "--------- general -----" | 311 | putStrLn "--------- general -----" |
308 | quickCheck (\(Sym m) -> m |=| (trans m:: Matrix BaseType)) | 312 | quickCheck (\(Sym m) -> m == (trans m:: Matrix BaseType)) |
309 | quickCheck $ \l -> null l || (toList . fromList) l == (l :: [BaseType]) | 313 | quickCheck $ \l -> null l || (toList . fromList) l == (l :: [BaseType]) |
310 | 314 | ||
311 | quickCheck $ \m -> m |=| asC (m :: Matrix BaseType) | 315 | quickCheck $ \m -> m == asC (m :: Matrix BaseType) |
312 | quickCheck $ \m -> m |=| asFortran (m :: Matrix BaseType) | 316 | quickCheck $ \m -> m == asFortran (m :: Matrix BaseType) |
313 | quickCheck $ \m -> m |=| (asC . asFortran) (m :: Matrix BaseType) | 317 | quickCheck $ \m -> m == (asC . asFortran) (m :: Matrix BaseType) |
314 | putStrLn "--------- MULTIPLY ----" | 318 | putStrLn "--------- MULTIPLY ----" |
315 | quickCheck $ \(PairM m1 m2) -> mulC m1 m2 |=| mulF m1 (m2 :: Matrix BaseType) | 319 | quickCheck $ \(PairM m1 m2) -> mulC m1 m2 == mulF m1 (m2 :: Matrix BaseType) |
316 | quickCheck $ \(PairM m1 m2) -> mulC m1 m2 |=| trans (mulF (trans m2) (trans m1 :: Matrix BaseType)) | 320 | quickCheck $ \(PairM m1 m2) -> mulC m1 m2 == trans (mulF (trans m2) (trans m1 :: Matrix BaseType)) |
317 | quickCheck $ \(PairM m1 m2) -> mulC m1 m2 |=| multiplyG m1 (m2 :: Matrix BaseType) | 321 | quickCheck $ \(PairM m1 m2) -> mulC m1 m2 == multiplyG m1 (m2 :: Matrix BaseType) |
318 | putStrLn "--------- SVD ---------" | 322 | putStrLn "--------- SVD ---------" |
319 | quickCheck (svdTestR svdR mulC) | 323 | quickCheck (svdTestR svdR) |
320 | quickCheck (svdTestR svdR mulF) | 324 | quickCheck (svdTestR svdRdd) |
321 | quickCheck (svdTestR svdRdd mulC) | 325 | -- quickCheck (svdTestR svdg) |
322 | quickCheck (svdTestR svdRdd mulF) | 326 | quickCheck svdTestC |
323 | quickCheck (svdTestC mulC) | ||
324 | quickCheck (svdTestC mulF) | ||
325 | putStrLn "--------- EIG ---------" | 327 | putStrLn "--------- EIG ---------" |
326 | quickCheck (eigTestC mulC) | 328 | quickCheck eigTestC |
327 | quickCheck (eigTestC mulF) | 329 | quickCheck eigTestR |
328 | quickCheck (eigTestR mulC) | 330 | quickCheck eigTestS |
329 | quickCheck (eigTestR mulF) | 331 | quickCheck eigTestH |
330 | quickCheck (eigTestS mulC) | ||
331 | quickCheck (eigTestS mulF) | ||
332 | quickCheck (eigTestH mulC) | ||
333 | quickCheck (eigTestH mulF) | ||
334 | putStrLn "--------- SOLVE ---------" | 332 | putStrLn "--------- SOLVE ---------" |
335 | quickCheck (linearSolveSQTest linearSolveR (|~|) (singular svdR') mulC) | 333 | quickCheck (linearSolveSQTest linearSolveR (singular svdR')) |
336 | quickCheck (linearSolveSQTest linearSolveC (|~~|) (singular svdC') mulF) | 334 | quickCheck (linearSolveSQTest linearSolveC (singular svdC')) |
337 | quickCheck (pinvTest pinvR (|~|)) | 335 | quickCheck (pinvTest pinvR) |
338 | quickCheck (pinvTest pinvC (|~~|)) | 336 | quickCheck (pinvTest pinvC) |
339 | quickCheck (pinvTest pinvSVDR (|~|)) | 337 | quickCheck (pinvTest pinvSVDR) |
340 | quickCheck (pinvTest pinvSVDC (|~~|)) | 338 | quickCheck (pinvTest pinvSVDC) |
341 | putStrLn "--------- VEC OPER ------" | 339 | putStrLn "--------- VEC OPER ------" |
342 | quickCheck arit1 | 340 | quickCheck arit1 |
343 | quickCheck arit2 | 341 | quickCheck arit2 |
344 | putStrLn "--------- GSL ------" | 342 | putStrLn "--------- GSL ------" |
345 | runTestTT tests | 343 | runTestTT tests |
346 | quickCheck $ \v -> ifft (fft v) ~~ v | 344 | quickCheck $ \v -> ifft (fft v) |~| v |
347 | 345 | ||
348 | kk = (2><2) | 346 | kk = (2><2) |
349 | [ 1.0, 0.0 | 347 | [ 1.0, 0.0 |
350 | , -1.5, 1.0 ::Double] | 348 | , -1.5, 1.0 ::Double] |
351 | 349 | ||
352 | v = 11 # [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0::Double] | 350 | v = 11 |> [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0::Double] |
351 | |||
352 | pol = [14.125,-7.666666666666667,-14.3,-13.0,-7.0,-9.6,4.666666666666666,13.0,0.5] | ||
353 | 353 | ||
354 | pol = [14.125,-7.666666666666667,-14.3,-13.0,-7.0,-9.6,4.666666666666666,13.0,0.5] \ No newline at end of file | 354 | mm = (2><2) |
355 | [ 0.5, 0.0 | ||
356 | , 0.0, 0.0 ] :: Matrix Double | ||
diff --git a/lib/Data/Packed/Instances.hs b/lib/Data/Packed/Instances.hs new file mode 100644 index 0000000..4478469 --- /dev/null +++ b/lib/Data/Packed/Instances.hs | |||
@@ -0,0 +1,391 @@ | |||
1 | {-# OPTIONS_GHC -fglasgow-exts #-} | ||
2 | ----------------------------------------------------------------------------- | ||
3 | {- | | ||
4 | Module : Data.Packed.Instances | ||
5 | Copyright : (c) Alberto Ruiz 2006 | ||
6 | License : GPL-style | ||
7 | |||
8 | Maintainer : Alberto Ruiz (aruiz at um dot es) | ||
9 | Stability : provisional | ||
10 | Portability : uses -fffi and -fglasgow-exts | ||
11 | |||
12 | Creates reasonable numeric instances for Vectors and Matrices. In the context of the standard numeric operators, one-component vectors and matrices automatically expand to match the dimensions of the other operand. | ||
13 | |||
14 | -} | ||
15 | ----------------------------------------------------------------------------- | ||
16 | |||
17 | module Data.Packed.Instances( | ||
18 | Contractible(..) | ||
19 | ) where | ||
20 | |||
21 | import Data.Packed.Internal | ||
22 | import Data.Packed.Vector | ||
23 | import Data.Packed.Matrix | ||
24 | import GSL.Vector | ||
25 | import GSL.Matrix | ||
26 | import LinearAlgebra.Algorithms | ||
27 | import Complex | ||
28 | |||
29 | instance (Eq a, Field a) => Eq (Vector a) where | ||
30 | a == b = dim a == dim b && toList a == toList b | ||
31 | |||
32 | instance (Num a, Field a) => Num (Vector a) where | ||
33 | (+) = add | ||
34 | (-) = sub | ||
35 | (*) = mul | ||
36 | signum = liftVector signum | ||
37 | abs = liftVector abs | ||
38 | fromInteger = fromList . return . fromInteger | ||
39 | |||
40 | instance (Eq a, Field a) => Eq (Matrix a) where | ||
41 | a == b = rows a == rows b && cols a == cols b && cdat a == cdat b && fdat a == fdat b | ||
42 | |||
43 | instance (Num a, Field a) => Num (Matrix a) where | ||
44 | (+) = liftMatrix2 add | ||
45 | (-) = liftMatrix2 sub | ||
46 | (*) = liftMatrix2 mul | ||
47 | signum = liftMatrix signum | ||
48 | abs = liftMatrix abs | ||
49 | fromInteger = (1><1) . return . fromInteger | ||
50 | |||
51 | --------------------------------------------------- | ||
52 | |||
53 | adaptScalar f1 f2 f3 x y | ||
54 | | dim x == 1 = f1 (x@>0) y | ||
55 | | dim y == 1 = f3 x (y@>0) | ||
56 | | otherwise = f2 x y | ||
57 | |||
58 | {- | ||
59 | subvv = vectorZip 4 | ||
60 | subvc v c = addConstant (-c) v | ||
61 | subcv c v = addConstant c (scale (-1) v) | ||
62 | |||
63 | mul = vectorZip 1 | ||
64 | |||
65 | instance Num (Vector Double) where | ||
66 | (+) = adaptScalar addConstant add (flip addConstant) | ||
67 | (-) = adaptScalar subcv subvv subvc | ||
68 | (*) = adaptScalar scale mul (flip scale) | ||
69 | abs = vectorMap 3 | ||
70 | signum = vectorMap 15 | ||
71 | fromInteger n = fromList [fromInteger n] | ||
72 | |||
73 | ---------------------------------------------------- | ||
74 | |||
75 | --addConstantC a = gmap (+a) | ||
76 | --subCvv u v = u `add` scale (-1) v | ||
77 | subCvv = vectorZipComplex 4 -- faster? | ||
78 | subCvc v c = addConstantC (-c) v | ||
79 | subCcv c v = addConstantC c (scale (-1) v) | ||
80 | |||
81 | |||
82 | instance Num (Vector (Complex Double)) where | ||
83 | (+) = adaptScalar addConstantC add (flip addConstantC) | ||
84 | (-) = adaptScalar subCcv subCvv subCvc | ||
85 | (*) = adaptScalar scale (vectorZipComplex 1) (flip scale) | ||
86 | abs = gmap abs | ||
87 | signum = gmap signum | ||
88 | fromInteger n = fromList [fromInteger n] | ||
89 | |||
90 | |||
91 | -- | adapts a function on two vectors to work on all the elements of two matrices | ||
92 | liftMatrix2' :: (Vector a -> Vector b -> Vector c) -> Matrix a -> Matrix b -> Matrix c | ||
93 | liftMatrix2' f m1@(M r1 c1 _) m2@(M r2 c2 _) | ||
94 | | sameShape m1 m2 || r1*c1==1 || r2*c2==1 | ||
95 | = reshape (max c1 c2) $ f (flatten m1) (flatten m2) | ||
96 | | otherwise = error "inconsistent matrix dimensions" | ||
97 | |||
98 | --------------------------------------------------- | ||
99 | |||
100 | instance (Eq a, Field a) => Eq (Matrix a) where | ||
101 | a == b = rows a == rows b && cdat a == cdat b | ||
102 | |||
103 | instance Num (Matrix Double) where | ||
104 | (+) = liftMatrix2' (+) | ||
105 | (-) = liftMatrix2' (-) | ||
106 | (*) = liftMatrix2' (*) | ||
107 | abs = liftMatrix abs | ||
108 | signum = liftMatrix signum | ||
109 | fromInteger n = fromLists [[fromInteger n]] | ||
110 | |||
111 | ---------------------------------------------------- | ||
112 | |||
113 | instance Num (Matrix (Complex Double)) where | ||
114 | (+) = liftMatrix2' (+) | ||
115 | (-) = liftMatrix2' (-) | ||
116 | (*) = liftMatrix2' (*) | ||
117 | abs = liftMatrix abs | ||
118 | signum = liftMatrix signum | ||
119 | fromInteger n = fromLists [[fromInteger n]] | ||
120 | |||
121 | ------------------------------------------------------ | ||
122 | |||
123 | instance Fractional (Vector Double) where | ||
124 | fromRational n = fromList [fromRational n] | ||
125 | (/) = adaptScalar f (vectorZip 2) g where | ||
126 | r `f` v = vectorZip 2 (constant r (dim v)) v | ||
127 | v `g` r = scale (recip r) v | ||
128 | |||
129 | ------------------------------------------------------- | ||
130 | |||
131 | instance Fractional (Vector (Complex Double)) where | ||
132 | fromRational n = fromList [fromRational n] | ||
133 | (/) = adaptScalar f (vectorZipComplex 2) g where | ||
134 | r `f` v = gmap ((*r).recip) v | ||
135 | v `g` r = gmap (/r) v | ||
136 | |||
137 | ------------------------------------------------------ | ||
138 | |||
139 | instance Fractional (Matrix Double) where | ||
140 | fromRational n = fromLists [[fromRational n]] | ||
141 | (/) = liftMatrix2' (/) | ||
142 | |||
143 | ------------------------------------------------------- | ||
144 | |||
145 | instance Fractional (Matrix (Complex Double)) where | ||
146 | fromRational n = fromLists [[fromRational n]] | ||
147 | (/) = liftMatrix2' (/) | ||
148 | |||
149 | --------------------------------------------------------- | ||
150 | |||
151 | instance Floating (Vector Double) where | ||
152 | sin = vectorMap 0 | ||
153 | cos = vectorMap 1 | ||
154 | tan = vectorMap 2 | ||
155 | asin = vectorMap 4 | ||
156 | acos = vectorMap 5 | ||
157 | atan = vectorMap 6 | ||
158 | sinh = vectorMap 7 | ||
159 | cosh = vectorMap 8 | ||
160 | tanh = vectorMap 9 | ||
161 | asinh = vectorMap 10 | ||
162 | acosh = vectorMap 11 | ||
163 | atanh = vectorMap 12 | ||
164 | exp = vectorMap 13 | ||
165 | log = vectorMap 14 | ||
166 | sqrt = vectorMap 16 | ||
167 | (**) = adaptScalar f (vectorZip 5) g where f s v = constant s (dim v) ** v | ||
168 | g v s = v ** constant s (dim v) | ||
169 | pi = fromList [pi] | ||
170 | |||
171 | ----------------------------------------------------------- | ||
172 | |||
173 | instance Floating (Matrix Double) where | ||
174 | sin = liftMatrix sin | ||
175 | cos = liftMatrix cos | ||
176 | tan = liftMatrix tan | ||
177 | asin = liftMatrix asin | ||
178 | acos = liftMatrix acos | ||
179 | atan = liftMatrix atan | ||
180 | sinh = liftMatrix sinh | ||
181 | cosh = liftMatrix cosh | ||
182 | tanh = liftMatrix tanh | ||
183 | asinh = liftMatrix asinh | ||
184 | acosh = liftMatrix acosh | ||
185 | atanh = liftMatrix atanh | ||
186 | exp = liftMatrix exp | ||
187 | log = liftMatrix log | ||
188 | sqrt = liftMatrix sqrt | ||
189 | (**) = liftMatrix2 (**) | ||
190 | pi = fromLists [[pi]] | ||
191 | |||
192 | ------------------------------------------------------------- | ||
193 | |||
194 | instance Floating (Vector (Complex Double)) where | ||
195 | sin = vectorMapComplex 0 | ||
196 | cos = vectorMapComplex 1 | ||
197 | tan = vectorMapComplex 2 | ||
198 | asin = vectorMapComplex 4 | ||
199 | acos = vectorMapComplex 5 | ||
200 | atan = vectorMapComplex 6 | ||
201 | sinh = vectorMapComplex 7 | ||
202 | cosh = vectorMapComplex 8 | ||
203 | tanh = vectorMapComplex 9 | ||
204 | asinh = vectorMapComplex 10 | ||
205 | acosh = vectorMapComplex 11 | ||
206 | atanh = vectorMapComplex 12 | ||
207 | exp = vectorMapComplex 13 | ||
208 | log = vectorMapComplex 14 | ||
209 | sqrt = vectorMapComplex 16 | ||
210 | (**) = adaptScalar f (vectorZipComplex 5) g where f s v = constantC s (dim v) ** v | ||
211 | g v s = v ** constantC s (dim v) | ||
212 | pi = fromList [pi] | ||
213 | |||
214 | --------------------------------------------------------------- | ||
215 | |||
216 | instance Floating (Matrix (Complex Double)) where | ||
217 | sin = liftMatrix sin | ||
218 | cos = liftMatrix cos | ||
219 | tan = liftMatrix tan | ||
220 | asin = liftMatrix asin | ||
221 | acos = liftMatrix acos | ||
222 | atan = liftMatrix atan | ||
223 | sinh = liftMatrix sinh | ||
224 | cosh = liftMatrix cosh | ||
225 | tanh = liftMatrix tanh | ||
226 | asinh = liftMatrix asinh | ||
227 | acosh = liftMatrix acosh | ||
228 | atanh = liftMatrix atanh | ||
229 | exp = liftMatrix exp | ||
230 | log = liftMatrix log | ||
231 | (**) = liftMatrix2 (**) | ||
232 | sqrt = liftMatrix sqrt | ||
233 | pi = fromLists [[pi]] | ||
234 | |||
235 | --------------------------------------------------------------- | ||
236 | -} | ||
237 | |||
238 | class Contractible a b c | a b -> c where | ||
239 | infixl 7 <> | ||
240 | {- | An overloaded operator for matrix products, matrix-vector and vector-matrix products, dot products and scaling of vectors and matrices. Type consistency is statically checked. Alternatively, you can use the specific functions described below, but using this operator you can automatically combine real and complex objects. | ||
241 | |||
242 | @v = 'fromList' [1,2,3] :: Vector Double | ||
243 | cv = 'fromList' [1+'i',2] | ||
244 | m = 'fromLists' [[1,2,3], | ||
245 | [4,5,7]] :: Matrix Double | ||
246 | cm = 'fromLists' [[ 1, 2], | ||
247 | [3+'i',7*'i'], | ||
248 | [ 'i', 1]] | ||
249 | \ | ||
250 | \> m \<\> v | ||
251 | 14. 35. | ||
252 | \ | ||
253 | \> cv \<\> m | ||
254 | 9.+1.i 12.+2.i 17.+3.i | ||
255 | \ | ||
256 | \> m \<\> cm | ||
257 | 7.+5.i 5.+14.i | ||
258 | 19.+12.i 15.+35.i | ||
259 | \ | ||
260 | \> v \<\> 'i' | ||
261 | 1.i 2.i 3.i | ||
262 | \ | ||
263 | \> v \<\> v | ||
264 | 14.0 | ||
265 | \ | ||
266 | \> cv \<\> cv | ||
267 | 4.0 :+ 2.0@ | ||
268 | |||
269 | -} | ||
270 | (<>) :: a -> b -> c | ||
271 | |||
272 | |||
273 | instance Contractible Double Double Double where | ||
274 | (<>) = (*) | ||
275 | |||
276 | instance Contractible Double (Complex Double) (Complex Double) where | ||
277 | a <> b = (a:+0) * b | ||
278 | |||
279 | instance Contractible (Complex Double) Double (Complex Double) where | ||
280 | a <> b = a * (b:+0) | ||
281 | |||
282 | instance Contractible (Complex Double) (Complex Double) (Complex Double) where | ||
283 | (<>) = (*) | ||
284 | |||
285 | --------------------------------- matrix matrix | ||
286 | |||
287 | instance Contractible (Matrix Double) (Matrix Double) (Matrix Double) where | ||
288 | (<>) = mXm | ||
289 | |||
290 | instance Contractible (Matrix (Complex Double)) (Matrix (Complex Double)) (Matrix (Complex Double)) where | ||
291 | (<>) = mXm | ||
292 | |||
293 | instance Contractible (Matrix (Complex Double)) (Matrix Double) (Matrix (Complex Double)) where | ||
294 | c <> r = c <> liftMatrix comp r | ||
295 | |||
296 | instance Contractible (Matrix Double) (Matrix (Complex Double)) (Matrix (Complex Double)) where | ||
297 | r <> c = liftMatrix comp r <> c | ||
298 | |||
299 | --------------------------------- (Matrix Double) (Vector Double) | ||
300 | |||
301 | instance Contractible (Matrix Double) (Vector Double) (Vector Double) where | ||
302 | (<>) = mXv | ||
303 | |||
304 | instance Contractible (Matrix (Complex Double)) (Vector (Complex Double)) (Vector (Complex Double)) where | ||
305 | (<>) = mXv | ||
306 | |||
307 | instance Contractible (Matrix (Complex Double)) (Vector Double) (Vector (Complex Double)) where | ||
308 | m <> v = m <> comp v | ||
309 | |||
310 | instance Contractible (Matrix Double) (Vector (Complex Double)) (Vector (Complex Double)) where | ||
311 | m <> v = liftMatrix comp m <> v | ||
312 | |||
313 | --------------------------------- (Vector Double) (Matrix Double) | ||
314 | |||
315 | instance Contractible (Vector Double) (Matrix Double) (Vector Double) where | ||
316 | (<>) = vXm | ||
317 | |||
318 | instance Contractible (Vector (Complex Double)) (Matrix (Complex Double)) (Vector (Complex Double)) where | ||
319 | (<>) = vXm | ||
320 | |||
321 | instance Contractible (Vector (Complex Double)) (Matrix Double) (Vector (Complex Double)) where | ||
322 | v <> m = v <> liftMatrix comp m | ||
323 | |||
324 | instance Contractible (Vector Double) (Matrix (Complex Double)) (Vector (Complex Double)) where | ||
325 | v <> m = comp v <> m | ||
326 | |||
327 | --------------------------------- dot product | ||
328 | |||
329 | instance Contractible (Vector Double) (Vector Double) Double where | ||
330 | (<>) = dot | ||
331 | |||
332 | instance Contractible (Vector (Complex Double)) (Vector (Complex Double)) (Complex Double) where | ||
333 | (<>) = dot | ||
334 | |||
335 | instance Contractible (Vector Double) (Vector (Complex Double)) (Complex Double) where | ||
336 | a <> b = comp a <> b | ||
337 | |||
338 | instance Contractible (Vector (Complex Double)) (Vector Double) (Complex Double) where | ||
339 | (<>) = flip (<>) | ||
340 | |||
341 | --------------------------------- scaling vectors | ||
342 | |||
343 | instance Contractible Double (Vector Double) (Vector Double) where | ||
344 | (<>) = scale | ||
345 | |||
346 | instance Contractible (Vector Double) Double (Vector Double) where | ||
347 | (<>) = flip (<>) | ||
348 | |||
349 | instance Contractible (Complex Double) (Vector (Complex Double)) (Vector (Complex Double)) where | ||
350 | (<>) = scale | ||
351 | |||
352 | instance Contractible (Vector (Complex Double)) (Complex Double) (Vector (Complex Double)) where | ||
353 | (<>) = flip (<>) | ||
354 | |||
355 | instance Contractible Double (Vector (Complex Double)) (Vector (Complex Double)) where | ||
356 | a <> v = (a:+0) <> v | ||
357 | |||
358 | instance Contractible (Vector (Complex Double)) Double (Vector (Complex Double)) where | ||
359 | (<>) = flip (<>) | ||
360 | |||
361 | instance Contractible (Complex Double) (Vector Double) (Vector (Complex Double)) where | ||
362 | a <> v = a <> comp v | ||
363 | |||
364 | instance Contractible (Vector Double) (Complex Double) (Vector (Complex Double)) where | ||
365 | (<>) = flip (<>) | ||
366 | |||
367 | --------------------------------- scaling matrices | ||
368 | |||
369 | instance Contractible Double (Matrix Double) (Matrix Double) where | ||
370 | (<>) a = liftMatrix (a <>) | ||
371 | |||
372 | instance Contractible (Matrix Double) Double (Matrix Double) where | ||
373 | (<>) = flip (<>) | ||
374 | |||
375 | instance Contractible (Complex Double) (Matrix (Complex Double)) (Matrix (Complex Double)) where | ||
376 | (<>) a = liftMatrix (a <>) | ||
377 | |||
378 | instance Contractible (Matrix (Complex Double)) (Complex Double) (Matrix (Complex Double)) where | ||
379 | (<>) = flip (<>) | ||
380 | |||
381 | instance Contractible Double (Matrix (Complex Double)) (Matrix (Complex Double)) where | ||
382 | a <> m = (a:+0) <> m | ||
383 | |||
384 | instance Contractible (Matrix (Complex Double)) Double (Matrix (Complex Double)) where | ||
385 | (<>) = flip (<>) | ||
386 | |||
387 | instance Contractible (Complex Double) (Matrix Double) (Matrix (Complex Double)) where | ||
388 | a <> m = a <> liftMatrix comp m | ||
389 | |||
390 | instance Contractible (Matrix Double) (Complex Double) (Matrix (Complex Double)) where | ||
391 | (<>) = flip (<>) | ||
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 9309d1d..dd33943 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -93,6 +93,15 @@ createMatrix order r c = do | |||
93 | p <- createVector (r*c) | 93 | p <- createVector (r*c) |
94 | return (matrixFromVector order c p) | 94 | return (matrixFromVector order c p) |
95 | 95 | ||
96 | {- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. | ||
97 | |||
98 | @\> reshape 4 ('fromList' [1..12]) | ||
99 | (3><4) | ||
100 | [ 1.0, 2.0, 3.0, 4.0 | ||
101 | , 5.0, 6.0, 7.0, 8.0 | ||
102 | , 9.0, 10.0, 11.0, 12.0 ]@ | ||
103 | |||
104 | -} | ||
96 | reshape :: (Field t) => Int -> Vector t -> Matrix t | 105 | reshape :: (Field t) => Int -> Vector t -> Matrix t |
97 | reshape c v = matrixFromVector RowMajor c v | 106 | reshape c v = matrixFromVector RowMajor c v |
98 | 107 | ||
@@ -140,7 +149,6 @@ liftMatrix f m = m { dat = f (dat m), tdat = f (tdat m) } -- check sizes | |||
140 | 149 | ||
141 | liftMatrix2 :: (Field t) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t | 150 | liftMatrix2 :: (Field t) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t |
142 | liftMatrix2 f m1 m2 = reshape (cols m1) (f (cdat m1) (cdat m2)) -- check sizes | 151 | liftMatrix2 f m1 m2 = reshape (cols m1) (f (cdat m1) (cdat m2)) -- check sizes |
143 | |||
144 | ------------------------------------------------------------------ | 152 | ------------------------------------------------------------------ |
145 | 153 | ||
146 | dotL a b = sum (zipWith (*) a b) | 154 | dotL a b = sum (zipWith (*) a b) |
@@ -200,6 +208,14 @@ multiplyD order a b | |||
200 | 208 | ||
201 | outer' u v = dat (outer u v) | 209 | outer' u v = dat (outer u v) |
202 | 210 | ||
211 | {- | Outer product of two vectors. | ||
212 | |||
213 | @\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3] | ||
214 | (3><3) | ||
215 | [ 5.0, 2.0, 3.0 | ||
216 | , 10.0, 4.0, 6.0 | ||
217 | , 15.0, 6.0, 9.0 ]@ | ||
218 | -} | ||
203 | outer :: (Num t, Field t) => Vector t -> Vector t -> Matrix t | 219 | outer :: (Num t, Field t) => Vector t -> Vector t -> Matrix t |
204 | outer u v = multiply RowMajor r c | 220 | outer u v = multiply RowMajor r c |
205 | where r = matrixFromVector RowMajor 1 u | 221 | where r = matrixFromVector RowMajor 1 u |
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs index 25e848d..f1addf4 100644 --- a/lib/Data/Packed/Internal/Vector.hs +++ b/lib/Data/Packed/Internal/Vector.hs | |||
@@ -48,7 +48,7 @@ fromList l = unsafePerformIO $ do | |||
48 | toList :: Storable a => Vector a -> [a] | 48 | toList :: Storable a => Vector a -> [a] |
49 | toList v = unsafePerformIO $ peekArray (dim v) (ptr v) | 49 | toList v = unsafePerformIO $ peekArray (dim v) (ptr v) |
50 | 50 | ||
51 | n # l = if length l == n then fromList l else error "# with wrong size" | 51 | n |> l = if length l == n then fromList l else error "|> with wrong size" |
52 | 52 | ||
53 | at' :: Storable a => Vector a -> Int -> a | 53 | at' :: Storable a => Vector a -> Int -> a |
54 | at' v n = unsafePerformIO $ peekElemOff (ptr v) n | 54 | at' v n = unsafePerformIO $ peekElemOff (ptr v) n |
@@ -58,7 +58,7 @@ at v n | n >= 0 && n < dim v = at' v n | |||
58 | | otherwise = error "vector index out of range" | 58 | | otherwise = error "vector index out of range" |
59 | 59 | ||
60 | instance (Show a, Storable a) => (Show (Vector a)) where | 60 | instance (Show a, Storable a) => (Show (Vector a)) where |
61 | show v = (show (dim v))++" # " ++ show (toList v) | 61 | show v = (show (dim v))++" |> " ++ show (toList v) |
62 | 62 | ||
63 | -- | creates a Vector taking a number of consecutive toList from another Vector | 63 | -- | creates a Vector taking a number of consecutive toList from another Vector |
64 | subVector :: Storable t => Int -- ^ index of the starting element | 64 | subVector :: Storable t => Int -- ^ index of the starting element |
@@ -129,3 +129,5 @@ constant x n | isReal id x = scast $ constantR (scast x) n | |||
129 | | isComp id x = scast $ constantC (scast x) n | 129 | | isComp id x = scast $ constantC (scast x) n |
130 | | otherwise = constantG x n | 130 | | otherwise = constantG x n |
131 | 131 | ||
132 | liftVector f = fromList . map f . toList | ||
133 | liftVector2 f u v = fromList $ zipWith f (toList u) (toList v) | ||
diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs index 0f9d998..36bf32e 100644 --- a/lib/Data/Packed/Matrix.hs +++ b/lib/Data/Packed/Matrix.hs | |||
@@ -13,11 +13,11 @@ | |||
13 | ----------------------------------------------------------------------------- | 13 | ----------------------------------------------------------------------------- |
14 | 14 | ||
15 | module Data.Packed.Matrix ( | 15 | module Data.Packed.Matrix ( |
16 | Matrix(rows,cols), Field, | 16 | Matrix(rows,cols), |
17 | fromLists, toLists, (><), (>|<), (@@>), | 17 | fromLists, toLists, (><), (>|<), (@@>), |
18 | trans, conjTrans, | 18 | trans, conjTrans, |
19 | reshape, flatten, | 19 | reshape, flatten, asRow, asColumn, |
20 | fromRows, toRows, fromColumns, toColumns, | 20 | fromRows, toRows, fromColumns, toColumns, fromBlocks, |
21 | joinVert, joinHoriz, | 21 | joinVert, joinHoriz, |
22 | flipud, fliprl, | 22 | flipud, fliprl, |
23 | liftMatrix, liftMatrix2, | 23 | liftMatrix, liftMatrix2, |
@@ -43,6 +43,22 @@ joinVert ms = case common cols ms of | |||
43 | joinHoriz :: Field t => [Matrix t] -> Matrix t | 43 | joinHoriz :: Field t => [Matrix t] -> Matrix t |
44 | joinHoriz ms = trans. joinVert . map trans $ ms | 44 | joinHoriz ms = trans. joinVert . map trans $ ms |
45 | 45 | ||
46 | {- | Creates a matrix from blocks given as a list of lists of matrices: | ||
47 | |||
48 | @\> let a = 'diag' $ 'fromList' [5,7,2] | ||
49 | \> let b = 'reshape' 4 $ 'constant' (-1) 12 | ||
50 | \> fromBlocks [[a,b],[b,a]] | ||
51 | (6><7) | ||
52 | [ 5.0, 0.0, 0.0, -1.0, -1.0, -1.0, -1.0 | ||
53 | , 0.0, 7.0, 0.0, -1.0, -1.0, -1.0, -1.0 | ||
54 | , 0.0, 0.0, 2.0, -1.0, -1.0, -1.0, -1.0 | ||
55 | , -1.0, -1.0, -1.0, -1.0, 5.0, 0.0, 0.0 | ||
56 | , -1.0, -1.0, -1.0, -1.0, 0.0, 7.0, 0.0 | ||
57 | , -1.0, -1.0, -1.0, -1.0, 0.0, 0.0, 2.0 ]@ | ||
58 | -} | ||
59 | fromBlocks :: Field t => [[Matrix t]] -> Matrix t | ||
60 | fromBlocks = joinVert . map joinHoriz | ||
61 | |||
46 | -- | Reverse rows | 62 | -- | Reverse rows |
47 | flipud :: Field t => Matrix t -> Matrix t | 63 | flipud :: Field t => Matrix t -> Matrix t |
48 | flipud m = fromRows . reverse . toRows $ m | 64 | flipud m = fromRows . reverse . toRows $ m |
@@ -98,6 +114,11 @@ dropColumns n mat = subMatrix (0,n) (rows mat, cols mat - n) mat | |||
98 | 114 | ||
99 | ---------------------------------------------------------------- | 115 | ---------------------------------------------------------------- |
100 | 116 | ||
117 | {- | Creates a vector by concatenation of rows | ||
118 | |||
119 | @\> flatten ('ident' 3) | ||
120 | 9 # [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]@ | ||
121 | -} | ||
101 | flatten :: Matrix t -> Vector t | 122 | flatten :: Matrix t -> Vector t |
102 | flatten = cdat | 123 | flatten = cdat |
103 | 124 | ||
@@ -106,4 +127,10 @@ fromLists :: Field t => [[t]] -> Matrix t | |||
106 | fromLists = fromRows . map fromList | 127 | fromLists = fromRows . map fromList |
107 | 128 | ||
108 | conjTrans :: Matrix (Complex Double) -> Matrix (Complex Double) | 129 | conjTrans :: Matrix (Complex Double) -> Matrix (Complex Double) |
109 | conjTrans = trans . liftMatrix conj \ No newline at end of file | 130 | conjTrans = trans . liftMatrix conj |
131 | |||
132 | asRow :: Field a => Vector a -> Matrix a | ||
133 | asRow v = reshape (dim v) v | ||
134 | |||
135 | asColumn :: Field a => Vector a -> Matrix a | ||
136 | asColumn v = reshape 1 v | ||
diff --git a/lib/Data/Packed/Vector.hs b/lib/Data/Packed/Vector.hs index 9d9d879..94f70be 100644 --- a/lib/Data/Packed/Vector.hs +++ b/lib/Data/Packed/Vector.hs | |||
@@ -15,7 +15,7 @@ | |||
15 | module Data.Packed.Vector ( | 15 | module Data.Packed.Vector ( |
16 | Vector(dim), Field, | 16 | Vector(dim), Field, |
17 | fromList, toList, | 17 | fromList, toList, |
18 | at, | 18 | (@>), |
19 | subVector, join, | 19 | subVector, join, |
20 | constant, | 20 | constant, |
21 | toComplex, comp, | 21 | toComplex, comp, |
@@ -26,6 +26,7 @@ module Data.Packed.Vector ( | |||
26 | 26 | ||
27 | import Data.Packed.Internal | 27 | import Data.Packed.Internal |
28 | import Complex | 28 | import Complex |
29 | import GSL.Vector | ||
29 | 30 | ||
30 | -- | creates a complex vector from vectors with real and imaginary parts | 31 | -- | creates a complex vector from vectors with real and imaginary parts |
31 | toComplex :: (Vector Double, Vector Double) -> Vector (Complex Double) | 32 | toComplex :: (Vector Double, Vector Double) -> Vector (Complex Double) |
@@ -41,10 +42,14 @@ comp v = toComplex (v,constant 0 (dim v)) | |||
41 | 42 | ||
42 | {- | Creates a real vector containing a range of values: | 43 | {- | Creates a real vector containing a range of values: |
43 | 44 | ||
44 | > > linspace 10 (-2,2) | 45 | @\> linspace 5 (-3,7) |
45 | >-2. -1.556 -1.111 -0.667 -0.222 0.222 0.667 1.111 1.556 2. | 46 | 5 |> [-3.0,-0.5,2.0,4.5,7.0]@ |
46 | |||
47 | -} | 47 | -} |
48 | linspace :: Int -> (Double, Double) -> Vector Double | 48 | linspace :: Int -> (Double, Double) -> Vector Double |
49 | linspace n (a,b) = fromList [a::Double,a+delta .. b] | 49 | linspace n (a,b) = fromList [a::Double,a+delta .. b] |
50 | where delta = (b-a)/(fromIntegral n -1) | 50 | where delta = (b-a)/(fromIntegral n -1) |
51 | |||
52 | -- | Reads a vector position. | ||
53 | (@>) :: Field t => Vector t -> Int -> t | ||
54 | infixl 9 @> | ||
55 | (@>) = at | ||
diff --git a/lib/GSL.hs b/lib/GSL.hs new file mode 100644 index 0000000..8e033c3 --- /dev/null +++ b/lib/GSL.hs | |||
@@ -0,0 +1,45 @@ | |||
1 | {- | | ||
2 | |||
3 | Module : GSL | ||
4 | Copyright : (c) Alberto Ruiz 2006-7 | ||
5 | License : GPL-style | ||
6 | |||
7 | Maintainer : Alberto Ruiz (aruiz at um dot es) | ||
8 | Stability : provisional | ||
9 | Portability : uses -fffi and -fglasgow-exts | ||
10 | |||
11 | This module reexports the basic functionality and a collection of utilities (old interface) | ||
12 | |||
13 | -} | ||
14 | |||
15 | module GSL ( | ||
16 | |||
17 | module Data.Packed.Vector, | ||
18 | module Data.Packed.Matrix, | ||
19 | module Data.Packed.Tensor, | ||
20 | module Data.Packed.Instances, | ||
21 | module LinearAlgebra.Algorithms, | ||
22 | module LAPACK, | ||
23 | module GSL.Integration, | ||
24 | module GSL.Differentiation, | ||
25 | module GSL.Special, | ||
26 | module GSL.Fourier, | ||
27 | module GSL.Polynomials, | ||
28 | module GSL.Minimization, | ||
29 | module Data.Packed.Plot | ||
30 | |||
31 | ) where | ||
32 | |||
33 | import Data.Packed.Vector | ||
34 | import Data.Packed.Matrix | ||
35 | import Data.Packed.Tensor | ||
36 | import Data.Packed.Instances | ||
37 | import LinearAlgebra.Algorithms | ||
38 | import LAPACK | ||
39 | import GSL.Integration | ||
40 | import GSL.Differentiation | ||
41 | import GSL.Special | ||
42 | import GSL.Fourier | ||
43 | import GSL.Polynomials | ||
44 | import GSL.Minimization | ||
45 | import Data.Packed.Plot | ||
diff --git a/lib/GSL/Matrix.hs b/lib/GSL/Matrix.hs new file mode 100644 index 0000000..919c2d9 --- /dev/null +++ b/lib/GSL/Matrix.hs | |||
@@ -0,0 +1,311 @@ | |||
1 | ----------------------------------------------------------------------------- | ||
2 | -- | | ||
3 | -- Module : GSL.Matrix | ||
4 | -- Copyright : (c) Alberto Ruiz 2007 | ||
5 | -- License : GPL-style | ||
6 | -- | ||
7 | -- Maintainer : Alberto Ruiz <aruiz@um.es> | ||
8 | -- Stability : provisional | ||
9 | -- Portability : portable (uses FFI) | ||
10 | -- | ||
11 | -- A few linear algebra computations based on the GSL (<http://www.gnu.org/software/gsl>). | ||
12 | -- | ||
13 | ----------------------------------------------------------------------------- | ||
14 | |||
15 | module GSL.Matrix( | ||
16 | eigSg, eigHg, | ||
17 | svdg, | ||
18 | qr, | ||
19 | chol, | ||
20 | luSolveR, luSolveC, | ||
21 | luR, luC, | ||
22 | fromFile | ||
23 | ) where | ||
24 | |||
25 | import Data.Packed.Internal | ||
26 | import Data.Packed.Matrix(fromLists,ident,takeDiag) | ||
27 | import GSL.Vector | ||
28 | import Foreign | ||
29 | import Foreign.C.Types | ||
30 | import Complex | ||
31 | import Foreign.C.String | ||
32 | |||
33 | {- | eigendecomposition of a real symmetric matrix using /gsl_eigen_symmv/. | ||
34 | |||
35 | > > let (l,v) = eigS $ 'fromLists' [[1,2],[2,1]] | ||
36 | > > l | ||
37 | > 3.000 -1.000 | ||
38 | > | ||
39 | > > v | ||
40 | > 0.707 -0.707 | ||
41 | > 0.707 0.707 | ||
42 | > | ||
43 | > > v <> diag l <> trans v | ||
44 | > 1.000 2.000 | ||
45 | > 2.000 1.000 | ||
46 | |||
47 | -} | ||
48 | eigSg :: Matrix Double -> (Vector Double, Matrix Double) | ||
49 | eigSg (m@M {rows = r}) | ||
50 | | r == 1 = (fromList [cdat m `at` 0], singleton 1) | ||
51 | | otherwise = unsafePerformIO $ do | ||
52 | l <- createVector r | ||
53 | v <- createMatrix RowMajor r r | ||
54 | c_eigS // mat cdat m // vec l // mat dat v // check "eigSg" [cdat m] | ||
55 | return (l,v) | ||
56 | foreign import ccall "gsl-aux.h eigensystemR" c_eigS :: TMVM | ||
57 | |||
58 | ------------------------------------------------------------------ | ||
59 | |||
60 | |||
61 | |||
62 | {- | eigendecomposition of a complex hermitian matrix using /gsl_eigen_hermv/ | ||
63 | |||
64 | > > let (l,v) = eigH $ 'fromLists' [[1,2+i],[2-i,3]] | ||
65 | > | ||
66 | > > l | ||
67 | > 4.449 -0.449 | ||
68 | > | ||
69 | > > v | ||
70 | > -0.544 0.839 | ||
71 | > (-0.751,0.375) (-0.487,0.243) | ||
72 | > | ||
73 | > > v <> diag l <> (conjTrans) v | ||
74 | > 1.000 (2.000,1.000) | ||
75 | > (2.000,-1.000) 3.000 | ||
76 | |||
77 | -} | ||
78 | eigHg :: Matrix (Complex Double)-> (Vector Double, Matrix (Complex Double)) | ||
79 | eigHg (m@M {rows = r}) | ||
80 | | r == 1 = (fromList [realPart $ cdat m `at` 0], singleton 1) | ||
81 | | otherwise = unsafePerformIO $ do | ||
82 | l <- createVector r | ||
83 | v <- createMatrix RowMajor r r | ||
84 | c_eigH // mat cdat m // vec l // mat dat v // check "eigHg" [cdat m] | ||
85 | return (l,v) | ||
86 | foreign import ccall "gsl-aux.h eigensystemC" c_eigH :: TCMVCM | ||
87 | |||
88 | |||
89 | {- | Singular value decomposition of a real matrix, using /gsl_linalg_SV_decomp_mod/: | ||
90 | |||
91 | @\> let (u,s,v) = svdg $ 'fromLists' [[1,2,3],[-4,1,7]] | ||
92 | \ | ||
93 | \> u | ||
94 | 0.310 -0.951 | ||
95 | 0.951 0.310 | ||
96 | \ | ||
97 | \> s | ||
98 | 8.497 2.792 | ||
99 | \ | ||
100 | \> v | ||
101 | -0.411 -0.785 | ||
102 | 0.185 -0.570 | ||
103 | 0.893 -0.243 | ||
104 | \ | ||
105 | \> u \<\> 'diag' s \<\> 'trans' v | ||
106 | 1. 2. 3. | ||
107 | -4. 1. 7.@ | ||
108 | |||
109 | -} | ||
110 | svdg :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) | ||
111 | svdg x@M {rows = r, cols = c} = if r>=c | ||
112 | then svd' x | ||
113 | else (v, s, u) where (u,s,v) = svd' (trans x) | ||
114 | |||
115 | svd' x@M {rows = r, cols = c} = unsafePerformIO $ do | ||
116 | u <- createMatrix RowMajor r c | ||
117 | s <- createVector c | ||
118 | v <- createMatrix RowMajor c c | ||
119 | c_svd // mat cdat x // mat dat u // vec s // mat dat v // check "svdg" [cdat x] | ||
120 | return (u,s,v) | ||
121 | foreign import ccall "gsl-aux.h svd" c_svd :: TMMVM | ||
122 | |||
123 | {- | QR decomposition of a real matrix using /gsl_linalg_QR_decomp/ and /gsl_linalg_QR_unpack/. | ||
124 | |||
125 | @\> let (q,r) = qr $ 'fromLists' [[1,3,5,7],[2,0,-2,4]] | ||
126 | \ | ||
127 | \> q | ||
128 | -0.447 -0.894 | ||
129 | -0.894 0.447 | ||
130 | \ | ||
131 | \> r | ||
132 | -2.236 -1.342 -0.447 -6.708 | ||
133 | 0. -2.683 -5.367 -4.472 | ||
134 | \ | ||
135 | \> q \<\> r | ||
136 | 1.000 3.000 5.000 7.000 | ||
137 | 2.000 0. -2.000 4.000@ | ||
138 | |||
139 | -} | ||
140 | qr :: Matrix Double -> (Matrix Double, Matrix Double) | ||
141 | qr x@M {rows = r, cols = c} = unsafePerformIO $ do | ||
142 | q <- createMatrix RowMajor r r | ||
143 | rot <- createMatrix RowMajor r c | ||
144 | c_qr // mat cdat x // mat dat q // mat dat rot // check "qr" [cdat x] | ||
145 | return (q,rot) | ||
146 | foreign import ccall "gsl-aux.h QR" c_qr :: TMMM | ||
147 | |||
148 | {- | Cholesky decomposition of a symmetric positive definite real matrix using /gsl_linalg_cholesky_decomp/. | ||
149 | |||
150 | @\> let c = chol $ 'fromLists' [[5,4],[4,5]] | ||
151 | \ | ||
152 | \> c | ||
153 | 2.236 0. | ||
154 | 1.789 1.342 | ||
155 | \ | ||
156 | \> c \<\> 'trans' c | ||
157 | 5.000 4.000 | ||
158 | 4.000 5.000@ | ||
159 | |||
160 | -} | ||
161 | chol :: Matrix Double -> Matrix Double | ||
162 | --chol x@(M r _ p) = createM [p] "chol" r r $ m c_chol x | ||
163 | chol x@M {rows = r} = unsafePerformIO $ do | ||
164 | res <- createMatrix RowMajor r r | ||
165 | c_chol // mat cdat x // mat dat res // check "chol" [cdat x] | ||
166 | return res | ||
167 | foreign import ccall "gsl-aux.h chol" c_chol :: TMM | ||
168 | |||
169 | -------------------------------------------------------- | ||
170 | |||
171 | {- -| efficient multiplication by the inverse of a matrix (for real matrices) | ||
172 | -} | ||
173 | luSolveR :: Matrix Double -> Matrix Double -> Matrix Double | ||
174 | luSolveR a@(M {rows = n1, cols = n2}) b@(M {rows = r, cols = c}) | ||
175 | | n1==n2 && n1==r = unsafePerformIO $ do | ||
176 | s <- createMatrix RowMajor r c | ||
177 | c_luSolveR // mat cdat a // mat cdat b // mat dat s // check "luSolveR" [cdat a, cdat b] | ||
178 | return s | ||
179 | | otherwise = error "luSolveR of nonsquare matrix" | ||
180 | |||
181 | foreign import ccall "gsl-aux.h luSolveR" c_luSolveR :: TMMM | ||
182 | |||
183 | {- -| efficient multiplication by the inverse of a matrix (for complex matrices). | ||
184 | -} | ||
185 | luSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | ||
186 | luSolveC a@(M {rows = n1, cols = n2}) b@(M {rows = r, cols = c}) | ||
187 | | n1==n2 && n1==r = unsafePerformIO $ do | ||
188 | s <- createMatrix RowMajor r c | ||
189 | c_luSolveC // mat cdat a // mat cdat b // mat dat s // check "luSolveC" [cdat a, cdat b] | ||
190 | return s | ||
191 | | otherwise = error "luSolveC of nonsquare matrix" | ||
192 | |||
193 | foreign import ccall "gsl-aux.h luSolveC" c_luSolveC :: TCMCMCM | ||
194 | |||
195 | {- | lu decomposition of real matrix (packed as a vector including l, u, the permutation and sign) | ||
196 | -} | ||
197 | luRaux :: Matrix Double -> Vector Double | ||
198 | luRaux x@M {rows = r, cols = c} = unsafePerformIO $ do | ||
199 | res <- createVector (r*r+r+1) | ||
200 | c_luRaux // mat cdat x // vec res // check "luRaux" [cdat x] | ||
201 | return res | ||
202 | foreign import ccall "gsl-aux.h luRaux" c_luRaux :: TMV | ||
203 | |||
204 | {- | lu decomposition of complex matrix (packed as a vector including l, u, the permutation and sign) | ||
205 | -} | ||
206 | luCaux :: Matrix (Complex Double) -> Vector (Complex Double) | ||
207 | luCaux x@M {rows = r, cols = c} = unsafePerformIO $ do | ||
208 | res <- createVector (r*r+r+1) | ||
209 | c_luCaux // mat cdat x // vec res // check "luCaux" [cdat x] | ||
210 | return res | ||
211 | foreign import ccall "gsl-aux.h luCaux" c_luCaux :: TCMCV | ||
212 | |||
213 | {- | The LU decomposition of a square matrix. Is based on /gsl_linalg_LU_decomp/ and /gsl_linalg_complex_LU_decomp/ as described in <http://www.gnu.org/software/gsl/manual/gsl-ref_13.html#SEC223>. | ||
214 | |||
215 | @\> let m = 'fromLists' [[1,2,-3],[2+3*i,-7,0],[1,-i,2*i]] | ||
216 | \> let (l,u,p,s) = luR m@ | ||
217 | |||
218 | L is the lower triangular: | ||
219 | |||
220 | @\> l | ||
221 | 1. 0. 0. | ||
222 | 0.154-0.231i 1. 0. | ||
223 | 0.154-0.231i 0.624-0.522i 1.@ | ||
224 | |||
225 | U is the upper triangular: | ||
226 | |||
227 | @\> u | ||
228 | 2.+3.i -7. 0. | ||
229 | 0. 3.077-1.615i -3. | ||
230 | 0. 0. 1.873+0.433i@ | ||
231 | |||
232 | p is a permutation: | ||
233 | |||
234 | @\> p | ||
235 | [1,0,2]@ | ||
236 | |||
237 | L \* U obtains a permuted version of the original matrix: | ||
238 | |||
239 | @\> 'extractRows' p m | ||
240 | 2.+3.i -7. 0. | ||
241 | 1. 2. -3. | ||
242 | 1. -1.i 2.i | ||
243 | \ | ||
244 | \> l \<\> u | ||
245 | 2.+3.i -7. 0. | ||
246 | 1. 2. -3. | ||
247 | 1. -1.i 2.i@ | ||
248 | |||
249 | s is the sign of the permutation, required to obtain sign of the determinant: | ||
250 | |||
251 | @\> s * product ('toList' $ 'takeDiag' u) | ||
252 | (-18.0) :+ (-16.000000000000004) | ||
253 | \> 'LinearAlgebra.Algorithms.det' m | ||
254 | (-18.0) :+ (-16.000000000000004)@ | ||
255 | |||
256 | -} | ||
257 | luR :: Matrix Double -> (Matrix Double, Matrix Double, [Int], Double) | ||
258 | luR m = (l,u,p, fromIntegral s') where | ||
259 | r = rows m | ||
260 | v = luRaux m | ||
261 | lu = reshape r $ subVector 0 (r*r) v | ||
262 | s':p = map round . toList . subVector (r*r) (r+1) $ v | ||
263 | u = triang r r 0 1`mul` lu | ||
264 | l = (triang r r 0 0 `mul` lu) `add` ident r | ||
265 | add = liftMatrix2 $ vectorZipR Add | ||
266 | mul = liftMatrix2 $ vectorZipR Mul | ||
267 | |||
268 | -- | Complex version of 'luR'. | ||
269 | luC :: Matrix (Complex Double) -> (Matrix (Complex Double), Matrix (Complex Double), [Int], Complex Double) | ||
270 | luC m = (l,u,p, fromIntegral s') where | ||
271 | r = rows m | ||
272 | v = luCaux m | ||
273 | lu = reshape r $ subVector 0 (r*r) v | ||
274 | s':p = map (round.realPart) . toList . subVector (r*r) (r+1) $ v | ||
275 | u = triang r r 0 1 `mul` lu | ||
276 | l = (triang r r 0 0 `mul` lu) `add` ident r | ||
277 | add = liftMatrix2 $ vectorZipC Add | ||
278 | mul = liftMatrix2 $ vectorZipC Mul | ||
279 | |||
280 | extract l is = [l!!i |i<-is] | ||
281 | |||
282 | {- auxiliary function to get triangular matrices | ||
283 | -} | ||
284 | triang r c h v = reshape c $ fromList [el i j | i<-[0..r-1], j<-[0..c-1]] | ||
285 | where el i j = if j-i>=h then v else 1 - v | ||
286 | |||
287 | {- | rearranges the rows of a matrix according to the order given in a list of integers. | ||
288 | |||
289 | > > extractRows [3,3,0,1] (ident 4) | ||
290 | > 0. 0. 0. 1. | ||
291 | > 0. 0. 0. 1. | ||
292 | > 1. 0. 0. 0. | ||
293 | > 0. 1. 0. 0. | ||
294 | |||
295 | -} | ||
296 | extractRows :: Field t => [Int] -> Matrix t -> Matrix t | ||
297 | extractRows l m = fromRows $ extract (toRows $ m) l | ||
298 | |||
299 | -------------------------------------------------------------- | ||
300 | |||
301 | -- | loads a matrix efficiently from formatted ASCII text file (the number of rows and columns must be known in advance). | ||
302 | fromFile :: FilePath -> (Int,Int) -> IO (Matrix Double) | ||
303 | fromFile filename (r,c) = do | ||
304 | charname <- newCString filename | ||
305 | res <- createMatrix RowMajor r c | ||
306 | c_gslReadMatrix charname // mat dat res // check "gslReadMatrix" [] | ||
307 | --free charname -- TO DO: free the auxiliary CString | ||
308 | return res | ||
309 | foreign import ccall "gsl-aux.h matrix_fscanf" c_gslReadMatrix:: Ptr CChar -> TM | ||
310 | |||
311 | --------------------------------------------------------------------------- | ||
diff --git a/lib/GSL/Vector.hs b/lib/GSL/Vector.hs index a772b34..a074254 100644 --- a/lib/GSL/Vector.hs +++ b/lib/GSL/Vector.hs | |||
@@ -18,7 +18,7 @@ module GSL.Vector ( | |||
18 | FunCodeV(..), vectorMapR, vectorMapC, | 18 | FunCodeV(..), vectorMapR, vectorMapC, |
19 | FunCodeSV(..), vectorMapValR, vectorMapValC, | 19 | FunCodeSV(..), vectorMapValR, vectorMapValC, |
20 | FunCodeVV(..), vectorZipR, vectorZipC, | 20 | FunCodeVV(..), vectorZipR, vectorZipC, |
21 | scale, addConstant, add, mul, | 21 | scale, addConstant, add, sub, mul, |
22 | ) where | 22 | ) where |
23 | 23 | ||
24 | import Data.Packed.Internal | 24 | import Data.Packed.Internal |
@@ -84,6 +84,11 @@ add u v | isReal baseOf v = scast $ vectorZipR Add (scast u) (scast v) | |||
84 | | isComp baseOf v = scast $ vectorZipC Add (scast u) (scast v) | 84 | | isComp baseOf v = scast $ vectorZipC Add (scast u) (scast v) |
85 | | otherwise = fromList $ zipWith (+) (toList u) (toList v) | 85 | | otherwise = fromList $ zipWith (+) (toList u) (toList v) |
86 | 86 | ||
87 | sub :: (Num a, Field a) => Vector a -> Vector a -> Vector a | ||
88 | sub u v | isReal baseOf v = scast $ vectorZipR Sub (scast u) (scast v) | ||
89 | | isComp baseOf v = scast $ vectorZipC Sub (scast u) (scast v) | ||
90 | | otherwise = fromList $ zipWith (-) (toList u) (toList v) | ||
91 | |||
87 | mul :: (Num a, Field a) => Vector a -> Vector a -> Vector a | 92 | mul :: (Num a, Field a) => Vector a -> Vector a -> Vector a |
88 | mul u v | isReal baseOf v = scast $ vectorZipR Mul (scast u) (scast v) | 93 | mul u v | isReal baseOf v = scast $ vectorZipR Mul (scast u) (scast v) |
89 | | isComp baseOf v = scast $ vectorZipC Mul (scast u) (scast v) | 94 | | isComp baseOf v = scast $ vectorZipC Mul (scast u) (scast v) |
diff --git a/lib/LinearAlgebra.hs b/lib/LinearAlgebra.hs index 474f9a6..b0c8b9d 100644 --- a/lib/LinearAlgebra.hs +++ b/lib/LinearAlgebra.hs | |||
@@ -15,3 +15,4 @@ Some linear algebra algorithms, implemented by means of BLAS, LAPACK or GSL. | |||
15 | module LinearAlgebra ( | 15 | module LinearAlgebra ( |
16 | 16 | ||
17 | ) where | 17 | ) where |
18 | |||
diff --git a/lib/LinearAlgebra/Algorithms.hs b/lib/LinearAlgebra/Algorithms.hs index 680612f..126549a 100644 --- a/lib/LinearAlgebra/Algorithms.hs +++ b/lib/LinearAlgebra/Algorithms.hs | |||
@@ -1,3 +1,4 @@ | |||
1 | {-# OPTIONS_GHC -fglasgow-exts #-} | ||
1 | ----------------------------------------------------------------------------- | 2 | ----------------------------------------------------------------------------- |
2 | {- | | 3 | {- | |
3 | Module : LinearAlgebra.Algorithms | 4 | Module : LinearAlgebra.Algorithms |
@@ -13,5 +14,229 @@ Portability : uses ffi | |||
13 | ----------------------------------------------------------------------------- | 14 | ----------------------------------------------------------------------------- |
14 | 15 | ||
15 | module LinearAlgebra.Algorithms ( | 16 | module LinearAlgebra.Algorithms ( |
16 | 17 | mXm, mXv, vXm, | |
18 | inv, | ||
19 | pinv, | ||
20 | pinvTol, | ||
21 | pinvTolg, | ||
22 | Normed(..), NormType(..), | ||
23 | det, | ||
24 | eps, i | ||
17 | ) where | 25 | ) where |
26 | |||
27 | |||
28 | import Data.Packed.Internal | ||
29 | import Data.Packed.Matrix | ||
30 | import GSL.Matrix | ||
31 | import GSL.Vector | ||
32 | import LAPACK | ||
33 | import Complex | ||
34 | |||
35 | {- | Machine precision of a Double. | ||
36 | |||
37 | >> eps | ||
38 | > 2.22044604925031e-16 | ||
39 | |||
40 | (The value used by GNU-Octave) | ||
41 | |||
42 | -} | ||
43 | eps :: Double | ||
44 | eps = 2.22044604925031e-16 | ||
45 | |||
46 | {- | The imaginary unit | ||
47 | |||
48 | @> 'ident' 3 \<\> i | ||
49 | 1.i 0. 0. | ||
50 | 0. 1.i 0. | ||
51 | 0. 0. 1.i@ | ||
52 | |||
53 | -} | ||
54 | i :: Complex Double | ||
55 | i = 0:+1 | ||
56 | |||
57 | |||
58 | -- | matrix product | ||
59 | mXm :: (Num t, Field t) => Matrix t -> Matrix t -> Matrix t | ||
60 | mXm = multiply RowMajor | ||
61 | |||
62 | -- | matrix - vector product | ||
63 | mXv :: (Num t, Field t) => Matrix t -> Vector t -> Vector t | ||
64 | mXv m v = flatten $ m `mXm` (asColumn v) | ||
65 | |||
66 | -- | vector - matrix product | ||
67 | vXm :: (Num t, Field t) => Vector t -> Matrix t -> Vector t | ||
68 | vXm v m = flatten $ (asRow v) `mXm` m | ||
69 | |||
70 | |||
71 | |||
72 | -- | Pseudoinverse of a real matrix | ||
73 | -- | ||
74 | -- @dispR 3 $ pinv (fromLists [[1,2], | ||
75 | -- [3,4], | ||
76 | -- [5,6]]) | ||
77 | --matrix (2x3) | ||
78 | -- -1.333 | -0.333 | 0.667 | ||
79 | -- 1.083 | 0.333 | -0.417@ | ||
80 | -- | ||
81 | |||
82 | pinv :: Matrix Double -> Matrix Double | ||
83 | pinv m = pinvTol 1 m | ||
84 | --pinv m = linearSolveSVDR Nothing m (ident (rows m)) | ||
85 | |||
86 | {- -| Pseudoinverse of a real matrix with the default tolerance used by GNU-Octave: the singular values less than max (rows, colums) * greatest singular value * 'eps' are ignored. See 'pinvTol'. | ||
87 | |||
88 | @\> let m = 'fromLists' [[ 1, 2] | ||
89 | ,[ 5, 8] | ||
90 | ,[10,-5]] | ||
91 | \> pinv m | ||
92 | 9.353e-3 4.539e-2 7.637e-2 | ||
93 | 2.231e-2 8.993e-2 -4.719e-2 | ||
94 | \ | ||
95 | \> m \<\> pinv m \<\> m | ||
96 | 1. 2. | ||
97 | 5. 8. | ||
98 | 10. -5.@ | ||
99 | |||
100 | -} | ||
101 | --pinvg :: Matrix Double -> Matrix Double | ||
102 | pinvg m = pinvTolg 1 m | ||
103 | |||
104 | {- | Pseudoinverse of a real matrix with the desired tolerance, expressed as a | ||
105 | multiplicative factor of the default tolerance used by GNU-Octave (see 'pinv'). | ||
106 | |||
107 | @\> let m = 'fromLists' [[1,0, 0] | ||
108 | ,[0,1, 0] | ||
109 | ,[0,0,1e-10]] | ||
110 | \ | ||
111 | \> 'pinv' m | ||
112 | 1. 0. 0. | ||
113 | 0. 1. 0. | ||
114 | 0. 0. 10000000000. | ||
115 | \ | ||
116 | \> pinvTol 1E8 m | ||
117 | 1. 0. 0. | ||
118 | 0. 1. 0. | ||
119 | 0. 0. 1.@ | ||
120 | |||
121 | -} | ||
122 | pinvTol :: Double -> Matrix Double -> Matrix Double | ||
123 | pinvTol t m = v' `mXm` diag s' `mXm` trans u' where | ||
124 | (u,s,v) = svdR' m | ||
125 | sl@(g:_) = toList s | ||
126 | s' = fromList . map rec $ sl | ||
127 | rec x = if x < g*tol then 1 else 1/x | ||
128 | tol = (fromIntegral (max (rows m) (cols m)) * g * t * eps) | ||
129 | r = rows m | ||
130 | c = cols m | ||
131 | d = dim s | ||
132 | u' = takeColumns d u | ||
133 | v' = takeColumns d v | ||
134 | |||
135 | |||
136 | pinvTolg :: Double -> Matrix Double -> Matrix Double | ||
137 | pinvTolg t m = v `mXm` diag s' `mXm` trans u where | ||
138 | (u,s,v) = svdg m | ||
139 | sl@(g:_) = toList s | ||
140 | s' = fromList . map rec $ sl | ||
141 | rec x = if x < g*tol then 1 else 1/x | ||
142 | tol = (fromIntegral (max (rows m) (cols m)) * g * t * eps) | ||
143 | |||
144 | |||
145 | |||
146 | {- | Inverse of a square matrix. | ||
147 | |||
148 | inv m = 'linearSolveR' m ('ident' ('rows' m)) | ||
149 | |||
150 | @\>inv ('fromLists' [[1,4] | ||
151 | ,[0,2]]) | ||
152 | 1. -2. | ||
153 | 0. 0.500@ | ||
154 | -} | ||
155 | inv :: Matrix Double -> Matrix Double | ||
156 | inv m = if rows m == cols m | ||
157 | then m `linearSolveR` ident (rows m) | ||
158 | else error "inv of nonsquare matrix" | ||
159 | |||
160 | |||
161 | {- - | Shortcut for the 2-norm ('pnorm' 2) | ||
162 | |||
163 | @ > norm $ 'hilb' 5 | ||
164 | 1.5670506910982311 | ||
165 | @ | ||
166 | |||
167 | @\> norm $ 'fromList' [1,-1,'i',-'i'] | ||
168 | 2.0@ | ||
169 | |||
170 | -} | ||
171 | |||
172 | |||
173 | |||
174 | {- | Determinant of a square matrix, computed from the LU decomposition. | ||
175 | |||
176 | @\> det ('fromLists' [[7,2],[3,8]]) | ||
177 | 50.0@ | ||
178 | |||
179 | -} | ||
180 | det :: Matrix Double -> Double | ||
181 | det m = s * (product $ toList $ takeDiag $ u) | ||
182 | where (_,u,_,s) = luR m | ||
183 | |||
184 | --------------------------------------------------------------------------- | ||
185 | |||
186 | norm2 :: Vector Double -> Double | ||
187 | norm2 = toScalarR Norm2 | ||
188 | |||
189 | norm1 :: Vector Double -> Double | ||
190 | norm1 = toScalarR AbsSum | ||
191 | |||
192 | vectorMax :: Vector Double -> Double | ||
193 | vectorMax = toScalarR Max | ||
194 | vectorMin :: Vector Double -> Double | ||
195 | vectorMin = toScalarR Min | ||
196 | vectorMaxIndex :: Vector Double -> Int | ||
197 | vectorMaxIndex = round . toScalarR MaxIdx | ||
198 | vectorMinIndex :: Vector Double -> Int | ||
199 | vectorMinIndex = round . toScalarR MinIdx | ||
200 | |||
201 | data NormType = Infinity | PNorm1 | PNorm2 -- PNorm Int | ||
202 | |||
203 | pnormRV PNorm2 = norm2 | ||
204 | pnormRV PNorm1 = norm1 | ||
205 | pnormRV Infinity = vectorMax . vectorMapR Abs | ||
206 | --pnormRV _ = error "pnormRV not yet defined" | ||
207 | |||
208 | pnormCV PNorm2 = norm2 . asReal | ||
209 | pnormCV PNorm1 = norm1 . liftVector magnitude | ||
210 | pnormCV Infinity = vectorMax . liftVector magnitude | ||
211 | --pnormCV _ = error "pnormCV not yet defined" | ||
212 | |||
213 | pnormRM PNorm2 m = head (toList s) where (_,s,_) = svdR' m | ||
214 | pnormRM PNorm1 m = vectorMax $ constant 1 (rows m) `vXm` liftMatrix (vectorMapR Abs) m | ||
215 | pnormRM Infinity m = vectorMax $ liftMatrix (vectorMapR Abs) m `mXv` constant 1 (cols m) | ||
216 | --pnormRM _ _ = error "p norm not yet defined" | ||
217 | |||
218 | pnormCM PNorm2 m = head (toList s) where (_,s,_) = svdC' m | ||
219 | pnormCM PNorm1 m = vectorMax $ constant 1 (rows m) `vXm` liftMatrix (liftVector magnitude) m | ||
220 | pnormCM Infinity m = vectorMax $ liftMatrix (liftVector magnitude) m `mXv` constant 1 (cols m) | ||
221 | --pnormCM _ _ = error "p norm not yet defined" | ||
222 | |||
223 | -- -- | computes the p-norm of a matrix or vector (with the same definitions as GNU-octave). pnorm 0 denotes \\inf-norm. See also 'norm'. | ||
224 | --pnorm :: (Container t, Field a) => Int -> t a -> Double | ||
225 | --pnorm = pnormG | ||
226 | |||
227 | class Normed t where | ||
228 | pnorm :: NormType -> t -> Double | ||
229 | norm :: t -> Double | ||
230 | norm = pnorm PNorm2 | ||
231 | |||
232 | instance Normed (Vector Double) where | ||
233 | pnorm = pnormRV | ||
234 | |||
235 | instance Normed (Vector (Complex Double)) where | ||
236 | pnorm = pnormCV | ||
237 | |||
238 | instance Normed (Matrix Double) where | ||
239 | pnorm = pnormRM | ||
240 | |||
241 | instance Normed (Matrix (Complex Double)) where | ||
242 | pnorm = pnormCM | ||