summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--HSSL.cabal2
-rw-r--r--examples/tests.hs152
-rw-r--r--lib/Data/Packed/Instances.hs391
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs18
-rw-r--r--lib/Data/Packed/Internal/Vector.hs6
-rw-r--r--lib/Data/Packed/Matrix.hs35
-rw-r--r--lib/Data/Packed/Vector.hs13
-rw-r--r--lib/GSL.hs45
-rw-r--r--lib/GSL/Matrix.hs311
-rw-r--r--lib/GSL/Vector.hs7
-rw-r--r--lib/LinearAlgebra.hs1
-rw-r--r--lib/LinearAlgebra/Algorithms.hs227
12 files changed, 1119 insertions, 89 deletions
diff --git a/HSSL.cabal b/HSSL.cabal
index fc9db60..84ca5f5 100644
--- a/HSSL.cabal
+++ b/HSSL.cabal
@@ -24,7 +24,7 @@ Exposed-modules: Data.Packed.Internal,
24 Data.Packed.Internal.Matrix, Data.Packed.Matrix, 24 Data.Packed.Internal.Matrix, Data.Packed.Matrix,
25 Data.Packed.Internal.Tensor, Data.Packed.Tensor, 25 Data.Packed.Internal.Tensor, Data.Packed.Tensor,
26 Data.Packed.Plot 26 Data.Packed.Plot
27 -- Data.Packed.Instances 27 Data.Packed.Instances
28 LAPACK, 28 LAPACK,
29 GSL.Vector, 29 GSL.Vector,
30 GSL.Matrix 30 GSL.Matrix
diff --git a/examples/tests.hs b/examples/tests.hs
index 37800cd..9de9339 100644
--- a/examples/tests.hs
+++ b/examples/tests.hs
@@ -16,12 +16,26 @@ import GSL.Fourier
16import GSL.Polynomials 16import GSL.Polynomials
17import LAPACK 17import LAPACK
18import Test.QuickCheck 18import Test.QuickCheck
19import Test.HUnit 19import Test.HUnit hiding ((~:))
20import Complex 20import Complex
21import LinearAlgebra.Algorithms
22import GSL.Matrix
23import Data.Packed.Instances hiding ((<>))
24
25dist :: (Normed t, Num t) => t -> t -> Double
26dist a b = norm (a-b)
27
28infixl 4 |~|
29a |~| b = a :~8~: b
30
31data Aprox a = (:~) a Int
32
33(~:) :: (Normed a, Num a) => Aprox a -> a -> Bool
34a :~n~: b = dist a b < 10^^(-n)
21 35
22 36
23{- 37{-
24-- Bravo por quickCheck! 38-- Bravo por quickCheck!
25 39
26pinvProp1 tol m = (rank m == cols m) ==> pinv m <> m ~~ ident (cols m) 40pinvProp1 tol m = (rank m == cols m) ==> pinv m <> m ~~ ident (cols m)
27 where infix 2 ~~ 41 where infix 2 ~~
@@ -32,7 +46,7 @@ pinvProp2 tol m = 0 < r && r <= c ==> (r==c) `trivial` (m <> pinv m <> m ~~ m)
32 c = cols m 46 c = cols m
33 infix 2 ~~ 47 infix 2 ~~
34 (~~) = approxEqual tol 48 (~~) = approxEqual tol
35 49
36nullspaceProp tol m = cr > 0 ==> m <> nt ~~ zeros 50nullspaceProp tol m = cr > 0 ==> m <> nt ~~ zeros
37 where nt = trans (nullspace m) 51 where nt = trans (nullspace m)
38 cr = corank m 52 cr = corank m
@@ -49,31 +63,28 @@ mz = (2 >< 3) [1,2,3,4,5,6:+(1::Double)]
49af = (2>|<3) [1,4,2,5,3,6::Double] 63af = (2>|<3) [1,4,2,5,3,6::Double]
50bf = (3>|<4) [7,11,15,8,12,16,9,13,17,10,14,18::Double] 64bf = (3>|<4) [7,11,15,8,12,16,9,13,17,10,14,18::Double]
51 65
52a |=| b = rows a == rows b &&
53 cols a == cols b &&
54 toList (cdat a) == toList (cdat b) &&
55 toList (fdat a) == toList (fdat b)
56 66
67{-
57aprox fun a b = rows a == rows b && 68aprox fun a b = rows a == rows b &&
58 cols a == cols b && 69 cols a == cols b &&
59 eps > aproxL fun (toList (t a)) (toList (t b)) 70 epsTol > aproxL fun (toList (t a)) (toList (t b))
60 where t = if (order a == RowMajor) `xor` isTrans a then cdat else fdat 71 where t = if (order a == RowMajor) `xor` isTrans a then cdat else fdat
61 72
62aproxL fun v1 v2 = sum (zipWith (\a b-> fun (a-b)) v1 v2) / fromIntegral (length v1) 73aproxL fun v1 v2 = sum (zipWith (\a b-> fun (a-b)) v1 v2) / fromIntegral (length v1)
63 74
64normVR a b = toScalarR AbsSum (vectorZipR Sub a b) 75normVR a b = toScalarR AbsSum (vectorZipR Sub a b)
65 76
66a |~| b = rows a == rows b && cols a == cols b && eps > normVR (t a) (t b) 77a |~| b = rows a == rows b && cols a == cols b && epsTol > normVR (t a) (t b)
67 where t = if (order a == RowMajor) `xor` isTrans a then cdat else fdat 78 where t = if (order a == RowMajor) `xor` isTrans a then cdat else fdat
68 79
69(|~~|) = aprox magnitude 80(|~~|) = aprox magnitude
70 81
71v1 ~~ v2 = reshape 1 v1 |~~| reshape 1 v2 82v1 ~~ v2 = reshape 1 v1 |~~| reshape 1 v2
72 83
73u ~|~ v = normVR u v < eps 84u ~|~ v = normVR u v < epsTol
74 85-}
75 86
76eps = 1E-8::Double 87epsTol = 1E-8::Double
77 88
78asFortran m = (rows m >|< cols m) $ toList (fdat m) 89asFortran m = (rows m >|< cols m) $ toList (fdat m)
79asC m = (rows m >< cols m) $ toList (cdat m) 90asC m = (rows m >< cols m) $ toList (cdat m)
@@ -81,6 +92,9 @@ asC m = (rows m >< cols m) $ toList (cdat m)
81mulC a b = multiply RowMajor a b 92mulC a b = multiply RowMajor a b
82mulF a b = multiply ColumnMajor a b 93mulF a b = multiply ColumnMajor a b
83 94
95infixl 7 <>
96a <> b = mulF a b
97
84cc = mulC ac bf 98cc = mulC ac bf
85cf = mulF af bc 99cf = mulF af bc
86 100
@@ -133,14 +147,14 @@ data Sym a = Sym (Matrix a) deriving Show
133instance (Field a, Arbitrary a, Num a) => Arbitrary (Sym a) where 147instance (Field a, Arbitrary a, Num a) => Arbitrary (Sym a) where
134 arbitrary = do 148 arbitrary = do
135 SqM m <- arbitrary 149 SqM m <- arbitrary
136 return $ Sym (m `addM` trans m) 150 return $ Sym (m + trans m)
137 coarbitrary = undefined 151 coarbitrary = undefined
138 152
139data Her = Her (Matrix (Complex Double)) deriving Show 153data Her = Her (Matrix (Complex Double)) deriving Show
140instance {-(Field a, Arbitrary a, Num a) =>-} Arbitrary Her where 154instance {-(Field a, Arbitrary a, Num a) =>-} Arbitrary Her where
141 arbitrary = do 155 arbitrary = do
142 SqM m <- arbitrary 156 SqM m <- arbitrary
143 return $ Her (m `addM` (liftMatrix conj) (trans m)) 157 return $ Her (m + conjTrans m)
144 coarbitrary = undefined 158 coarbitrary = undefined
145 159
146data PairSM a = PairSM (Matrix a) (Matrix a) deriving Show 160data PairSM a = PairSM (Matrix a) (Matrix a) deriving Show
@@ -171,49 +185,40 @@ instance (Field a, Arbitrary a) => Arbitrary (PairV a) where
171 coarbitrary = undefined 185 coarbitrary = undefined
172 186
173 187
174addM m1 m2 = liftMatrix2 add m1 m2
175
176 188
177type BaseType = Double 189type BaseType = Double
178 190
179svdTestR fun prod m = u <> s <> trans v |~| m 191svdTestR fun m = u <> s <> trans v |~| m
180 && u <> trans u |~| ident (rows m) 192 && u <> trans u |~| ident (rows m)
181 && v <> trans v |~| ident (cols m) 193 && v <> trans v |~| ident (cols m)
182 where (u,s,v) = fun m 194 where (u,s,v) = fun m
183 (<>) = prod
184 195
185 196
186svdTestC prod m = u <> s' <> (trans v) |~~| m 197svdTestC m = u <> s' <> (trans v) |~| m
187 && u <> (liftMatrix conj) (trans u) |~~| ident (rows m) 198 && u <> conjTrans u |~| ident (rows m)
188 && v <> (liftMatrix conj) (trans v) |~~| ident (cols m) 199 && v <> conjTrans v |~| ident (cols m)
189 where (u,s,v) = svdC m 200 where (u,s,v) = svdC m
190 (<>) = prod
191 s' = liftMatrix comp s 201 s' = liftMatrix comp s
192 202
193eigTestC prod (SqM m) = (m <> v) |~~| (v <> diag s) 203--svdg' m = (u,s',v) where
194 && takeDiag ((liftMatrix conj (trans v)) <> v) ~~ constant 1 (rows m) --normalized 204
205eigTestC (SqM m) = (m <> v) |~| (v <> diag s)
206 && takeDiag (conjTrans v <> v) |~| constant 1 (rows m) --normalized
195 where (s,v) = eigC m 207 where (s,v) = eigC m
196 (<>) = prod
197 208
198eigTestR prod (SqM m) = (liftMatrix comp m <> v) |~~| (v <> diag s) 209eigTestR (SqM m) = (liftMatrix comp m <> v) |~| (v <> diag s)
199 -- && takeDiag ((liftMatrix conj (trans v)) <> v) ~~ constant 1 (rows m) --normalized ??? 210 -- && takeDiag ((liftMatrix conj (trans v)) <> v) |~| constant 1 (rows m) --normalized ???
200 where (s,v) = eigR m 211 where (s,v) = eigR m
201 (<>) = prod
202 212
203eigTestS prod (Sym m) = (m <> v) |~| (v <> diag s) 213eigTestS (Sym m) = (m <> v) |~| (v <> diag s)
204 && v <> trans v |~| ident (cols m) 214 && v <> trans v |~| ident (cols m)
205 where (s,v) = eigS m 215 where (s,v) = eigS m
206 (<>) = prod
207 216
208eigTestH prod (Her m) = (m <> v) |~~| (v <> diag (comp s)) 217eigTestH (Her m) = (m <> v) |~| (v <> diag (comp s))
209 && v <> (liftMatrix conj) (trans v) |~~| ident (cols m) 218 && v <> conjTrans v |~| ident (cols m)
210 where (s,v) = eigH m 219 where (s,v) = eigH m
211 (<>) = prod
212
213linearSolveSQTest fun eqfun singu prod (PairSM a b) = singu a || (a <> fun a b) ==== b
214 where (<>) = prod
215 (====) = eqfun
216 220
221linearSolveSQTest fun singu (PairSM a b) = singu a || (a <> fun a b) |~| b
217 222
218prec = 1E-15 223prec = 1E-15
219 224
@@ -237,8 +242,7 @@ identC n = toComplex(ident n, (0::Double) <>ident n)
237 242
238-------------------------------------------------------------------- 243--------------------------------------------------------------------
239 244
240pinvTest f feq m = (m <> f m <> m) `feq` m 245pinvTest f m = (m <> f m <> m) |~| m
241 where (<>) = mulF
242 246
243pinvR m = linearSolveLSR m (ident (rows m)) 247pinvR m = linearSolveLSR m (ident (rows m))
244pinvC m = linearSolveLSC m (ident (rows m)) 248pinvC m = linearSolveLSC m (ident (rows m))
@@ -252,7 +256,7 @@ pinvSVDC m = linearSolveSVDC Nothing m (ident (rows m))
252polyEval cs x = foldr (\c ac->ac*x+c) 0 cs 256polyEval cs x = foldr (\c ac->ac*x+c) 0 cs
253 257
254polySolveTest' p = length p <2 || last p == 0|| 1E-8 > maximum (map magnitude $ map (polyEval (map (:+0) p)) (polySolve p)) 258polySolveTest' p = length p <2 || last p == 0|| 1E-8 > maximum (map magnitude $ map (polyEval (map (:+0) p)) (polySolve p))
255 where l1 |~~| l2 = eps > aproxL magnitude l1 l2 259
256 260
257polySolveTest = assertBool "polySolve" (polySolveTest' [1,2,3,4]) 261polySolveTest = assertBool "polySolve" (polySolveTest' [1,2,3,4])
258 262
@@ -267,17 +271,17 @@ quad2 f a b g1 g2 = quad h a b
267volSphere r = 8 * quad2 (\x y -> sqrt (r*r-x*x-y*y)) 271volSphere r = 8 * quad2 (\x y -> sqrt (r*r-x*x-y*y))
268 0 r (const 0) (\x->sqrt (r*r-x*x)) 272 0 r (const 0) (\x->sqrt (r*r-x*x))
269 273
270integrateTest = assertBool "integrate" (abs (volSphere 2.5 - 4/3*pi*2.5^3) < eps) 274integrateTest = assertBool "integrate" (abs (volSphere 2.5 - 4/3*pi*2.5^3) < epsTol)
271 275
272 276
273--------------------------------------------------------------------- 277---------------------------------------------------------------------
274 278
275arit1 u = vectorMapValR PowVS 2 (vectorMapR Sin u) 279arit1 u = vectorMapValR PowVS 2 (vectorMapR Sin u)
276 `add` vectorMapValR PowVS 2 (vectorMapR Cos u) 280 `add` vectorMapValR PowVS 2 (vectorMapR Cos u)
277 ~|~ constant 1 (dim u) 281 |~| constant 1 (dim u)
278 282
279arit2 u = (vectorMapR Cos u) `mul` (vectorMapR Tan u) 283arit2 u = (vectorMapR Cos u) `mul` (vectorMapR Tan u)
280 ~|~ vectorMapR Sin u 284 |~| vectorMapR Sin u
281 285
282 286
283-- arit3 (PairV u v) = 287-- arit3 (PairV u v) =
@@ -305,50 +309,48 @@ tests = TestList
305 309
306main = do 310main = do
307 putStrLn "--------- general -----" 311 putStrLn "--------- general -----"
308 quickCheck (\(Sym m) -> m |=| (trans m:: Matrix BaseType)) 312 quickCheck (\(Sym m) -> m == (trans m:: Matrix BaseType))
309 quickCheck $ \l -> null l || (toList . fromList) l == (l :: [BaseType]) 313 quickCheck $ \l -> null l || (toList . fromList) l == (l :: [BaseType])
310 314
311 quickCheck $ \m -> m |=| asC (m :: Matrix BaseType) 315 quickCheck $ \m -> m == asC (m :: Matrix BaseType)
312 quickCheck $ \m -> m |=| asFortran (m :: Matrix BaseType) 316 quickCheck $ \m -> m == asFortran (m :: Matrix BaseType)
313 quickCheck $ \m -> m |=| (asC . asFortran) (m :: Matrix BaseType) 317 quickCheck $ \m -> m == (asC . asFortran) (m :: Matrix BaseType)
314 putStrLn "--------- MULTIPLY ----" 318 putStrLn "--------- MULTIPLY ----"
315 quickCheck $ \(PairM m1 m2) -> mulC m1 m2 |=| mulF m1 (m2 :: Matrix BaseType) 319 quickCheck $ \(PairM m1 m2) -> mulC m1 m2 == mulF m1 (m2 :: Matrix BaseType)
316 quickCheck $ \(PairM m1 m2) -> mulC m1 m2 |=| trans (mulF (trans m2) (trans m1 :: Matrix BaseType)) 320 quickCheck $ \(PairM m1 m2) -> mulC m1 m2 == trans (mulF (trans m2) (trans m1 :: Matrix BaseType))
317 quickCheck $ \(PairM m1 m2) -> mulC m1 m2 |=| multiplyG m1 (m2 :: Matrix BaseType) 321 quickCheck $ \(PairM m1 m2) -> mulC m1 m2 == multiplyG m1 (m2 :: Matrix BaseType)
318 putStrLn "--------- SVD ---------" 322 putStrLn "--------- SVD ---------"
319 quickCheck (svdTestR svdR mulC) 323 quickCheck (svdTestR svdR)
320 quickCheck (svdTestR svdR mulF) 324 quickCheck (svdTestR svdRdd)
321 quickCheck (svdTestR svdRdd mulC) 325-- quickCheck (svdTestR svdg)
322 quickCheck (svdTestR svdRdd mulF) 326 quickCheck svdTestC
323 quickCheck (svdTestC mulC)
324 quickCheck (svdTestC mulF)
325 putStrLn "--------- EIG ---------" 327 putStrLn "--------- EIG ---------"
326 quickCheck (eigTestC mulC) 328 quickCheck eigTestC
327 quickCheck (eigTestC mulF) 329 quickCheck eigTestR
328 quickCheck (eigTestR mulC) 330 quickCheck eigTestS
329 quickCheck (eigTestR mulF) 331 quickCheck eigTestH
330 quickCheck (eigTestS mulC)
331 quickCheck (eigTestS mulF)
332 quickCheck (eigTestH mulC)
333 quickCheck (eigTestH mulF)
334 putStrLn "--------- SOLVE ---------" 332 putStrLn "--------- SOLVE ---------"
335 quickCheck (linearSolveSQTest linearSolveR (|~|) (singular svdR') mulC) 333 quickCheck (linearSolveSQTest linearSolveR (singular svdR'))
336 quickCheck (linearSolveSQTest linearSolveC (|~~|) (singular svdC') mulF) 334 quickCheck (linearSolveSQTest linearSolveC (singular svdC'))
337 quickCheck (pinvTest pinvR (|~|)) 335 quickCheck (pinvTest pinvR)
338 quickCheck (pinvTest pinvC (|~~|)) 336 quickCheck (pinvTest pinvC)
339 quickCheck (pinvTest pinvSVDR (|~|)) 337 quickCheck (pinvTest pinvSVDR)
340 quickCheck (pinvTest pinvSVDC (|~~|)) 338 quickCheck (pinvTest pinvSVDC)
341 putStrLn "--------- VEC OPER ------" 339 putStrLn "--------- VEC OPER ------"
342 quickCheck arit1 340 quickCheck arit1
343 quickCheck arit2 341 quickCheck arit2
344 putStrLn "--------- GSL ------" 342 putStrLn "--------- GSL ------"
345 runTestTT tests 343 runTestTT tests
346 quickCheck $ \v -> ifft (fft v) ~~ v 344 quickCheck $ \v -> ifft (fft v) |~| v
347 345
348kk = (2><2) 346kk = (2><2)
349 [ 1.0, 0.0 347 [ 1.0, 0.0
350 , -1.5, 1.0 ::Double] 348 , -1.5, 1.0 ::Double]
351 349
352v = 11 # [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0::Double] 350v = 11 |> [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0::Double]
351
352pol = [14.125,-7.666666666666667,-14.3,-13.0,-7.0,-9.6,4.666666666666666,13.0,0.5]
353 353
354pol = [14.125,-7.666666666666667,-14.3,-13.0,-7.0,-9.6,4.666666666666666,13.0,0.5] \ No newline at end of file 354mm = (2><2)
355 [ 0.5, 0.0
356 , 0.0, 0.0 ] :: Matrix Double
diff --git a/lib/Data/Packed/Instances.hs b/lib/Data/Packed/Instances.hs
new file mode 100644
index 0000000..4478469
--- /dev/null
+++ b/lib/Data/Packed/Instances.hs
@@ -0,0 +1,391 @@
1{-# OPTIONS_GHC -fglasgow-exts #-}
2-----------------------------------------------------------------------------
3{- |
4Module : Data.Packed.Instances
5Copyright : (c) Alberto Ruiz 2006
6License : GPL-style
7
8Maintainer : Alberto Ruiz (aruiz at um dot es)
9Stability : provisional
10Portability : uses -fffi and -fglasgow-exts
11
12Creates reasonable numeric instances for Vectors and Matrices. In the context of the standard numeric operators, one-component vectors and matrices automatically expand to match the dimensions of the other operand.
13
14-}
15-----------------------------------------------------------------------------
16
17module Data.Packed.Instances(
18 Contractible(..)
19) where
20
21import Data.Packed.Internal
22import Data.Packed.Vector
23import Data.Packed.Matrix
24import GSL.Vector
25import GSL.Matrix
26import LinearAlgebra.Algorithms
27import Complex
28
29instance (Eq a, Field a) => Eq (Vector a) where
30 a == b = dim a == dim b && toList a == toList b
31
32instance (Num a, Field a) => Num (Vector a) where
33 (+) = add
34 (-) = sub
35 (*) = mul
36 signum = liftVector signum
37 abs = liftVector abs
38 fromInteger = fromList . return . fromInteger
39
40instance (Eq a, Field a) => Eq (Matrix a) where
41 a == b = rows a == rows b && cols a == cols b && cdat a == cdat b && fdat a == fdat b
42
43instance (Num a, Field a) => Num (Matrix a) where
44 (+) = liftMatrix2 add
45 (-) = liftMatrix2 sub
46 (*) = liftMatrix2 mul
47 signum = liftMatrix signum
48 abs = liftMatrix abs
49 fromInteger = (1><1) . return . fromInteger
50
51---------------------------------------------------
52
53adaptScalar f1 f2 f3 x y
54 | dim x == 1 = f1 (x@>0) y
55 | dim y == 1 = f3 x (y@>0)
56 | otherwise = f2 x y
57
58{-
59subvv = vectorZip 4
60subvc v c = addConstant (-c) v
61subcv c v = addConstant c (scale (-1) v)
62
63mul = vectorZip 1
64
65instance Num (Vector Double) where
66 (+) = adaptScalar addConstant add (flip addConstant)
67 (-) = adaptScalar subcv subvv subvc
68 (*) = adaptScalar scale mul (flip scale)
69 abs = vectorMap 3
70 signum = vectorMap 15
71 fromInteger n = fromList [fromInteger n]
72
73----------------------------------------------------
74
75--addConstantC a = gmap (+a)
76--subCvv u v = u `add` scale (-1) v
77subCvv = vectorZipComplex 4 -- faster?
78subCvc v c = addConstantC (-c) v
79subCcv c v = addConstantC c (scale (-1) v)
80
81
82instance Num (Vector (Complex Double)) where
83 (+) = adaptScalar addConstantC add (flip addConstantC)
84 (-) = adaptScalar subCcv subCvv subCvc
85 (*) = adaptScalar scale (vectorZipComplex 1) (flip scale)
86 abs = gmap abs
87 signum = gmap signum
88 fromInteger n = fromList [fromInteger n]
89
90
91-- | adapts a function on two vectors to work on all the elements of two matrices
92liftMatrix2' :: (Vector a -> Vector b -> Vector c) -> Matrix a -> Matrix b -> Matrix c
93liftMatrix2' f m1@(M r1 c1 _) m2@(M r2 c2 _)
94 | sameShape m1 m2 || r1*c1==1 || r2*c2==1
95 = reshape (max c1 c2) $ f (flatten m1) (flatten m2)
96 | otherwise = error "inconsistent matrix dimensions"
97
98---------------------------------------------------
99
100instance (Eq a, Field a) => Eq (Matrix a) where
101 a == b = rows a == rows b && cdat a == cdat b
102
103instance Num (Matrix Double) where
104 (+) = liftMatrix2' (+)
105 (-) = liftMatrix2' (-)
106 (*) = liftMatrix2' (*)
107 abs = liftMatrix abs
108 signum = liftMatrix signum
109 fromInteger n = fromLists [[fromInteger n]]
110
111----------------------------------------------------
112
113instance Num (Matrix (Complex Double)) where
114 (+) = liftMatrix2' (+)
115 (-) = liftMatrix2' (-)
116 (*) = liftMatrix2' (*)
117 abs = liftMatrix abs
118 signum = liftMatrix signum
119 fromInteger n = fromLists [[fromInteger n]]
120
121------------------------------------------------------
122
123instance Fractional (Vector Double) where
124 fromRational n = fromList [fromRational n]
125 (/) = adaptScalar f (vectorZip 2) g where
126 r `f` v = vectorZip 2 (constant r (dim v)) v
127 v `g` r = scale (recip r) v
128
129-------------------------------------------------------
130
131instance Fractional (Vector (Complex Double)) where
132 fromRational n = fromList [fromRational n]
133 (/) = adaptScalar f (vectorZipComplex 2) g where
134 r `f` v = gmap ((*r).recip) v
135 v `g` r = gmap (/r) v
136
137------------------------------------------------------
138
139instance Fractional (Matrix Double) where
140 fromRational n = fromLists [[fromRational n]]
141 (/) = liftMatrix2' (/)
142
143-------------------------------------------------------
144
145instance Fractional (Matrix (Complex Double)) where
146 fromRational n = fromLists [[fromRational n]]
147 (/) = liftMatrix2' (/)
148
149---------------------------------------------------------
150
151instance Floating (Vector Double) where
152 sin = vectorMap 0
153 cos = vectorMap 1
154 tan = vectorMap 2
155 asin = vectorMap 4
156 acos = vectorMap 5
157 atan = vectorMap 6
158 sinh = vectorMap 7
159 cosh = vectorMap 8
160 tanh = vectorMap 9
161 asinh = vectorMap 10
162 acosh = vectorMap 11
163 atanh = vectorMap 12
164 exp = vectorMap 13
165 log = vectorMap 14
166 sqrt = vectorMap 16
167 (**) = adaptScalar f (vectorZip 5) g where f s v = constant s (dim v) ** v
168 g v s = v ** constant s (dim v)
169 pi = fromList [pi]
170
171-----------------------------------------------------------
172
173instance Floating (Matrix Double) where
174 sin = liftMatrix sin
175 cos = liftMatrix cos
176 tan = liftMatrix tan
177 asin = liftMatrix asin
178 acos = liftMatrix acos
179 atan = liftMatrix atan
180 sinh = liftMatrix sinh
181 cosh = liftMatrix cosh
182 tanh = liftMatrix tanh
183 asinh = liftMatrix asinh
184 acosh = liftMatrix acosh
185 atanh = liftMatrix atanh
186 exp = liftMatrix exp
187 log = liftMatrix log
188 sqrt = liftMatrix sqrt
189 (**) = liftMatrix2 (**)
190 pi = fromLists [[pi]]
191
192-------------------------------------------------------------
193
194instance Floating (Vector (Complex Double)) where
195 sin = vectorMapComplex 0
196 cos = vectorMapComplex 1
197 tan = vectorMapComplex 2
198 asin = vectorMapComplex 4
199 acos = vectorMapComplex 5
200 atan = vectorMapComplex 6
201 sinh = vectorMapComplex 7
202 cosh = vectorMapComplex 8
203 tanh = vectorMapComplex 9
204 asinh = vectorMapComplex 10
205 acosh = vectorMapComplex 11
206 atanh = vectorMapComplex 12
207 exp = vectorMapComplex 13
208 log = vectorMapComplex 14
209 sqrt = vectorMapComplex 16
210 (**) = adaptScalar f (vectorZipComplex 5) g where f s v = constantC s (dim v) ** v
211 g v s = v ** constantC s (dim v)
212 pi = fromList [pi]
213
214---------------------------------------------------------------
215
216instance Floating (Matrix (Complex Double)) where
217 sin = liftMatrix sin
218 cos = liftMatrix cos
219 tan = liftMatrix tan
220 asin = liftMatrix asin
221 acos = liftMatrix acos
222 atan = liftMatrix atan
223 sinh = liftMatrix sinh
224 cosh = liftMatrix cosh
225 tanh = liftMatrix tanh
226 asinh = liftMatrix asinh
227 acosh = liftMatrix acosh
228 atanh = liftMatrix atanh
229 exp = liftMatrix exp
230 log = liftMatrix log
231 (**) = liftMatrix2 (**)
232 sqrt = liftMatrix sqrt
233 pi = fromLists [[pi]]
234
235---------------------------------------------------------------
236-}
237
238class Contractible a b c | a b -> c where
239 infixl 7 <>
240{- | An overloaded operator for matrix products, matrix-vector and vector-matrix products, dot products and scaling of vectors and matrices. Type consistency is statically checked. Alternatively, you can use the specific functions described below, but using this operator you can automatically combine real and complex objects.
241
242@v = 'fromList' [1,2,3] :: Vector Double
243cv = 'fromList' [1+'i',2]
244m = 'fromLists' [[1,2,3],
245 [4,5,7]] :: Matrix Double
246cm = 'fromLists' [[ 1, 2],
247 [3+'i',7*'i'],
248 [ 'i', 1]]
249\
250\> m \<\> v
25114. 35.
252\
253\> cv \<\> m
2549.+1.i 12.+2.i 17.+3.i
255\
256\> m \<\> cm
257 7.+5.i 5.+14.i
25819.+12.i 15.+35.i
259\
260\> v \<\> 'i'
2611.i 2.i 3.i
262\
263\> v \<\> v
26414.0
265\
266\> cv \<\> cv
2674.0 :+ 2.0@
268
269-}
270 (<>) :: a -> b -> c
271
272
273instance Contractible Double Double Double where
274 (<>) = (*)
275
276instance Contractible Double (Complex Double) (Complex Double) where
277 a <> b = (a:+0) * b
278
279instance Contractible (Complex Double) Double (Complex Double) where
280 a <> b = a * (b:+0)
281
282instance Contractible (Complex Double) (Complex Double) (Complex Double) where
283 (<>) = (*)
284
285--------------------------------- matrix matrix
286
287instance Contractible (Matrix Double) (Matrix Double) (Matrix Double) where
288 (<>) = mXm
289
290instance Contractible (Matrix (Complex Double)) (Matrix (Complex Double)) (Matrix (Complex Double)) where
291 (<>) = mXm
292
293instance Contractible (Matrix (Complex Double)) (Matrix Double) (Matrix (Complex Double)) where
294 c <> r = c <> liftMatrix comp r
295
296instance Contractible (Matrix Double) (Matrix (Complex Double)) (Matrix (Complex Double)) where
297 r <> c = liftMatrix comp r <> c
298
299--------------------------------- (Matrix Double) (Vector Double)
300
301instance Contractible (Matrix Double) (Vector Double) (Vector Double) where
302 (<>) = mXv
303
304instance Contractible (Matrix (Complex Double)) (Vector (Complex Double)) (Vector (Complex Double)) where
305 (<>) = mXv
306
307instance Contractible (Matrix (Complex Double)) (Vector Double) (Vector (Complex Double)) where
308 m <> v = m <> comp v
309
310instance Contractible (Matrix Double) (Vector (Complex Double)) (Vector (Complex Double)) where
311 m <> v = liftMatrix comp m <> v
312
313--------------------------------- (Vector Double) (Matrix Double)
314
315instance Contractible (Vector Double) (Matrix Double) (Vector Double) where
316 (<>) = vXm
317
318instance Contractible (Vector (Complex Double)) (Matrix (Complex Double)) (Vector (Complex Double)) where
319 (<>) = vXm
320
321instance Contractible (Vector (Complex Double)) (Matrix Double) (Vector (Complex Double)) where
322 v <> m = v <> liftMatrix comp m
323
324instance Contractible (Vector Double) (Matrix (Complex Double)) (Vector (Complex Double)) where
325 v <> m = comp v <> m
326
327--------------------------------- dot product
328
329instance Contractible (Vector Double) (Vector Double) Double where
330 (<>) = dot
331
332instance Contractible (Vector (Complex Double)) (Vector (Complex Double)) (Complex Double) where
333 (<>) = dot
334
335instance Contractible (Vector Double) (Vector (Complex Double)) (Complex Double) where
336 a <> b = comp a <> b
337
338instance Contractible (Vector (Complex Double)) (Vector Double) (Complex Double) where
339 (<>) = flip (<>)
340
341--------------------------------- scaling vectors
342
343instance Contractible Double (Vector Double) (Vector Double) where
344 (<>) = scale
345
346instance Contractible (Vector Double) Double (Vector Double) where
347 (<>) = flip (<>)
348
349instance Contractible (Complex Double) (Vector (Complex Double)) (Vector (Complex Double)) where
350 (<>) = scale
351
352instance Contractible (Vector (Complex Double)) (Complex Double) (Vector (Complex Double)) where
353 (<>) = flip (<>)
354
355instance Contractible Double (Vector (Complex Double)) (Vector (Complex Double)) where
356 a <> v = (a:+0) <> v
357
358instance Contractible (Vector (Complex Double)) Double (Vector (Complex Double)) where
359 (<>) = flip (<>)
360
361instance Contractible (Complex Double) (Vector Double) (Vector (Complex Double)) where
362 a <> v = a <> comp v
363
364instance Contractible (Vector Double) (Complex Double) (Vector (Complex Double)) where
365 (<>) = flip (<>)
366
367--------------------------------- scaling matrices
368
369instance Contractible Double (Matrix Double) (Matrix Double) where
370 (<>) a = liftMatrix (a <>)
371
372instance Contractible (Matrix Double) Double (Matrix Double) where
373 (<>) = flip (<>)
374
375instance Contractible (Complex Double) (Matrix (Complex Double)) (Matrix (Complex Double)) where
376 (<>) a = liftMatrix (a <>)
377
378instance Contractible (Matrix (Complex Double)) (Complex Double) (Matrix (Complex Double)) where
379 (<>) = flip (<>)
380
381instance Contractible Double (Matrix (Complex Double)) (Matrix (Complex Double)) where
382 a <> m = (a:+0) <> m
383
384instance Contractible (Matrix (Complex Double)) Double (Matrix (Complex Double)) where
385 (<>) = flip (<>)
386
387instance Contractible (Complex Double) (Matrix Double) (Matrix (Complex Double)) where
388 a <> m = a <> liftMatrix comp m
389
390instance Contractible (Matrix Double) (Complex Double) (Matrix (Complex Double)) where
391 (<>) = flip (<>)
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index 9309d1d..dd33943 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -93,6 +93,15 @@ createMatrix order r c = do
93 p <- createVector (r*c) 93 p <- createVector (r*c)
94 return (matrixFromVector order c p) 94 return (matrixFromVector order c p)
95 95
96{- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns.
97
98@\> reshape 4 ('fromList' [1..12])
99(3><4)
100 [ 1.0, 2.0, 3.0, 4.0
101 , 5.0, 6.0, 7.0, 8.0
102 , 9.0, 10.0, 11.0, 12.0 ]@
103
104-}
96reshape :: (Field t) => Int -> Vector t -> Matrix t 105reshape :: (Field t) => Int -> Vector t -> Matrix t
97reshape c v = matrixFromVector RowMajor c v 106reshape c v = matrixFromVector RowMajor c v
98 107
@@ -140,7 +149,6 @@ liftMatrix f m = m { dat = f (dat m), tdat = f (tdat m) } -- check sizes
140 149
141liftMatrix2 :: (Field t) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t 150liftMatrix2 :: (Field t) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
142liftMatrix2 f m1 m2 = reshape (cols m1) (f (cdat m1) (cdat m2)) -- check sizes 151liftMatrix2 f m1 m2 = reshape (cols m1) (f (cdat m1) (cdat m2)) -- check sizes
143
144------------------------------------------------------------------ 152------------------------------------------------------------------
145 153
146dotL a b = sum (zipWith (*) a b) 154dotL a b = sum (zipWith (*) a b)
@@ -200,6 +208,14 @@ multiplyD order a b
200 208
201outer' u v = dat (outer u v) 209outer' u v = dat (outer u v)
202 210
211{- | Outer product of two vectors.
212
213@\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3]
214(3><3)
215 [ 5.0, 2.0, 3.0
216 , 10.0, 4.0, 6.0
217 , 15.0, 6.0, 9.0 ]@
218-}
203outer :: (Num t, Field t) => Vector t -> Vector t -> Matrix t 219outer :: (Num t, Field t) => Vector t -> Vector t -> Matrix t
204outer u v = multiply RowMajor r c 220outer u v = multiply RowMajor r c
205 where r = matrixFromVector RowMajor 1 u 221 where r = matrixFromVector RowMajor 1 u
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs
index 25e848d..f1addf4 100644
--- a/lib/Data/Packed/Internal/Vector.hs
+++ b/lib/Data/Packed/Internal/Vector.hs
@@ -48,7 +48,7 @@ fromList l = unsafePerformIO $ do
48toList :: Storable a => Vector a -> [a] 48toList :: Storable a => Vector a -> [a]
49toList v = unsafePerformIO $ peekArray (dim v) (ptr v) 49toList v = unsafePerformIO $ peekArray (dim v) (ptr v)
50 50
51n # l = if length l == n then fromList l else error "# with wrong size" 51n |> l = if length l == n then fromList l else error "|> with wrong size"
52 52
53at' :: Storable a => Vector a -> Int -> a 53at' :: Storable a => Vector a -> Int -> a
54at' v n = unsafePerformIO $ peekElemOff (ptr v) n 54at' v n = unsafePerformIO $ peekElemOff (ptr v) n
@@ -58,7 +58,7 @@ at v n | n >= 0 && n < dim v = at' v n
58 | otherwise = error "vector index out of range" 58 | otherwise = error "vector index out of range"
59 59
60instance (Show a, Storable a) => (Show (Vector a)) where 60instance (Show a, Storable a) => (Show (Vector a)) where
61 show v = (show (dim v))++" # " ++ show (toList v) 61 show v = (show (dim v))++" |> " ++ show (toList v)
62 62
63-- | creates a Vector taking a number of consecutive toList from another Vector 63-- | creates a Vector taking a number of consecutive toList from another Vector
64subVector :: Storable t => Int -- ^ index of the starting element 64subVector :: Storable t => Int -- ^ index of the starting element
@@ -129,3 +129,5 @@ constant x n | isReal id x = scast $ constantR (scast x) n
129 | isComp id x = scast $ constantC (scast x) n 129 | isComp id x = scast $ constantC (scast x) n
130 | otherwise = constantG x n 130 | otherwise = constantG x n
131 131
132liftVector f = fromList . map f . toList
133liftVector2 f u v = fromList $ zipWith f (toList u) (toList v)
diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs
index 0f9d998..36bf32e 100644
--- a/lib/Data/Packed/Matrix.hs
+++ b/lib/Data/Packed/Matrix.hs
@@ -13,11 +13,11 @@
13----------------------------------------------------------------------------- 13-----------------------------------------------------------------------------
14 14
15module Data.Packed.Matrix ( 15module Data.Packed.Matrix (
16 Matrix(rows,cols), Field, 16 Matrix(rows,cols),
17 fromLists, toLists, (><), (>|<), (@@>), 17 fromLists, toLists, (><), (>|<), (@@>),
18 trans, conjTrans, 18 trans, conjTrans,
19 reshape, flatten, 19 reshape, flatten, asRow, asColumn,
20 fromRows, toRows, fromColumns, toColumns, 20 fromRows, toRows, fromColumns, toColumns, fromBlocks,
21 joinVert, joinHoriz, 21 joinVert, joinHoriz,
22 flipud, fliprl, 22 flipud, fliprl,
23 liftMatrix, liftMatrix2, 23 liftMatrix, liftMatrix2,
@@ -43,6 +43,22 @@ joinVert ms = case common cols ms of
43joinHoriz :: Field t => [Matrix t] -> Matrix t 43joinHoriz :: Field t => [Matrix t] -> Matrix t
44joinHoriz ms = trans. joinVert . map trans $ ms 44joinHoriz ms = trans. joinVert . map trans $ ms
45 45
46{- | Creates a matrix from blocks given as a list of lists of matrices:
47
48@\> let a = 'diag' $ 'fromList' [5,7,2]
49\> let b = 'reshape' 4 $ 'constant' (-1) 12
50\> fromBlocks [[a,b],[b,a]]
51(6><7)
52 [ 5.0, 0.0, 0.0, -1.0, -1.0, -1.0, -1.0
53 , 0.0, 7.0, 0.0, -1.0, -1.0, -1.0, -1.0
54 , 0.0, 0.0, 2.0, -1.0, -1.0, -1.0, -1.0
55 , -1.0, -1.0, -1.0, -1.0, 5.0, 0.0, 0.0
56 , -1.0, -1.0, -1.0, -1.0, 0.0, 7.0, 0.0
57 , -1.0, -1.0, -1.0, -1.0, 0.0, 0.0, 2.0 ]@
58-}
59fromBlocks :: Field t => [[Matrix t]] -> Matrix t
60fromBlocks = joinVert . map joinHoriz
61
46-- | Reverse rows 62-- | Reverse rows
47flipud :: Field t => Matrix t -> Matrix t 63flipud :: Field t => Matrix t -> Matrix t
48flipud m = fromRows . reverse . toRows $ m 64flipud m = fromRows . reverse . toRows $ m
@@ -98,6 +114,11 @@ dropColumns n mat = subMatrix (0,n) (rows mat, cols mat - n) mat
98 114
99---------------------------------------------------------------- 115----------------------------------------------------------------
100 116
117{- | Creates a vector by concatenation of rows
118
119@\> flatten ('ident' 3)
1209 # [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]@
121-}
101flatten :: Matrix t -> Vector t 122flatten :: Matrix t -> Vector t
102flatten = cdat 123flatten = cdat
103 124
@@ -106,4 +127,10 @@ fromLists :: Field t => [[t]] -> Matrix t
106fromLists = fromRows . map fromList 127fromLists = fromRows . map fromList
107 128
108conjTrans :: Matrix (Complex Double) -> Matrix (Complex Double) 129conjTrans :: Matrix (Complex Double) -> Matrix (Complex Double)
109conjTrans = trans . liftMatrix conj \ No newline at end of file 130conjTrans = trans . liftMatrix conj
131
132asRow :: Field a => Vector a -> Matrix a
133asRow v = reshape (dim v) v
134
135asColumn :: Field a => Vector a -> Matrix a
136asColumn v = reshape 1 v
diff --git a/lib/Data/Packed/Vector.hs b/lib/Data/Packed/Vector.hs
index 9d9d879..94f70be 100644
--- a/lib/Data/Packed/Vector.hs
+++ b/lib/Data/Packed/Vector.hs
@@ -15,7 +15,7 @@
15module Data.Packed.Vector ( 15module Data.Packed.Vector (
16 Vector(dim), Field, 16 Vector(dim), Field,
17 fromList, toList, 17 fromList, toList,
18 at, 18 (@>),
19 subVector, join, 19 subVector, join,
20 constant, 20 constant,
21 toComplex, comp, 21 toComplex, comp,
@@ -26,6 +26,7 @@ module Data.Packed.Vector (
26 26
27import Data.Packed.Internal 27import Data.Packed.Internal
28import Complex 28import Complex
29import GSL.Vector
29 30
30-- | creates a complex vector from vectors with real and imaginary parts 31-- | creates a complex vector from vectors with real and imaginary parts
31toComplex :: (Vector Double, Vector Double) -> Vector (Complex Double) 32toComplex :: (Vector Double, Vector Double) -> Vector (Complex Double)
@@ -41,10 +42,14 @@ comp v = toComplex (v,constant 0 (dim v))
41 42
42{- | Creates a real vector containing a range of values: 43{- | Creates a real vector containing a range of values:
43 44
44> > linspace 10 (-2,2) 45@\> linspace 5 (-3,7)
45>-2. -1.556 -1.111 -0.667 -0.222 0.222 0.667 1.111 1.556 2. 465 |> [-3.0,-0.5,2.0,4.5,7.0]@
46
47-} 47-}
48linspace :: Int -> (Double, Double) -> Vector Double 48linspace :: Int -> (Double, Double) -> Vector Double
49linspace n (a,b) = fromList [a::Double,a+delta .. b] 49linspace n (a,b) = fromList [a::Double,a+delta .. b]
50 where delta = (b-a)/(fromIntegral n -1) 50 where delta = (b-a)/(fromIntegral n -1)
51
52-- | Reads a vector position.
53(@>) :: Field t => Vector t -> Int -> t
54infixl 9 @>
55(@>) = at
diff --git a/lib/GSL.hs b/lib/GSL.hs
new file mode 100644
index 0000000..8e033c3
--- /dev/null
+++ b/lib/GSL.hs
@@ -0,0 +1,45 @@
1{- |
2
3Module : GSL
4Copyright : (c) Alberto Ruiz 2006-7
5License : GPL-style
6
7Maintainer : Alberto Ruiz (aruiz at um dot es)
8Stability : provisional
9Portability : uses -fffi and -fglasgow-exts
10
11This module reexports the basic functionality and a collection of utilities (old interface)
12
13-}
14
15module GSL (
16
17module Data.Packed.Vector,
18module Data.Packed.Matrix,
19module Data.Packed.Tensor,
20module Data.Packed.Instances,
21module LinearAlgebra.Algorithms,
22module LAPACK,
23module GSL.Integration,
24module GSL.Differentiation,
25module GSL.Special,
26module GSL.Fourier,
27module GSL.Polynomials,
28module GSL.Minimization,
29module Data.Packed.Plot
30
31) where
32
33import Data.Packed.Vector
34import Data.Packed.Matrix
35import Data.Packed.Tensor
36import Data.Packed.Instances
37import LinearAlgebra.Algorithms
38import LAPACK
39import GSL.Integration
40import GSL.Differentiation
41import GSL.Special
42import GSL.Fourier
43import GSL.Polynomials
44import GSL.Minimization
45import Data.Packed.Plot
diff --git a/lib/GSL/Matrix.hs b/lib/GSL/Matrix.hs
new file mode 100644
index 0000000..919c2d9
--- /dev/null
+++ b/lib/GSL/Matrix.hs
@@ -0,0 +1,311 @@
1-----------------------------------------------------------------------------
2-- |
3-- Module : GSL.Matrix
4-- Copyright : (c) Alberto Ruiz 2007
5-- License : GPL-style
6--
7-- Maintainer : Alberto Ruiz <aruiz@um.es>
8-- Stability : provisional
9-- Portability : portable (uses FFI)
10--
11-- A few linear algebra computations based on the GSL (<http://www.gnu.org/software/gsl>).
12--
13-----------------------------------------------------------------------------
14
15module GSL.Matrix(
16 eigSg, eigHg,
17 svdg,
18 qr,
19 chol,
20 luSolveR, luSolveC,
21 luR, luC,
22 fromFile
23) where
24
25import Data.Packed.Internal
26import Data.Packed.Matrix(fromLists,ident,takeDiag)
27import GSL.Vector
28import Foreign
29import Foreign.C.Types
30import Complex
31import Foreign.C.String
32
33{- | eigendecomposition of a real symmetric matrix using /gsl_eigen_symmv/.
34
35> > let (l,v) = eigS $ 'fromLists' [[1,2],[2,1]]
36> > l
37> 3.000 -1.000
38>
39> > v
40> 0.707 -0.707
41> 0.707 0.707
42>
43> > v <> diag l <> trans v
44> 1.000 2.000
45> 2.000 1.000
46
47-}
48eigSg :: Matrix Double -> (Vector Double, Matrix Double)
49eigSg (m@M {rows = r})
50 | r == 1 = (fromList [cdat m `at` 0], singleton 1)
51 | otherwise = unsafePerformIO $ do
52 l <- createVector r
53 v <- createMatrix RowMajor r r
54 c_eigS // mat cdat m // vec l // mat dat v // check "eigSg" [cdat m]
55 return (l,v)
56foreign import ccall "gsl-aux.h eigensystemR" c_eigS :: TMVM
57
58------------------------------------------------------------------
59
60
61
62{- | eigendecomposition of a complex hermitian matrix using /gsl_eigen_hermv/
63
64> > let (l,v) = eigH $ 'fromLists' [[1,2+i],[2-i,3]]
65>
66> > l
67> 4.449 -0.449
68>
69> > v
70> -0.544 0.839
71> (-0.751,0.375) (-0.487,0.243)
72>
73> > v <> diag l <> (conjTrans) v
74> 1.000 (2.000,1.000)
75> (2.000,-1.000) 3.000
76
77-}
78eigHg :: Matrix (Complex Double)-> (Vector Double, Matrix (Complex Double))
79eigHg (m@M {rows = r})
80 | r == 1 = (fromList [realPart $ cdat m `at` 0], singleton 1)
81 | otherwise = unsafePerformIO $ do
82 l <- createVector r
83 v <- createMatrix RowMajor r r
84 c_eigH // mat cdat m // vec l // mat dat v // check "eigHg" [cdat m]
85 return (l,v)
86foreign import ccall "gsl-aux.h eigensystemC" c_eigH :: TCMVCM
87
88
89{- | Singular value decomposition of a real matrix, using /gsl_linalg_SV_decomp_mod/:
90
91@\> let (u,s,v) = svdg $ 'fromLists' [[1,2,3],[-4,1,7]]
92\
93\> u
940.310 -0.951
950.951 0.310
96\
97\> s
988.497 2.792
99\
100\> v
101-0.411 -0.785
102 0.185 -0.570
103 0.893 -0.243
104\
105\> u \<\> 'diag' s \<\> 'trans' v
106 1. 2. 3.
107-4. 1. 7.@
108
109-}
110svdg :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double)
111svdg x@M {rows = r, cols = c} = if r>=c
112 then svd' x
113 else (v, s, u) where (u,s,v) = svd' (trans x)
114
115svd' x@M {rows = r, cols = c} = unsafePerformIO $ do
116 u <- createMatrix RowMajor r c
117 s <- createVector c
118 v <- createMatrix RowMajor c c
119 c_svd // mat cdat x // mat dat u // vec s // mat dat v // check "svdg" [cdat x]
120 return (u,s,v)
121foreign import ccall "gsl-aux.h svd" c_svd :: TMMVM
122
123{- | QR decomposition of a real matrix using /gsl_linalg_QR_decomp/ and /gsl_linalg_QR_unpack/.
124
125@\> let (q,r) = qr $ 'fromLists' [[1,3,5,7],[2,0,-2,4]]
126\
127\> q
128-0.447 -0.894
129-0.894 0.447
130\
131\> r
132-2.236 -1.342 -0.447 -6.708
133 0. -2.683 -5.367 -4.472
134\
135\> q \<\> r
1361.000 3.000 5.000 7.000
1372.000 0. -2.000 4.000@
138
139-}
140qr :: Matrix Double -> (Matrix Double, Matrix Double)
141qr x@M {rows = r, cols = c} = unsafePerformIO $ do
142 q <- createMatrix RowMajor r r
143 rot <- createMatrix RowMajor r c
144 c_qr // mat cdat x // mat dat q // mat dat rot // check "qr" [cdat x]
145 return (q,rot)
146foreign import ccall "gsl-aux.h QR" c_qr :: TMMM
147
148{- | Cholesky decomposition of a symmetric positive definite real matrix using /gsl_linalg_cholesky_decomp/.
149
150@\> let c = chol $ 'fromLists' [[5,4],[4,5]]
151\
152\> c
1532.236 0.
1541.789 1.342
155\
156\> c \<\> 'trans' c
1575.000 4.000
1584.000 5.000@
159
160-}
161chol :: Matrix Double -> Matrix Double
162--chol x@(M r _ p) = createM [p] "chol" r r $ m c_chol x
163chol x@M {rows = r} = unsafePerformIO $ do
164 res <- createMatrix RowMajor r r
165 c_chol // mat cdat x // mat dat res // check "chol" [cdat x]
166 return res
167foreign import ccall "gsl-aux.h chol" c_chol :: TMM
168
169--------------------------------------------------------
170
171{- -| efficient multiplication by the inverse of a matrix (for real matrices)
172-}
173luSolveR :: Matrix Double -> Matrix Double -> Matrix Double
174luSolveR a@(M {rows = n1, cols = n2}) b@(M {rows = r, cols = c})
175 | n1==n2 && n1==r = unsafePerformIO $ do
176 s <- createMatrix RowMajor r c
177 c_luSolveR // mat cdat a // mat cdat b // mat dat s // check "luSolveR" [cdat a, cdat b]
178 return s
179 | otherwise = error "luSolveR of nonsquare matrix"
180
181foreign import ccall "gsl-aux.h luSolveR" c_luSolveR :: TMMM
182
183{- -| efficient multiplication by the inverse of a matrix (for complex matrices).
184-}
185luSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
186luSolveC a@(M {rows = n1, cols = n2}) b@(M {rows = r, cols = c})
187 | n1==n2 && n1==r = unsafePerformIO $ do
188 s <- createMatrix RowMajor r c
189 c_luSolveC // mat cdat a // mat cdat b // mat dat s // check "luSolveC" [cdat a, cdat b]
190 return s
191 | otherwise = error "luSolveC of nonsquare matrix"
192
193foreign import ccall "gsl-aux.h luSolveC" c_luSolveC :: TCMCMCM
194
195{- | lu decomposition of real matrix (packed as a vector including l, u, the permutation and sign)
196-}
197luRaux :: Matrix Double -> Vector Double
198luRaux x@M {rows = r, cols = c} = unsafePerformIO $ do
199 res <- createVector (r*r+r+1)
200 c_luRaux // mat cdat x // vec res // check "luRaux" [cdat x]
201 return res
202foreign import ccall "gsl-aux.h luRaux" c_luRaux :: TMV
203
204{- | lu decomposition of complex matrix (packed as a vector including l, u, the permutation and sign)
205-}
206luCaux :: Matrix (Complex Double) -> Vector (Complex Double)
207luCaux x@M {rows = r, cols = c} = unsafePerformIO $ do
208 res <- createVector (r*r+r+1)
209 c_luCaux // mat cdat x // vec res // check "luCaux" [cdat x]
210 return res
211foreign import ccall "gsl-aux.h luCaux" c_luCaux :: TCMCV
212
213{- | The LU decomposition of a square matrix. Is based on /gsl_linalg_LU_decomp/ and /gsl_linalg_complex_LU_decomp/ as described in <http://www.gnu.org/software/gsl/manual/gsl-ref_13.html#SEC223>.
214
215@\> let m = 'fromLists' [[1,2,-3],[2+3*i,-7,0],[1,-i,2*i]]
216\> let (l,u,p,s) = luR m@
217
218L is the lower triangular:
219
220@\> l
221 1. 0. 0.
2220.154-0.231i 1. 0.
2230.154-0.231i 0.624-0.522i 1.@
224
225U is the upper triangular:
226
227@\> u
2282.+3.i -7. 0.
229 0. 3.077-1.615i -3.
230 0. 0. 1.873+0.433i@
231
232p is a permutation:
233
234@\> p
235[1,0,2]@
236
237L \* U obtains a permuted version of the original matrix:
238
239@\> 'extractRows' p m
240 2.+3.i -7. 0.
241 1. 2. -3.
242 1. -1.i 2.i
243\
244\> l \<\> u
245 2.+3.i -7. 0.
246 1. 2. -3.
247 1. -1.i 2.i@
248
249s is the sign of the permutation, required to obtain sign of the determinant:
250
251@\> s * product ('toList' $ 'takeDiag' u)
252(-18.0) :+ (-16.000000000000004)
253\> 'LinearAlgebra.Algorithms.det' m
254(-18.0) :+ (-16.000000000000004)@
255
256 -}
257luR :: Matrix Double -> (Matrix Double, Matrix Double, [Int], Double)
258luR m = (l,u,p, fromIntegral s') where
259 r = rows m
260 v = luRaux m
261 lu = reshape r $ subVector 0 (r*r) v
262 s':p = map round . toList . subVector (r*r) (r+1) $ v
263 u = triang r r 0 1`mul` lu
264 l = (triang r r 0 0 `mul` lu) `add` ident r
265 add = liftMatrix2 $ vectorZipR Add
266 mul = liftMatrix2 $ vectorZipR Mul
267
268-- | Complex version of 'luR'.
269luC :: Matrix (Complex Double) -> (Matrix (Complex Double), Matrix (Complex Double), [Int], Complex Double)
270luC m = (l,u,p, fromIntegral s') where
271 r = rows m
272 v = luCaux m
273 lu = reshape r $ subVector 0 (r*r) v
274 s':p = map (round.realPart) . toList . subVector (r*r) (r+1) $ v
275 u = triang r r 0 1 `mul` lu
276 l = (triang r r 0 0 `mul` lu) `add` ident r
277 add = liftMatrix2 $ vectorZipC Add
278 mul = liftMatrix2 $ vectorZipC Mul
279
280extract l is = [l!!i |i<-is]
281
282{- auxiliary function to get triangular matrices
283-}
284triang r c h v = reshape c $ fromList [el i j | i<-[0..r-1], j<-[0..c-1]]
285 where el i j = if j-i>=h then v else 1 - v
286
287{- | rearranges the rows of a matrix according to the order given in a list of integers.
288
289> > extractRows [3,3,0,1] (ident 4)
290> 0. 0. 0. 1.
291> 0. 0. 0. 1.
292> 1. 0. 0. 0.
293> 0. 1. 0. 0.
294
295-}
296extractRows :: Field t => [Int] -> Matrix t -> Matrix t
297extractRows l m = fromRows $ extract (toRows $ m) l
298
299--------------------------------------------------------------
300
301-- | loads a matrix efficiently from formatted ASCII text file (the number of rows and columns must be known in advance).
302fromFile :: FilePath -> (Int,Int) -> IO (Matrix Double)
303fromFile filename (r,c) = do
304 charname <- newCString filename
305 res <- createMatrix RowMajor r c
306 c_gslReadMatrix charname // mat dat res // check "gslReadMatrix" []
307 --free charname -- TO DO: free the auxiliary CString
308 return res
309foreign import ccall "gsl-aux.h matrix_fscanf" c_gslReadMatrix:: Ptr CChar -> TM
310
311---------------------------------------------------------------------------
diff --git a/lib/GSL/Vector.hs b/lib/GSL/Vector.hs
index a772b34..a074254 100644
--- a/lib/GSL/Vector.hs
+++ b/lib/GSL/Vector.hs
@@ -18,7 +18,7 @@ module GSL.Vector (
18 FunCodeV(..), vectorMapR, vectorMapC, 18 FunCodeV(..), vectorMapR, vectorMapC,
19 FunCodeSV(..), vectorMapValR, vectorMapValC, 19 FunCodeSV(..), vectorMapValR, vectorMapValC,
20 FunCodeVV(..), vectorZipR, vectorZipC, 20 FunCodeVV(..), vectorZipR, vectorZipC,
21 scale, addConstant, add, mul, 21 scale, addConstant, add, sub, mul,
22) where 22) where
23 23
24import Data.Packed.Internal 24import Data.Packed.Internal
@@ -84,6 +84,11 @@ add u v | isReal baseOf v = scast $ vectorZipR Add (scast u) (scast v)
84 | isComp baseOf v = scast $ vectorZipC Add (scast u) (scast v) 84 | isComp baseOf v = scast $ vectorZipC Add (scast u) (scast v)
85 | otherwise = fromList $ zipWith (+) (toList u) (toList v) 85 | otherwise = fromList $ zipWith (+) (toList u) (toList v)
86 86
87sub :: (Num a, Field a) => Vector a -> Vector a -> Vector a
88sub u v | isReal baseOf v = scast $ vectorZipR Sub (scast u) (scast v)
89 | isComp baseOf v = scast $ vectorZipC Sub (scast u) (scast v)
90 | otherwise = fromList $ zipWith (-) (toList u) (toList v)
91
87mul :: (Num a, Field a) => Vector a -> Vector a -> Vector a 92mul :: (Num a, Field a) => Vector a -> Vector a -> Vector a
88mul u v | isReal baseOf v = scast $ vectorZipR Mul (scast u) (scast v) 93mul u v | isReal baseOf v = scast $ vectorZipR Mul (scast u) (scast v)
89 | isComp baseOf v = scast $ vectorZipC Mul (scast u) (scast v) 94 | isComp baseOf v = scast $ vectorZipC Mul (scast u) (scast v)
diff --git a/lib/LinearAlgebra.hs b/lib/LinearAlgebra.hs
index 474f9a6..b0c8b9d 100644
--- a/lib/LinearAlgebra.hs
+++ b/lib/LinearAlgebra.hs
@@ -15,3 +15,4 @@ Some linear algebra algorithms, implemented by means of BLAS, LAPACK or GSL.
15module LinearAlgebra ( 15module LinearAlgebra (
16 16
17) where 17) where
18
diff --git a/lib/LinearAlgebra/Algorithms.hs b/lib/LinearAlgebra/Algorithms.hs
index 680612f..126549a 100644
--- a/lib/LinearAlgebra/Algorithms.hs
+++ b/lib/LinearAlgebra/Algorithms.hs
@@ -1,3 +1,4 @@
1{-# OPTIONS_GHC -fglasgow-exts #-}
1----------------------------------------------------------------------------- 2-----------------------------------------------------------------------------
2{- | 3{- |
3Module : LinearAlgebra.Algorithms 4Module : LinearAlgebra.Algorithms
@@ -13,5 +14,229 @@ Portability : uses ffi
13----------------------------------------------------------------------------- 14-----------------------------------------------------------------------------
14 15
15module LinearAlgebra.Algorithms ( 16module LinearAlgebra.Algorithms (
16 17 mXm, mXv, vXm,
18 inv,
19 pinv,
20 pinvTol,
21 pinvTolg,
22 Normed(..), NormType(..),
23 det,
24 eps, i
17) where 25) where
26
27
28import Data.Packed.Internal
29import Data.Packed.Matrix
30import GSL.Matrix
31import GSL.Vector
32import LAPACK
33import Complex
34
35{- | Machine precision of a Double.
36
37>> eps
38> 2.22044604925031e-16
39
40(The value used by GNU-Octave)
41
42-}
43eps :: Double
44eps = 2.22044604925031e-16
45
46{- | The imaginary unit
47
48@> 'ident' 3 \<\> i
491.i 0. 0.
50 0. 1.i 0.
51 0. 0. 1.i@
52
53-}
54i :: Complex Double
55i = 0:+1
56
57
58-- | matrix product
59mXm :: (Num t, Field t) => Matrix t -> Matrix t -> Matrix t
60mXm = multiply RowMajor
61
62-- | matrix - vector product
63mXv :: (Num t, Field t) => Matrix t -> Vector t -> Vector t
64mXv m v = flatten $ m `mXm` (asColumn v)
65
66-- | vector - matrix product
67vXm :: (Num t, Field t) => Vector t -> Matrix t -> Vector t
68vXm v m = flatten $ (asRow v) `mXm` m
69
70
71
72-- | Pseudoinverse of a real matrix
73--
74-- @dispR 3 $ pinv (fromLists [[1,2],
75-- [3,4],
76-- [5,6]])
77--matrix (2x3)
78-- -1.333 | -0.333 | 0.667
79-- 1.083 | 0.333 | -0.417@
80--
81
82pinv :: Matrix Double -> Matrix Double
83pinv m = pinvTol 1 m
84--pinv m = linearSolveSVDR Nothing m (ident (rows m))
85
86{- -| Pseudoinverse of a real matrix with the default tolerance used by GNU-Octave: the singular values less than max (rows, colums) * greatest singular value * 'eps' are ignored. See 'pinvTol'.
87
88@\> let m = 'fromLists' [[ 1, 2]
89 ,[ 5, 8]
90 ,[10,-5]]
91\> pinv m
929.353e-3 4.539e-2 7.637e-2
932.231e-2 8.993e-2 -4.719e-2
94\
95\> m \<\> pinv m \<\> m
96 1. 2.
97 5. 8.
9810. -5.@
99
100-}
101--pinvg :: Matrix Double -> Matrix Double
102pinvg m = pinvTolg 1 m
103
104{- | Pseudoinverse of a real matrix with the desired tolerance, expressed as a
105multiplicative factor of the default tolerance used by GNU-Octave (see 'pinv').
106
107@\> let m = 'fromLists' [[1,0, 0]
108 ,[0,1, 0]
109 ,[0,0,1e-10]]
110\
111\> 'pinv' m
1121. 0. 0.
1130. 1. 0.
1140. 0. 10000000000.
115\
116\> pinvTol 1E8 m
1171. 0. 0.
1180. 1. 0.
1190. 0. 1.@
120
121-}
122pinvTol :: Double -> Matrix Double -> Matrix Double
123pinvTol t m = v' `mXm` diag s' `mXm` trans u' where
124 (u,s,v) = svdR' m
125 sl@(g:_) = toList s
126 s' = fromList . map rec $ sl
127 rec x = if x < g*tol then 1 else 1/x
128 tol = (fromIntegral (max (rows m) (cols m)) * g * t * eps)
129 r = rows m
130 c = cols m
131 d = dim s
132 u' = takeColumns d u
133 v' = takeColumns d v
134
135
136pinvTolg :: Double -> Matrix Double -> Matrix Double
137pinvTolg t m = v `mXm` diag s' `mXm` trans u where
138 (u,s,v) = svdg m
139 sl@(g:_) = toList s
140 s' = fromList . map rec $ sl
141 rec x = if x < g*tol then 1 else 1/x
142 tol = (fromIntegral (max (rows m) (cols m)) * g * t * eps)
143
144
145
146{- | Inverse of a square matrix.
147
148inv m = 'linearSolveR' m ('ident' ('rows' m))
149
150@\>inv ('fromLists' [[1,4]
151 ,[0,2]])
1521. -2.
1530. 0.500@
154-}
155inv :: Matrix Double -> Matrix Double
156inv m = if rows m == cols m
157 then m `linearSolveR` ident (rows m)
158 else error "inv of nonsquare matrix"
159
160
161{- - | Shortcut for the 2-norm ('pnorm' 2)
162
163@ > norm $ 'hilb' 5
1641.5670506910982311
165@
166
167@\> norm $ 'fromList' [1,-1,'i',-'i']
1682.0@
169
170-}
171
172
173
174{- | Determinant of a square matrix, computed from the LU decomposition.
175
176@\> det ('fromLists' [[7,2],[3,8]])
17750.0@
178
179-}
180det :: Matrix Double -> Double
181det m = s * (product $ toList $ takeDiag $ u)
182 where (_,u,_,s) = luR m
183
184---------------------------------------------------------------------------
185
186norm2 :: Vector Double -> Double
187norm2 = toScalarR Norm2
188
189norm1 :: Vector Double -> Double
190norm1 = toScalarR AbsSum
191
192vectorMax :: Vector Double -> Double
193vectorMax = toScalarR Max
194vectorMin :: Vector Double -> Double
195vectorMin = toScalarR Min
196vectorMaxIndex :: Vector Double -> Int
197vectorMaxIndex = round . toScalarR MaxIdx
198vectorMinIndex :: Vector Double -> Int
199vectorMinIndex = round . toScalarR MinIdx
200
201data NormType = Infinity | PNorm1 | PNorm2 -- PNorm Int
202
203pnormRV PNorm2 = norm2
204pnormRV PNorm1 = norm1
205pnormRV Infinity = vectorMax . vectorMapR Abs
206--pnormRV _ = error "pnormRV not yet defined"
207
208pnormCV PNorm2 = norm2 . asReal
209pnormCV PNorm1 = norm1 . liftVector magnitude
210pnormCV Infinity = vectorMax . liftVector magnitude
211--pnormCV _ = error "pnormCV not yet defined"
212
213pnormRM PNorm2 m = head (toList s) where (_,s,_) = svdR' m
214pnormRM PNorm1 m = vectorMax $ constant 1 (rows m) `vXm` liftMatrix (vectorMapR Abs) m
215pnormRM Infinity m = vectorMax $ liftMatrix (vectorMapR Abs) m `mXv` constant 1 (cols m)
216--pnormRM _ _ = error "p norm not yet defined"
217
218pnormCM PNorm2 m = head (toList s) where (_,s,_) = svdC' m
219pnormCM PNorm1 m = vectorMax $ constant 1 (rows m) `vXm` liftMatrix (liftVector magnitude) m
220pnormCM Infinity m = vectorMax $ liftMatrix (liftVector magnitude) m `mXv` constant 1 (cols m)
221--pnormCM _ _ = error "p norm not yet defined"
222
223-- -- | computes the p-norm of a matrix or vector (with the same definitions as GNU-octave). pnorm 0 denotes \\inf-norm. See also 'norm'.
224--pnorm :: (Container t, Field a) => Int -> t a -> Double
225--pnorm = pnormG
226
227class Normed t where
228 pnorm :: NormType -> t -> Double
229 norm :: t -> Double
230 norm = pnorm PNorm2
231
232instance Normed (Vector Double) where
233 pnorm = pnormRV
234
235instance Normed (Vector (Complex Double)) where
236 pnorm = pnormCV
237
238instance Normed (Matrix Double) where
239 pnorm = pnormRM
240
241instance Normed (Matrix (Complex Double)) where
242 pnorm = pnormCM