summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2007-09-21 18:10:59 +0000
committerAlberto Ruiz <aruiz@um.es>2007-09-21 18:10:59 +0000
commitedfaf9e0d1dcfccc9015476510a23e8cf64288be (patch)
tree2b11f0a933a7cce2362aed26ac160312e6b9431a
parent6cafd2f26a89008cc0db02e70e39f92a50ec4b4d (diff)
algorithms refactoring
-rw-r--r--HSSL.cabal3
-rw-r--r--examples/oldtests.hs120
-rw-r--r--examples/tests.hs357
-rw-r--r--lib/Data/Packed/Internal/Vector.hs2
-rw-r--r--lib/GSL.hs2
-rw-r--r--lib/LAPACK.hs36
-rw-r--r--lib/LinearAlgebra/Algorithms.hs241
-rw-r--r--lib/LinearAlgebra/Linear.hs10
8 files changed, 137 insertions, 634 deletions
diff --git a/HSSL.cabal b/HSSL.cabal
index b841588..22ac658 100644
--- a/HSSL.cabal
+++ b/HSSL.cabal
@@ -14,7 +14,7 @@ tested-with: GHC ==6.6.1
14Build-Depends: base, haskell98 14Build-Depends: base, haskell98
15Extensions: ForeignFunctionInterface 15Extensions: ForeignFunctionInterface
16--ghc-options: -Wall 16--ghc-options: -Wall
17ghc-options: -O0 17ghc-options: -O
18hs-source-dirs: lib 18hs-source-dirs: lib
19Exposed-modules: Data.Packed.Internal, 19Exposed-modules: Data.Packed.Internal,
20 Data.Packed.Internal.Common, 20 Data.Packed.Internal.Common,
@@ -51,6 +51,7 @@ Exposed-modules: Data.Packed.Internal,
51 LinearAlgebra.Instances, 51 LinearAlgebra.Instances,
52 LinearAlgebra.Interface, 52 LinearAlgebra.Interface,
53 LinearAlgebra.Algorithms 53 LinearAlgebra.Algorithms
54-- , LinearAlgebra.Tests
54 , Graphics.Plot 55 , Graphics.Plot
55-- , GSLHaskell 56-- , GSLHaskell
56Other-modules: 57Other-modules:
diff --git a/examples/oldtests.hs b/examples/oldtests.hs
deleted file mode 100644
index 7d4701c..0000000
--- a/examples/oldtests.hs
+++ /dev/null
@@ -1,120 +0,0 @@
1import Test.HUnit
2import LinearAlgebra
3import GSL hiding (exp)
4import System.Random(randomRs,mkStdGen)
5
6realMatrix = fromLists :: [[Double]] -> Matrix Double
7realVector = fromList :: [Double] -> Vector Double
8
9
10
11infixl 2 =~=
12a =~= b = pnorm PNorm1 (flatten (a - b)) < 1E-6
13
14randomMatrix seed (n,m) = reshape m $ realVector $ take (n*m) $ randomRs (-100,100) $ mkStdGen seed
15
16randomMatrixC seed (n,m) = toComplex (randomMatrix seed (n,m), randomMatrix (seed+1) (n,m))
17
18besselTest = do
19 let (r,e) = bessel_J0_e 5.0
20 let expected = -0.17759677131433830434739701
21 assertBool "bessel_J0_e" ( abs (r-expected) < e )
22
23exponentialTest = do
24 let (v,e,err) = exp_e10_e 30.0
25 let expected = exp 30.0
26 assertBool "exp_e10_e" ( abs (v*10^e - expected) < 4E-2 )
27
28disp m = putStrLn (format " " show m)
29
30ms = realMatrix [[1,2,3]
31 ,[-4,1,7]]
32
33ms' = randomMatrix 27 (50,100)
34
35ms'' = toComplex (randomMatrix 100 (50,100),randomMatrix 101 (50,100))
36
37fullsvdTest method mat msg = do
38 let (u,s,vt) = method mat
39 assertBool msg (u <> s <> trans vt =~= mat)
40
41svdg' m = (u, diag s, v) where (u,s,v) = svdg m
42
43full_svd_Rd = svdRdd
44
45--------------------------------------------------------------------
46
47mcu = toComplex (randomMatrix 33 (20,20),randomMatrix 34 (20,20))
48
49mcur = randomMatrix 35 (40,40)
50
51-- eigenvectors are columns
52eigTest method m msg = do
53 let (s,v) = method m
54 assertBool msg $ m <> v =~= v <> diag s
55
56bigmat = m + trans m where m = randomMatrix 18 (1000,1000)
57bigmatc = mc + conjTrans mc where mc = toComplex(m,m)
58 m = randomMatrix 19 (1000,1000)
59
60--------------------------------------------------------------------
61
62invTest msg m = do
63 assertBool msg $ m <> inv m =~= ident (rows m)
64
65invComplexTest msg m = do
66 assertBool msg $ m <> invC m =~= identC (rows m)
67
68invC m = linearSolveC m (identC (rows m))
69
70identC = comp . ident
71
72--------------------------------------------------------------------
73
74pinvTest f msg m = do
75 assertBool msg $ m <> f m <> m =~= m
76
77pinvC m = linearSolveLSC m (identC (rows m))
78
79pinvSVDR m = linearSolveSVDR Nothing m (ident (rows m))
80
81pinvSVDC m = linearSolveSVDC Nothing m (identC (rows m))
82
83--------------------------------------------------------------------
84
85
86tests = TestList [
87 TestCase $ besselTest
88 , TestCase $ exponentialTest
89 , TestCase $ invTest "inv 100x100" (randomMatrix 18 (100,100))
90 , TestCase $ invComplexTest "complex inv 100x100" (randomMatrixC 18 (100,100))
91 , TestCase $ pinvTest (pinvTolg 1) "pinvg 100x50" (randomMatrix 18 (100,50))
92 , TestCase $ pinvTest pinv "pinv 100x50" (randomMatrix 18 (100,50))
93 , TestCase $ pinvTest pinv "pinv 50x100" (randomMatrix 18 (50,100))
94 , TestCase $ pinvTest pinvSVDR "pinvSVDR 100x50" (randomMatrix 18 (100,50))
95 , TestCase $ pinvTest pinvSVDR "pinvSVDR 50x100" (randomMatrix 18 (50,100))
96 , TestCase $ pinvTest pinvC "pinvC 100x50" (randomMatrixC 18 (100,50))
97 , TestCase $ pinvTest pinvC "pinvC 50x100" (randomMatrixC 18 (50,100))
98 , TestCase $ pinvTest pinvSVDC "pinvSVDC 100x50" (randomMatrixC 18 (100,50))
99 , TestCase $ pinvTest pinvSVDC "pinvSVDC 50x100" (randomMatrixC 18 (50,100))
100 , TestCase $ eigTest eigC mcu "eigC"
101 , TestCase $ eigTest eigR mcur "eigR"
102 , TestCase $ eigTest eigS (mcur+trans mcur) "eigS"
103 , TestCase $ eigTest eigSg (mcur+trans mcur) "eigSg"
104 , TestCase $ eigTest eigH (mcu+ (conjTrans) mcu) "eigH"
105 , TestCase $ eigTest eigHg (mcu+ (conjTrans) mcu) "eigHg"
106 , TestCase $ fullsvdTest svdg' ms "GSL svd small"
107 , TestCase $ fullsvdTest svdR ms "fullsvdR small"
108 , TestCase $ fullsvdTest svdR (trans ms) "fullsvdR small"
109 , TestCase $ fullsvdTest svdR ms' "fullsvdR"
110 , TestCase $ fullsvdTest svdR (trans ms') "fullsvdR"
111 , TestCase $ fullsvdTest full_svd_Rd ms' "fullsvdRd"
112 , TestCase $ fullsvdTest full_svd_Rd (trans ms') "fullsvdRd"
113 , TestCase $ fullsvdTest svdC ms'' "fullsvdC"
114 , TestCase $ fullsvdTest svdC (trans ms'') "fullsvdC"
115 , TestCase $ eigTest eigS bigmat "big eigS"
116 , TestCase $ eigTest eigH bigmatc "big eigH"
117 , TestCase $ eigTest eigR bigmat "big eigR"
118 ]
119
120main = runTestTT tests
diff --git a/examples/tests.hs b/examples/tests.hs
deleted file mode 100644
index dcc3cbf..0000000
--- a/examples/tests.hs
+++ /dev/null
@@ -1,357 +0,0 @@
1{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-}
2
3--
4-- QuickCheck tests
5--
6
7-----------------------------------------------------------------------------
8
9import Data.Packed.Internal((>|<), fdat, cdat, multiply', multiplyG, MatrixOrder(..))
10import GSL hiding (sin,cos,exp,choose)
11import LinearAlgebra hiding ((<>))
12import Test.QuickCheck
13import Test.HUnit hiding ((~:))
14
15
16dist :: (Normed t, Num t) => t -> t -> Double
17dist a b = pnorm Infinity (a-b)
18
19infixl 4 |~|
20a |~| b = a :~8~: b
21
22data Aprox a = (:~) a Int
23
24(~:) :: (Normed a, Num a) => Aprox a -> a -> Bool
25a :~n~: b = dist a b < 10^^(-n)
26
27
28{-
29-- Bravo por quickCheck!
30
31pinvProp1 tol m = (rank m == cols m) ==> pinv m <> m ~~ ident (cols m)
32 where infix 2 ~~
33 (~~) = approxEqual tol
34
35pinvProp2 tol m = 0 < r && r <= c ==> (r==c) `trivial` (m <> pinv m <> m ~~ m)
36 where r = rank m
37 c = cols m
38 infix 2 ~~
39 (~~) = approxEqual tol
40
41nullspaceProp tol m = cr > 0 ==> m <> nt ~~ zeros
42 where nt = trans (nullspace m)
43 cr = corank m
44 r = rows m
45 zeros = create [r,cr] $ replicate (r*cr) 0
46
47-}
48
49ac = (2><3) [1 .. 6::Double]
50bc = (3><4) [7 .. 18::Double]
51
52mz = (2 >< 3) [1,2,3,4,5,6:+(1::Double)]
53
54af = (2>|<3) [1,4,2,5,3,6::Double]
55bf = (3>|<4) [7,11,15,8,12,16,9,13,17,10,14,18::Double]
56
57
58{-
59aprox fun a b = rows a == rows b &&
60 cols a == cols b &&
61 epsTol > aproxL fun (toList (t a)) (toList (t b))
62 where t = if (order a == RowMajor) `xor` isTrans a then cdat else fdat
63
64aproxL fun v1 v2 = sum (zipWith (\a b-> fun (a-b)) v1 v2) / fromIntegral (length v1)
65
66normVR a b = toScalarR AbsSum (vectorZipR Sub a b)
67
68a |~| b = rows a == rows b && cols a == cols b && epsTol > normVR (t a) (t b)
69 where t = if (order a == RowMajor) `xor` isTrans a then cdat else fdat
70
71(|~~|) = aprox magnitude
72
73v1 ~~ v2 = reshape 1 v1 |~~| reshape 1 v2
74
75u ~|~ v = normVR u v < epsTol
76-}
77
78epsTol = 1E-8::Double
79
80asFortran m = (rows m >|< cols m) $ toList (fdat m)
81asC m = (rows m >< cols m) $ toList (cdat m)
82
83mulC a b = multiply' RowMajor a b
84mulF a b = multiply' ColumnMajor a b
85
86identC = comp . ident
87
88infixl 7 <>
89a <> b = mulF a b
90
91cc = mulC ac bf
92cf = mulF af bc
93
94r = mulC cc (trans cf)
95
96rd = (2><2)
97 [ 27736.0, 65356.0
98 , 65356.0, 154006.0 ::Double]
99
100instance (Arbitrary a, RealFloat a) => Arbitrary (Complex a) where
101 arbitrary = do
102 r <- arbitrary
103 i <- arbitrary
104 return (r:+i)
105 coarbitrary = undefined
106
107instance (Field a, Arbitrary a) => Arbitrary (Matrix a) where
108 arbitrary = do --m <- sized $ \max -> choose (1,1+3*max)
109 m <- choose (1,10)
110 n <- choose (1,10)
111 l <- vector (m*n)
112 ctype <- arbitrary
113 let h = if ctype then (m><n) else (m>|<n)
114 trMode <- arbitrary
115 let tr = if trMode then trans else id
116 return $ tr (h l)
117 coarbitrary = undefined
118
119data PairM a = PairM (Matrix a) (Matrix a) deriving Show
120instance (Num a, Field a, Arbitrary a) => Arbitrary (PairM a) where
121 arbitrary = do
122 a <- choose (1,10)
123 b <- choose (1,10)
124 c <- choose (1,10)
125 l1 <- vector (a*b)
126 l2 <- vector (b*c)
127 return $ PairM ((a><b) (map fromIntegral (l1::[Int]))) ((b><c) (map fromIntegral (l2::[Int])))
128 --return $ PairM ((a><b) l1) ((b><c) l2)
129 coarbitrary = undefined
130
131data SqM a = SqM (Matrix a) deriving Show
132instance (Field a, Arbitrary a) => Arbitrary (SqM a) where
133 arbitrary = do
134 n <- choose (1,10)
135 l <- vector (n*n)
136 return $ SqM $ (n><n) l
137 coarbitrary = undefined
138
139data Sym a = Sym (Matrix a) deriving Show
140instance (Linear Vector a, Arbitrary a) => Arbitrary (Sym a) where
141 arbitrary = do
142 SqM m <- arbitrary
143 return $ Sym (m + trans m)
144 coarbitrary = undefined
145
146data Her = Her (Matrix (Complex Double)) deriving Show
147instance {-(Field a, Arbitrary a, Num a) =>-} Arbitrary Her where
148 arbitrary = do
149 SqM m <- arbitrary
150 return $ Her (m + conjTrans m)
151 coarbitrary = undefined
152
153data PairSM a = PairSM (Matrix a) (Matrix a) deriving Show
154instance (Num a, Field a, Arbitrary a) => Arbitrary (PairSM a) where
155 arbitrary = do
156 a <- choose (1,10)
157 c <- choose (1,10)
158 l1 <- vector (a*a)
159 l2 <- vector (a*c)
160 return $ PairSM ((a><a) (map fromIntegral (l1::[Int]))) ((a><c) (map fromIntegral (l2::[Int])))
161 --return $ PairSM ((a><a) l1) ((a><c) l2)
162 coarbitrary = undefined
163
164instance (Field a, Arbitrary a) => Arbitrary (Vector a) where
165 arbitrary = do --m <- sized $ \max -> choose (1,1+3*max)
166 m <- choose (1,100)
167 l <- vector m
168 return $ fromList l
169 coarbitrary = undefined
170
171data PairV a = PairV (Vector a) (Vector a)
172instance (Field a, Arbitrary a) => Arbitrary (PairV a) where
173 arbitrary = do --m <- sized $ \max -> choose (1,1+3*max)
174 m <- choose (1,100)
175 l1 <- vector m
176 l2 <- vector m
177 return $ PairV (fromList l1) (fromList l2)
178 coarbitrary = undefined
179
180
181
182type BaseType = Complex Double
183
184svdTestR fun m = u <> s <> trans v |~| m
185 && u <> trans u |~| ident (rows m)
186 && v <> trans v |~| ident (cols m)
187 where (u,s,v) = fun m
188
189
190svdTestC m = u <> s' <> (trans v) |~| m
191 && u <> conjTrans u |~| identC (rows m)
192 && v <> conjTrans v |~| identC (cols m)
193 where (u,s,v) = svdC m
194 s' = liftMatrix comp s
195
196--svdg' m = (u,s',v) where
197
198eigTestC (SqM m) = (m <> v) |~| (v <> diag s)
199 && takeDiag (conjTrans v <> v) |~| comp (constant 1 (rows m)) --normalized
200 where (s,v) = eigC m
201
202eigTestR (SqM m) = (liftMatrix comp m <> v) |~| (v <> diag s)
203 -- && takeDiag ((liftMatrix conj (trans v)) <> v) |~| constant 1 (rows m) --normalized ???
204 where (s,v) = eigR m
205
206eigTestS (Sym m) = (m <> v) |~| (v <> diag s)
207 && v <> trans v |~| ident (cols m)
208 where (s,v) = eigS m
209
210eigTestH (Her m) = (m <> v) |~| (v <> diag (comp s))
211 && v <> conjTrans v |~| identC (cols m)
212 where (s,v) = eigH m
213
214linearSolveSQTest fun singu (PairSM a b) = singu a || (a <> fun a b) |~| b
215
216prec = 1E-15
217
218singular fun m = s1 < prec || s2/s1 < prec
219 where (_,ss,v) = fun m
220 s = toList ss
221 s1 = maximum s
222 s2 = minimum s
223
224{-
225invTest msg m = do
226 assertBool msg $ m <> inv m =~= ident (rows m)
227
228invComplexTest msg m = do
229 assertBool msg $ m <> invC m =~= identC (rows m)
230
231invC m = linearSolveC m (identC (rows m))
232
233identC n = toComplex(ident n, (0::Double) <>ident n)
234-}
235
236--------------------------------------------------------------------
237
238pinvTest f m = (m <> f m <> m) |~| m
239
240pinvR m = linearSolveLSR m (ident (rows m))
241pinvC m = linearSolveLSC m (identC (rows m))
242
243pinvSVDR m = linearSolveSVDR Nothing m (ident (rows m))
244
245pinvSVDC m = linearSolveSVDC Nothing m (identC (rows m))
246
247--------------------------------------------------------------------
248
249polyEval cs x = foldr (\c ac->ac*x+c) 0 cs
250
251polySolveTest' p = length p <2 || last p == 0|| 1E-8 > maximum (map magnitude $ map (polyEval (map (:+0) p)) (polySolve p))
252
253
254polySolveTest = assertBool "polySolve" (polySolveTest' [1,2,3,4])
255
256---------------------------------------------------------------------
257
258quad f a b = fst $ integrateQAGS 1E-9 100 f a b
259
260-- A multiple integral can be easily defined using partial application
261quad2 f a b g1 g2 = quad h a b
262 where h x = quad (f x) (g1 x) (g2 x)
263
264volSphere r = 8 * quad2 (\x y -> sqrt (r*r-x*x-y*y))
265 0 r (const 0) (\x->sqrt (r*r-x*x))
266
267integrateTest = assertBool "integrate" (abs (volSphere 2.5 - 4/3*pi*2.5^3) < epsTol)
268
269
270---------------------------------------------------------------------
271
272arit1 u = sin u ^ 2 + cos u ^ 2 |~| 1
273 where _ = u :: Vector Double
274
275arit2 u = sin u ** 2 + cos u ** 2 |~| 1
276 where _ = u :: Vector Double
277
278arit3 u = cos u * tan u |~| sin u
279 where _ = u :: Vector Double
280
281arit4 u = (cos u * tan u) :~6~: sin u
282 where _ = u :: Vector (Complex Double)
283
284---------------------------------------------------------------------
285
286besselTest = do
287 let (r,e) = bessel_J0_e 5.0
288 let expected = -0.17759677131433830434739701
289 assertBool "bessel_J0_e" ( abs (r-expected) < e )
290
291exponentialTest = do
292 let (v,e,err) = exp_e10_e 30.0
293 let expected = exp 30.0
294 assertBool "exp_e10_e" ( abs (v*10^e - expected) < 4E-2 )
295
296gammaTest = do
297 assertBool "gamma" (gamma 5 == 24.0)
298
299tests = TestList
300 [ TestCase $ besselTest
301 , TestCase $ exponentialTest
302 , TestCase $ gammaTest
303 , TestCase $ polySolveTest
304 , TestCase $ integrateTest
305 ]
306
307----------------------------------------------------------------------
308
309main = do
310 putStrLn "--------- general -----"
311 quickCheck (\(Sym m) -> m == (trans m:: Matrix BaseType))
312 quickCheck $ \l -> null l || (toList . fromList) l == (l :: [BaseType])
313
314 quickCheck $ \m -> m == asC (m :: Matrix BaseType)
315 quickCheck $ \m -> m == asFortran (m :: Matrix BaseType)
316 quickCheck $ \m -> m == (asC . asFortran) (m :: Matrix BaseType)
317 putStrLn "--------- MULTIPLY ----"
318 quickCheck $ \(PairM m1 m2) -> mulC m1 m2 == mulF m1 (m2 :: Matrix BaseType)
319 quickCheck $ \(PairM m1 m2) -> mulC m1 m2 == trans (mulF (trans m2) (trans m1 :: Matrix BaseType))
320 quickCheck $ \(PairM m1 m2) -> mulC m1 m2 == multiplyG m1 (m2 :: Matrix BaseType)
321 putStrLn "--------- SVD ---------"
322 quickCheck (svdTestR svdR)
323 quickCheck (svdTestR svdRdd)
324-- quickCheck (svdTestR svdg)
325 quickCheck svdTestC
326 putStrLn "--------- EIG ---------"
327 quickCheck eigTestC
328 quickCheck eigTestR
329 quickCheck eigTestS
330 quickCheck eigTestH
331 putStrLn "--------- SOLVE ---------"
332 quickCheck (linearSolveSQTest linearSolveR (singular svdR'))
333 quickCheck (linearSolveSQTest linearSolveC (singular svdC'))
334 quickCheck (pinvTest pinvR)
335 quickCheck (pinvTest pinvC)
336 quickCheck (pinvTest pinvSVDR)
337 quickCheck (pinvTest pinvSVDC)
338 putStrLn "--------- VEC OPER ------"
339 quickCheck arit1
340 quickCheck arit2
341 quickCheck arit3
342 quickCheck arit4
343 putStrLn "--------- GSL ------"
344 runTestTT tests
345 quickCheck $ \v -> ifft (fft v) |~| v
346
347kk = (2><2)
348 [ 1.0, 0.0
349 , -1.5, 1.0 ::Double]
350
351v = 11 |> [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0::Double]
352
353pol = [14.125,-7.666666666666667,-14.3,-13.0,-7.0,-9.6,4.666666666666666,13.0,0.5]
354
355mm = (2><2)
356 [ 0.5, 0.0
357 , 0.0, 0.0 ] :: Matrix Double
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs
index f0ef8b6..ebe6371 100644
--- a/lib/Data/Packed/Internal/Vector.hs
+++ b/lib/Data/Packed/Internal/Vector.hs
@@ -24,6 +24,7 @@ import Data.List(transpose)
24import Debug.Trace(trace) 24import Debug.Trace(trace)
25import Foreign.C.String(peekCString) 25import Foreign.C.String(peekCString)
26import Foreign.C.Types 26import Foreign.C.Types
27import Data.Monoid
27 28
28-- | A one-dimensional array of objects stored in a contiguous memory block. 29-- | A one-dimensional array of objects stored in a contiguous memory block.
29data Vector t = V { dim :: Int -- ^ number of elements 30data Vector t = V { dim :: Int -- ^ number of elements
@@ -177,3 +178,4 @@ liftVector f = fromList . map f . toList
177liftVector2 :: (Storable a, Storable b, Storable c) => (a-> b -> c) -> Vector a -> Vector b -> Vector c 178liftVector2 :: (Storable a, Storable b, Storable c) => (a-> b -> c) -> Vector a -> Vector b -> Vector c
178liftVector2 f u v = fromList $ zipWith f (toList u) (toList v) 179liftVector2 f u v = fromList $ zipWith f (toList u) (toList v)
179 180
181-----------------------------------------------------------------
diff --git a/lib/GSL.hs b/lib/GSL.hs
index 8b6365b..d65f8ff 100644
--- a/lib/GSL.hs
+++ b/lib/GSL.hs
@@ -8,7 +8,7 @@ Maintainer : Alberto Ruiz (aruiz at um dot es)
8Stability : provisional 8Stability : provisional
9Portability : uses -fffi and -fglasgow-exts 9Portability : uses -fffi and -fglasgow-exts
10 10
11This module reexports all the GSL functions (except those in "LinearAlgebra"). 11This module reexports all the available GSL functions (except those in "LinearAlgebra").
12 12
13-} 13-}
14 14
diff --git a/lib/LAPACK.hs b/lib/LAPACK.hs
index 2b92a2a..54eea8a 100644
--- a/lib/LAPACK.hs
+++ b/lib/LAPACK.hs
@@ -14,7 +14,7 @@
14----------------------------------------------------------------------------- 14-----------------------------------------------------------------------------
15 15
16module LAPACK ( 16module LAPACK (
17 svdR, svdR', svdRdd, svdRdd', svdC, svdC', 17 svdR, svdRdd, svdC,
18 eigC, eigR, eigS, eigH, 18 eigC, eigR, eigS, eigH,
19 linearSolveR, linearSolveC, 19 linearSolveR, linearSolveC,
20 linearSolveLSR, linearSolveLSC, 20 linearSolveLSR, linearSolveLSC,
@@ -26,7 +26,6 @@ import Data.Packed.Internal.Vector
26import Data.Packed.Internal.Matrix 26import Data.Packed.Internal.Matrix
27import Data.Packed.Vector 27import Data.Packed.Vector
28import Data.Packed.Matrix 28import Data.Packed.Matrix
29--import LinearAlgebra.Linear(scale)
30import GSL.Vector(vectorMapValR, FunCodeSV(Scale)) 29import GSL.Vector(vectorMapValR, FunCodeSV(Scale))
31import Complex 30import Complex
32import Foreign 31import Foreign
@@ -36,14 +35,9 @@ foreign import ccall "LAPACK/lapack-aux.h svd_l_R" dgesvd :: TMMVM
36 35
37-- | Wrapper for LAPACK's /dgesvd/, which computes the full svd decomposition of a real matrix. 36-- | Wrapper for LAPACK's /dgesvd/, which computes the full svd decomposition of a real matrix.
38-- 37--
39-- @(u,s,v)=svdR m@ so that @m=u \<\> s \<\> 'trans' v@. 38-- @(u,s,v)=full svdR m@ so that @m=u \<\> s \<\> 'trans' v@.
40svdR :: Matrix Double -> (Matrix Double, Matrix Double, Matrix Double) 39svdR :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double)
41svdR x = (u, diagRect s r c, v) where (u,s,v) = svdR' x 40svdR x = unsafePerformIO $ do
42 r = rows x
43 c = cols x
44
45svdR' :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double)
46svdR' x = unsafePerformIO $ do
47 u <- createMatrix ColumnMajor r r 41 u <- createMatrix ColumnMajor r r
48 s <- createVector (min r c) 42 s <- createVector (min r c)
49 v <- createMatrix ColumnMajor c c 43 v <- createMatrix ColumnMajor c c
@@ -56,14 +50,9 @@ foreign import ccall "LAPACK/lapack-aux.h svd_l_Rdd" dgesdd :: TMMVM
56 50
57-- | Wrapper for LAPACK's /dgesvd/, which computes the full svd decomposition of a real matrix. 51-- | Wrapper for LAPACK's /dgesvd/, which computes the full svd decomposition of a real matrix.
58-- 52--
59-- @(u,s,v)=svdRdd m@ so that @m=u \<\> s \<\> 'trans' v@. 53-- @(u,s,v)=full svdRdd m@ so that @m=u \<\> s \<\> 'trans' v@.
60svdRdd :: Matrix Double -> (Matrix Double, Matrix Double , Matrix Double) 54svdRdd :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double)
61svdRdd x = (u, diagRect s r c, v) where (u,s,v) = svdRdd' x 55svdRdd x = unsafePerformIO $ do
62 r = rows x
63 c = cols x
64
65svdRdd' :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double)
66svdRdd' x = unsafePerformIO $ do
67 u <- createMatrix ColumnMajor r r 56 u <- createMatrix ColumnMajor r r
68 s <- createVector (min r c) 57 s <- createVector (min r c)
69 v <- createMatrix ColumnMajor c c 58 v <- createMatrix ColumnMajor c c
@@ -77,14 +66,9 @@ foreign import ccall "LAPACK/lapack-aux.h svd_l_C" zgesvd :: TCMCMVCM
77 66
78-- | Wrapper for LAPACK's /zgesvd/, which computes the full svd decomposition of a complex matrix. 67-- | Wrapper for LAPACK's /zgesvd/, which computes the full svd decomposition of a complex matrix.
79-- 68--
80-- @(u,s,v)=svdC m@ so that @m=u \<\> s \<\> 'trans' v@. 69-- @(u,s,v)=full svdC m@ so that @m=u \<\> comp s \<\> 'trans' v@.
81svdC :: Matrix (Complex Double) -> (Matrix (Complex Double), Matrix Double, Matrix (Complex Double)) 70svdC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector Double, Matrix (Complex Double))
82svdC x = (u, diagRect s r c, v) where (u,s,v) = svdC' x 71svdC x = unsafePerformIO $ do
83 r = rows x
84 c = cols x
85
86svdC' :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector Double, Matrix (Complex Double))
87svdC' x = unsafePerformIO $ do
88 u <- createMatrix ColumnMajor r r 72 u <- createMatrix ColumnMajor r r
89 s <- createVector (min r c) 73 s <- createVector (min r c)
90 v <- createMatrix ColumnMajor c c 74 v <- createMatrix ColumnMajor c c
diff --git a/lib/LinearAlgebra/Algorithms.hs b/lib/LinearAlgebra/Algorithms.hs
index 3112ce6..7953386 100644
--- a/lib/LinearAlgebra/Algorithms.hs
+++ b/lib/LinearAlgebra/Algorithms.hs
@@ -14,15 +14,13 @@ Portability : uses ffi
14----------------------------------------------------------------------------- 14-----------------------------------------------------------------------------
15 15
16module LinearAlgebra.Algorithms ( 16module LinearAlgebra.Algorithms (
17 -- mXv, vXm, 17 GMatrix(..),
18 inv, 18 Normed(..), NormType(..),
19 pinv, 19 det,inv,pinv,full,economy,
20 pinvTol, 20 pinvTol,
21 pinvTolg, 21-- pinvTolg,
22 nullspacePrec, 22 nullspacePrec,
23 nullVector, 23 nullVector,
24 Normed(..), NormType(..),
25 det,
26 eps, i 24 eps, i
27) where 25) where
28 26
@@ -33,6 +31,69 @@ import GSL.Matrix
33import GSL.Vector 31import GSL.Vector
34import LAPACK 32import LAPACK
35import Complex 33import Complex
34import LinearAlgebra.Linear
35
36class (Linear Matrix t) => GMatrix t where
37 svd :: Matrix t -> (Matrix t, Vector Double, Matrix t)
38 lu :: Matrix t -> (Matrix t, Matrix t, [Int], t)
39 linearSolve :: Matrix t -> Matrix t -> Matrix t
40 linearSolveSVD :: Matrix t -> Matrix t -> Matrix t
41 ctrans :: Matrix t -> Matrix t
42 eig :: Matrix t -> (Vector (Complex Double), Matrix (Complex Double))
43 eigSH :: Matrix t -> (Vector Double, Matrix t)
44
45instance GMatrix Double where
46 svd = svdR
47 lu = luR
48 linearSolve = linearSolveR
49 linearSolveSVD = linearSolveSVDR Nothing
50 ctrans = trans
51 eig = eigR
52 eigSH = eigS
53
54instance GMatrix (Complex Double) where
55 svd = svdC
56 lu = luC
57 linearSolve = linearSolveC
58 linearSolveSVD = linearSolveSVDC Nothing
59 ctrans = conjTrans
60 eig = eigC
61 eigSH = eigH
62
63square m = rows m == cols m
64
65det :: GMatrix t => Matrix t -> t
66det m | square m = s * (product $ toList $ takeDiag $ u)
67 | otherwise = error "det of nonsquare matrix"
68 where (_,u,_,s) = lu m
69
70inv :: GMatrix t => Matrix t -> Matrix t
71inv m | square m = m `linearSolve` ident (rows m)
72 | otherwise = error "inv of nonsquare matrix"
73
74pinv :: GMatrix t => Matrix t -> Matrix t
75pinv m = linearSolveSVD m (ident (rows m))
76
77
78full svd m = (u, d ,v) where
79 (u,s,v) = svd m
80 d = diagRect s r c
81 r = rows m
82 c = cols m
83
84economy svd m = (u', subVector 0 d s, v') where
85 (u,s,v) = svd m
86 sl@(g:_) = toList (complex s)
87 s' = fromList . filter rec $ sl
88 rec x = magnitude x > magnitude g*tol
89 t = 1
90 tol = (fromIntegral (max (rows m) (cols m)) * magnitude g * t * eps)
91 r = rows m
92 c = cols m
93 d = dim s'
94 u' = takeColumns d u
95 v' = takeColumns d v
96
36 97
37{- | Machine precision of a Double. 98{- | Machine precision of a Double.
38 99
@@ -70,119 +131,6 @@ vXm :: (Num t, Field t) => Vector t -> Matrix t -> Vector t
70vXm v m = flatten $ (asRow v) `mXm` m 131vXm v m = flatten $ (asRow v) `mXm` m
71 132
72 133
73
74-- | Pseudoinverse of a real matrix
75--
76-- @dispR 3 $ pinv (fromLists [[1,2],
77-- [3,4],
78-- [5,6]])
79--matrix (2x3)
80-- -1.333 | -0.333 | 0.667
81-- 1.083 | 0.333 | -0.417@
82--
83
84pinv :: Matrix Double -> Matrix Double
85pinv m = pinvTol 1 m
86--pinv m = linearSolveSVDR Nothing m (ident (rows m))
87
88{- -| Pseudoinverse of a real matrix with the default tolerance used by GNU-Octave: the singular values less than max (rows, colums) * greatest singular value * 'eps' are ignored. See 'pinvTol'.
89
90@\> let m = 'fromLists' [[ 1, 2]
91 ,[ 5, 8]
92 ,[10,-5]]
93\> pinv m
949.353e-3 4.539e-2 7.637e-2
952.231e-2 8.993e-2 -4.719e-2
96\
97\> m \<\> pinv m \<\> m
98 1. 2.
99 5. 8.
10010. -5.@
101
102-}
103--pinvg :: Matrix Double -> Matrix Double
104pinvg m = pinvTolg 1 m
105
106{- | Pseudoinverse of a real matrix with the desired tolerance, expressed as a
107multiplicative factor of the default tolerance used by GNU-Octave (see 'pinv').
108
109@\> let m = 'fromLists' [[1,0, 0]
110 ,[0,1, 0]
111 ,[0,0,1e-10]]
112\
113\> 'pinv' m
1141. 0. 0.
1150. 1. 0.
1160. 0. 10000000000.
117\
118\> pinvTol 1E8 m
1191. 0. 0.
1200. 1. 0.
1210. 0. 1.@
122
123-}
124pinvTol :: Double -> Matrix Double -> Matrix Double
125pinvTol t m = v' `mXm` diag s' `mXm` trans u' where
126 (u,s,v) = svdR' m
127 sl@(g:_) = toList s
128 s' = fromList . map rec $ sl
129 rec x = if x < g*tol then 1 else 1/x
130 tol = (fromIntegral (max (rows m) (cols m)) * g * t * eps)
131 r = rows m
132 c = cols m
133 d = dim s
134 u' = takeColumns d u
135 v' = takeColumns d v
136
137
138pinvTolg :: Double -> Matrix Double -> Matrix Double
139pinvTolg t m = v `mXm` diag s' `mXm` trans u where
140 (u,s,v) = svdg m
141 sl@(g:_) = toList s
142 s' = fromList . map rec $ sl
143 rec x = if x < g*tol then 1 else 1/x
144 tol = (fromIntegral (max (rows m) (cols m)) * g * t * eps)
145
146
147
148{- | Inverse of a square matrix.
149
150inv m = 'linearSolveR' m ('ident' ('rows' m))
151
152@\>inv ('fromLists' [[1,4]
153 ,[0,2]])
1541. -2.
1550. 0.500@
156-}
157inv :: Matrix Double -> Matrix Double
158inv m = if rows m == cols m
159 then m `linearSolveR` ident (rows m)
160 else error "inv of nonsquare matrix"
161
162
163{- - | Shortcut for the 2-norm ('pnorm' 2)
164
165@ > norm $ 'hilb' 5
1661.5670506910982311
167@
168
169@\> norm $ 'fromList' [1,-1,'i',-'i']
1702.0@
171
172-}
173
174
175
176{- | Determinant of a square matrix, computed from the LU decomposition.
177
178@\> det ('fromLists' [[7,2],[3,8]])
17950.0@
180
181-}
182det :: Matrix Double -> Double
183det m = s * (product $ toList $ takeDiag $ u)
184 where (_,u,_,s) = luR m
185
186--------------------------------------------------------------------------- 134---------------------------------------------------------------------------
187 135
188norm2 :: Vector Double -> Double 136norm2 :: Vector Double -> Double
@@ -212,12 +160,12 @@ pnormCV PNorm1 = norm1 . liftVector magnitude
212pnormCV Infinity = vectorMax . liftVector magnitude 160pnormCV Infinity = vectorMax . liftVector magnitude
213--pnormCV _ = error "pnormCV not yet defined" 161--pnormCV _ = error "pnormCV not yet defined"
214 162
215pnormRM PNorm2 m = head (toList s) where (_,s,_) = svdR' m 163pnormRM PNorm2 m = head (toList s) where (_,s,_) = svdR m
216pnormRM PNorm1 m = vectorMax $ constant 1 (rows m) `vXm` liftMatrix (vectorMapR Abs) m 164pnormRM PNorm1 m = vectorMax $ constant 1 (rows m) `vXm` liftMatrix (vectorMapR Abs) m
217pnormRM Infinity m = vectorMax $ liftMatrix (vectorMapR Abs) m `mXv` constant 1 (cols m) 165pnormRM Infinity m = vectorMax $ liftMatrix (vectorMapR Abs) m `mXv` constant 1 (cols m)
218--pnormRM _ _ = error "p norm not yet defined" 166--pnormRM _ _ = error "p norm not yet defined"
219 167
220pnormCM PNorm2 m = head (toList s) where (_,s,_) = svdC' m 168pnormCM PNorm2 m = head (toList s) where (_,s,_) = svdC m
221pnormCM PNorm1 m = vectorMax $ constant 1 (rows m) `vXm` liftMatrix (liftVector magnitude) m 169pnormCM PNorm1 m = vectorMax $ constant 1 (rows m) `vXm` liftMatrix (liftVector magnitude) m
222pnormCM Infinity m = vectorMax $ liftMatrix (liftVector magnitude) m `mXv` constant 1 (cols m) 170pnormCM Infinity m = vectorMax $ liftMatrix (liftVector magnitude) m `mXv` constant 1 (cols m)
223--pnormCM _ _ = error "p norm not yet defined" 171--pnormCM _ _ = error "p norm not yet defined"
@@ -245,17 +193,52 @@ instance Normed (Matrix (Complex Double)) where
245 193
246----------------------------------------------------------------------- 194-----------------------------------------------------------------------
247 195
248-- | The nullspace of a real matrix from its SVD decomposition. 196-- | The nullspace of a matrix from its SVD decomposition.
249nullspacePrec :: Double -- ^ relative tolerance in 'eps' units 197nullspacePrec :: GMatrix t
250 -> Matrix Double -- ^ input matrix 198 => Double -- ^ relative tolerance in 'eps' units
251 -> [Vector Double] -- ^ list of unitary vectors spanning the nullspace 199 -> Matrix t -- ^ input matrix
200 -> [Vector t] -- ^ list of unitary vectors spanning the nullspace
252nullspacePrec t m = ns where 201nullspacePrec t m = ns where
253 (_,s,v) = svdR' m 202 (_,s,v) = svd m
254 sl@(g:_) = toList s 203 sl@(g:_) = toList s
255 tol = (fromIntegral (max (rows m) (cols m)) * g * t * eps) 204 tol = (fromIntegral (max (rows m) (cols m)) * g * t * eps)
256 rank = length (filter (> g*tol) sl) 205 rank = length (filter (> g*tol) sl)
257 ns = drop rank (toColumns v) 206-- ns = drop rank (toColumns v)
207 ns = drop rank $ toRows $ ctrans v
258 208
259-- | The nullspace of a real matrix, assumed to be one-dimensional, with default tolerance (shortcut for @last . nullspacePrec 1@). 209-- | The nullspace of a matrix, assumed to be one-dimensional, with default tolerance (shortcut for @last . nullspacePrec 1@).
260nullVector :: Matrix Double -> Vector Double 210nullVector :: GMatrix t => Matrix t -> Vector t
261nullVector = last . nullspacePrec 1 211nullVector = last . nullspacePrec 1
212
213------------------------------------------------------------------------
214
215{- | Pseudoinverse of a real matrix with the desired tolerance, expressed as a
216multiplicative factor of the default tolerance used by GNU-Octave (see 'pinv').
217
218@\> let m = 'fromLists' [[1,0, 0]
219 ,[0,1, 0]
220 ,[0,0,1e-10]]
221\
222\> 'pinv' m
2231. 0. 0.
2240. 1. 0.
2250. 0. 10000000000.
226\
227\> pinvTol 1E8 m
2281. 0. 0.
2290. 1. 0.
2300. 0. 1.@
231
232-}
233pinvTol :: Double -> Matrix Double -> Matrix Double
234pinvTol t m = v' `mXm` diag s' `mXm` trans u' where
235 (u,s,v) = svdR m
236 sl@(g:_) = toList s
237 s' = fromList . map rec $ sl
238 rec x = if x < g*tol then 1 else 1/x
239 tol = (fromIntegral (max (rows m) (cols m)) * g * t * eps)
240 r = rows m
241 c = cols m
242 d = dim s
243 u' = takeColumns d u
244 v' = takeColumns d v
diff --git a/lib/LinearAlgebra/Linear.hs b/lib/LinearAlgebra/Linear.hs
index c12e30b..2f1bc6f 100644
--- a/lib/LinearAlgebra/Linear.hs
+++ b/lib/LinearAlgebra/Linear.hs
@@ -37,6 +37,8 @@ class (Field e) => Linear c e where
37 fromComplex :: RealFloat e => c (Complex e) -> (c e, c e) 37 fromComplex :: RealFloat e => c (Complex e) -> (c e, c e)
38 comp :: RealFloat e => c e -> c (Complex e) 38 comp :: RealFloat e => c e -> c (Complex e)
39 conj :: RealFloat e => c (Complex e) -> c (Complex e) 39 conj :: RealFloat e => c (Complex e) -> c (Complex e)
40 real :: c Double -> c e
41 complex :: c e -> c (Complex Double)
40 42
41instance Linear Vector Double where 43instance Linear Vector Double where
42 scale = vectorMapValR Scale 44 scale = vectorMapValR Scale
@@ -50,6 +52,8 @@ instance Linear Vector Double where
50 fromComplex = Data.Packed.Internal.fromComplex 52 fromComplex = Data.Packed.Internal.fromComplex
51 comp = Data.Packed.Internal.comp 53 comp = Data.Packed.Internal.comp
52 conj = Data.Packed.Internal.conj 54 conj = Data.Packed.Internal.conj
55 real = id
56 complex = LinearAlgebra.Linear.comp
53 57
54instance Linear Vector (Complex Double) where 58instance Linear Vector (Complex Double) where
55 scale = vectorMapValC Scale 59 scale = vectorMapValC Scale
@@ -63,6 +67,8 @@ instance Linear Vector (Complex Double) where
63 fromComplex = undefined 67 fromComplex = undefined
64 comp = undefined 68 comp = undefined
65 conj = undefined 69 conj = undefined
70 real = LinearAlgebra.Linear.comp
71 complex = id
66 72
67instance Linear Matrix Double where 73instance Linear Matrix Double where
68 scale x = liftMatrix (scale x) 74 scale x = liftMatrix (scale x)
@@ -78,6 +84,8 @@ instance Linear Matrix Double where
78 c = cols z 84 c = cols z
79 comp = liftMatrix Data.Packed.Internal.comp 85 comp = liftMatrix Data.Packed.Internal.comp
80 conj = liftMatrix Data.Packed.Internal.conj 86 conj = liftMatrix Data.Packed.Internal.conj
87 real = id
88 complex = LinearAlgebra.Linear.comp
81 89
82instance Linear Matrix (Complex Double) where 90instance Linear Matrix (Complex Double) where
83 scale x = liftMatrix (scale x) 91 scale x = liftMatrix (scale x)
@@ -91,6 +99,8 @@ instance Linear Matrix (Complex Double) where
91 fromComplex = undefined 99 fromComplex = undefined
92 comp = undefined 100 comp = undefined
93 conj = undefined 101 conj = undefined
102 real = LinearAlgebra.Linear.comp
103 complex = id
94 104
95-------------------------------------------------- 105--------------------------------------------------
96 106