summaryrefslogtreecommitdiff
path: root/lib/GSL
diff options
context:
space:
mode:
Diffstat (limited to 'lib/GSL')
-rw-r--r--lib/GSL/Compat.hs3
-rw-r--r--lib/GSL/Matrix.hs42
2 files changed, 31 insertions, 14 deletions
diff --git a/lib/GSL/Compat.hs b/lib/GSL/Compat.hs
index 809a1f5..1d6f7b9 100644
--- a/lib/GSL/Compat.hs
+++ b/lib/GSL/Compat.hs
@@ -38,7 +38,7 @@ adaptScalar f1 f2 f3 x y
38 | dim y == 1 = f3 x (y@>0) 38 | dim y == 1 = f3 x (y@>0)
39 | otherwise = f2 x y 39 | otherwise = f2 x y
40 40
41liftMatrix2' :: (Field t) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t 41liftMatrix2' :: (Field t, Field a, Field b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
42liftMatrix2' f m1 m2 | compat' m1 m2 = reshape (max (cols m1) (cols m2)) (f (cdat m1) (cdat m2)) 42liftMatrix2' f m1 m2 | compat' m1 m2 = reshape (max (cols m1) (cols m2)) (f (cdat m1) (cdat m2))
43 | otherwise = error "nonconformant matrices in liftMatrix2'" 43 | otherwise = error "nonconformant matrices in liftMatrix2'"
44 44
@@ -63,6 +63,7 @@ instance (Eq a, Field a) => Eq (Matrix a) where
63 63
64instance (Field a, Linear Vector a) => Num (Matrix a) where 64instance (Field a, Linear Vector a) => Num (Matrix a) where
65 (+) = liftMatrix2' (+) 65 (+) = liftMatrix2' (+)
66 (-) = liftMatrix2' (-)
66 negate = liftMatrix negate 67 negate = liftMatrix negate
67 (*) = liftMatrix2' (*) 68 (*) = liftMatrix2' (*)
68 signum = liftMatrix signum 69 signum = liftMatrix signum
diff --git a/lib/GSL/Matrix.hs b/lib/GSL/Matrix.hs
index 26c5e2a..15710df 100644
--- a/lib/GSL/Matrix.hs
+++ b/lib/GSL/Matrix.hs
@@ -46,13 +46,14 @@ import Foreign.C.String
46 46
47-} 47-}
48eigSg :: Matrix Double -> (Vector Double, Matrix Double) 48eigSg :: Matrix Double -> (Vector Double, Matrix Double)
49eigSg (m@M {rows = r}) 49eigSg m
50 | r == 1 = (fromList [cdat m `at` 0], singleton 1) 50 | r == 1 = (fromList [cdat m `at` 0], singleton 1)
51 | otherwise = unsafePerformIO $ do 51 | otherwise = unsafePerformIO $ do
52 l <- createVector r 52 l <- createVector r
53 v <- createMatrix RowMajor r r 53 v <- createMatrix RowMajor r r
54 c_eigS // mat cdat m // vec l // mat dat v // check "eigSg" [cdat m] 54 c_eigS // mat cdat m // vec l // mat dat v // check "eigSg" [cdat m]
55 return (l,v) 55 return (l,v)
56 where r = rows m
56foreign import ccall "gsl-aux.h eigensystemR" c_eigS :: TMVM 57foreign import ccall "gsl-aux.h eigensystemR" c_eigS :: TMVM
57 58
58------------------------------------------------------------------ 59------------------------------------------------------------------
@@ -76,13 +77,14 @@ foreign import ccall "gsl-aux.h eigensystemR" c_eigS :: TMVM
76 77
77-} 78-}
78eigHg :: Matrix (Complex Double)-> (Vector Double, Matrix (Complex Double)) 79eigHg :: Matrix (Complex Double)-> (Vector Double, Matrix (Complex Double))
79eigHg (m@M {rows = r}) 80eigHg m
80 | r == 1 = (fromList [realPart $ cdat m `at` 0], singleton 1) 81 | r == 1 = (fromList [realPart $ cdat m `at` 0], singleton 1)
81 | otherwise = unsafePerformIO $ do 82 | otherwise = unsafePerformIO $ do
82 l <- createVector r 83 l <- createVector r
83 v <- createMatrix RowMajor r r 84 v <- createMatrix RowMajor r r
84 c_eigH // mat cdat m // vec l // mat dat v // check "eigHg" [cdat m] 85 c_eigH // mat cdat m // vec l // mat dat v // check "eigHg" [cdat m]
85 return (l,v) 86 return (l,v)
87 where r = rows m
86foreign import ccall "gsl-aux.h eigensystemC" c_eigH :: TCMVCM 88foreign import ccall "gsl-aux.h eigensystemC" c_eigH :: TCMVCM
87 89
88 90
@@ -108,16 +110,18 @@ foreign import ccall "gsl-aux.h eigensystemC" c_eigH :: TCMVCM
108 110
109-} 111-}
110svdg :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) 112svdg :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double)
111svdg x@M {rows = r, cols = c} = if r>=c 113svdg x = if rows x >= cols x
112 then svd' x 114 then svd' x
113 else (v, s, u) where (u,s,v) = svd' (trans x) 115 else (v, s, u) where (u,s,v) = svd' (trans x)
114 116
115svd' x@M {rows = r, cols = c} = unsafePerformIO $ do 117svd' x = unsafePerformIO $ do
116 u <- createMatrix RowMajor r c 118 u <- createMatrix RowMajor r c
117 s <- createVector c 119 s <- createVector c
118 v <- createMatrix RowMajor c c 120 v <- createMatrix RowMajor c c
119 c_svd // mat cdat x // mat dat u // vec s // mat dat v // check "svdg" [cdat x] 121 c_svd // mat cdat x // mat dat u // vec s // mat dat v // check "svdg" [cdat x]
120 return (u,s,v) 122 return (u,s,v)
123 where r = rows x
124 c = cols x
121foreign import ccall "gsl-aux.h svd" c_svd :: TMMVM 125foreign import ccall "gsl-aux.h svd" c_svd :: TMMVM
122 126
123{- | QR decomposition of a real matrix using /gsl_linalg_QR_decomp/ and /gsl_linalg_QR_unpack/. 127{- | QR decomposition of a real matrix using /gsl_linalg_QR_decomp/ and /gsl_linalg_QR_unpack/.
@@ -138,11 +142,13 @@ foreign import ccall "gsl-aux.h svd" c_svd :: TMMVM
138 142
139-} 143-}
140qr :: Matrix Double -> (Matrix Double, Matrix Double) 144qr :: Matrix Double -> (Matrix Double, Matrix Double)
141qr x@M {rows = r, cols = c} = unsafePerformIO $ do 145qr x = unsafePerformIO $ do
142 q <- createMatrix RowMajor r r 146 q <- createMatrix RowMajor r r
143 rot <- createMatrix RowMajor r c 147 rot <- createMatrix RowMajor r c
144 c_qr // mat cdat x // mat dat q // mat dat rot // check "qr" [cdat x] 148 c_qr // mat cdat x // mat dat q // mat dat rot // check "qr" [cdat x]
145 return (q,rot) 149 return (q,rot)
150 where r = rows x
151 c = cols x
146foreign import ccall "gsl-aux.h QR" c_qr :: TMMM 152foreign import ccall "gsl-aux.h QR" c_qr :: TMMM
147 153
148{- | Cholesky decomposition of a symmetric positive definite real matrix using /gsl_linalg_cholesky_decomp/. 154{- | Cholesky decomposition of a symmetric positive definite real matrix using /gsl_linalg_cholesky_decomp/.
@@ -159,11 +165,11 @@ foreign import ccall "gsl-aux.h QR" c_qr :: TMMM
159 165
160-} 166-}
161chol :: Matrix Double -> Matrix Double 167chol :: Matrix Double -> Matrix Double
162--chol x@(M r _ p) = createM [p] "chol" r r $ m c_chol x 168chol x = unsafePerformIO $ do
163chol x@M {rows = r} = unsafePerformIO $ do
164 res <- createMatrix RowMajor r r 169 res <- createMatrix RowMajor r r
165 c_chol // mat cdat x // mat dat res // check "chol" [cdat x] 170 c_chol // mat cdat x // mat dat res // check "chol" [cdat x]
166 return res 171 return res
172 where r = rows x
167foreign import ccall "gsl-aux.h chol" c_chol :: TMM 173foreign import ccall "gsl-aux.h chol" c_chol :: TMM
168 174
169-------------------------------------------------------- 175--------------------------------------------------------
@@ -171,43 +177,53 @@ foreign import ccall "gsl-aux.h chol" c_chol :: TMM
171{- -| efficient multiplication by the inverse of a matrix (for real matrices) 177{- -| efficient multiplication by the inverse of a matrix (for real matrices)
172-} 178-}
173luSolveR :: Matrix Double -> Matrix Double -> Matrix Double 179luSolveR :: Matrix Double -> Matrix Double -> Matrix Double
174luSolveR a@(M {rows = n1, cols = n2}) b@(M {rows = r, cols = c}) 180luSolveR a b
175 | n1==n2 && n1==r = unsafePerformIO $ do 181 | n1==n2 && n1==r = unsafePerformIO $ do
176 s <- createMatrix RowMajor r c 182 s <- createMatrix RowMajor r c
177 c_luSolveR // mat cdat a // mat cdat b // mat dat s // check "luSolveR" [cdat a, cdat b] 183 c_luSolveR // mat cdat a // mat cdat b // mat dat s // check "luSolveR" [cdat a, cdat b]
178 return s 184 return s
179 | otherwise = error "luSolveR of nonsquare matrix" 185 | otherwise = error "luSolveR of nonsquare matrix"
180 186 where n1 = rows a
187 n2 = cols a
188 r = rows b
189 c = cols b
181foreign import ccall "gsl-aux.h luSolveR" c_luSolveR :: TMMM 190foreign import ccall "gsl-aux.h luSolveR" c_luSolveR :: TMMM
182 191
183{- -| efficient multiplication by the inverse of a matrix (for complex matrices). 192{- -| efficient multiplication by the inverse of a matrix (for complex matrices).
184-} 193-}
185luSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) 194luSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
186luSolveC a@(M {rows = n1, cols = n2}) b@(M {rows = r, cols = c}) 195luSolveC a b
187 | n1==n2 && n1==r = unsafePerformIO $ do 196 | n1==n2 && n1==r = unsafePerformIO $ do
188 s <- createMatrix RowMajor r c 197 s <- createMatrix RowMajor r c
189 c_luSolveC // mat cdat a // mat cdat b // mat dat s // check "luSolveC" [cdat a, cdat b] 198 c_luSolveC // mat cdat a // mat cdat b // mat dat s // check "luSolveC" [cdat a, cdat b]
190 return s 199 return s
191 | otherwise = error "luSolveC of nonsquare matrix" 200 | otherwise = error "luSolveC of nonsquare matrix"
192 201 where n1 = rows a
202 n2 = cols a
203 r = rows b
204 c = cols b
193foreign import ccall "gsl-aux.h luSolveC" c_luSolveC :: TCMCMCM 205foreign import ccall "gsl-aux.h luSolveC" c_luSolveC :: TCMCMCM
194 206
195{- | lu decomposition of real matrix (packed as a vector including l, u, the permutation and sign) 207{- | lu decomposition of real matrix (packed as a vector including l, u, the permutation and sign)
196-} 208-}
197luRaux :: Matrix Double -> Vector Double 209luRaux :: Matrix Double -> Vector Double
198luRaux x@M {rows = r, cols = c} = unsafePerformIO $ do 210luRaux x = unsafePerformIO $ do
199 res <- createVector (r*r+r+1) 211 res <- createVector (r*r+r+1)
200 c_luRaux // mat cdat x // vec res // check "luRaux" [cdat x] 212 c_luRaux // mat cdat x // vec res // check "luRaux" [cdat x]
201 return res 213 return res
214 where r = rows x
215 c = cols x
202foreign import ccall "gsl-aux.h luRaux" c_luRaux :: TMV 216foreign import ccall "gsl-aux.h luRaux" c_luRaux :: TMV
203 217
204{- | lu decomposition of complex matrix (packed as a vector including l, u, the permutation and sign) 218{- | lu decomposition of complex matrix (packed as a vector including l, u, the permutation and sign)
205-} 219-}
206luCaux :: Matrix (Complex Double) -> Vector (Complex Double) 220luCaux :: Matrix (Complex Double) -> Vector (Complex Double)
207luCaux x@M {rows = r, cols = c} = unsafePerformIO $ do 221luCaux x = unsafePerformIO $ do
208 res <- createVector (r*r+r+1) 222 res <- createVector (r*r+r+1)
209 c_luCaux // mat cdat x // vec res // check "luCaux" [cdat x] 223 c_luCaux // mat cdat x // vec res // check "luCaux" [cdat x]
210 return res 224 return res
225 where r = rows x
226 c = cols x
211foreign import ccall "gsl-aux.h luCaux" c_luCaux :: TCMCV 227foreign import ccall "gsl-aux.h luCaux" c_luCaux :: TCMCV
212 228
213{- | The LU decomposition of a square matrix. Is based on /gsl_linalg_LU_decomp/ and /gsl_linalg_complex_LU_decomp/ as described in <http://www.gnu.org/software/gsl/manual/gsl-ref_13.html#SEC223>. 229{- | The LU decomposition of a square matrix. Is based on /gsl_linalg_LU_decomp/ and /gsl_linalg_complex_LU_decomp/ as described in <http://www.gnu.org/software/gsl/manual/gsl-ref_13.html#SEC223>.