summaryrefslogtreecommitdiff
path: root/lib/Numeric/GSL
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Numeric/GSL')
-rw-r--r--lib/Numeric/GSL/Matrix.hs72
-rw-r--r--lib/Numeric/GSL/Minimization.hs8
2 files changed, 51 insertions, 29 deletions
diff --git a/lib/Numeric/GSL/Matrix.hs b/lib/Numeric/GSL/Matrix.hs
index 5a5c19e..e803c53 100644
--- a/lib/Numeric/GSL/Matrix.hs
+++ b/lib/Numeric/GSL/Matrix.hs
@@ -44,12 +44,14 @@ import Complex
44 44
45-} 45-}
46eigSg :: Matrix Double -> (Vector Double, Matrix Double) 46eigSg :: Matrix Double -> (Vector Double, Matrix Double)
47eigSg m 47eigSg = eigSg . cmat
48
49eigSg' m
48 | r == 1 = (fromList [cdat m `at` 0], singleton 1) 50 | r == 1 = (fromList [cdat m `at` 0], singleton 1)
49 | otherwise = unsafePerformIO $ do 51 | otherwise = unsafePerformIO $ do
50 l <- createVector r 52 l <- createVector r
51 v <- createMatrix RowMajor r r 53 v <- createMatrix RowMajor r r
52 c_eigS // mat cdat m // vec l // mat dat v // check "eigSg" [cdat m] 54 c_eigS // matc m // vec l // matc v // check "eigSg" [cdat m]
53 return (l,v) 55 return (l,v)
54 where r = rows m 56 where r = rows m
55foreign import ccall "gsl-aux.h eigensystemR" c_eigS :: TMVM 57foreign import ccall "gsl-aux.h eigensystemR" c_eigS :: TMVM
@@ -75,12 +77,14 @@ foreign import ccall "gsl-aux.h eigensystemR" c_eigS :: TMVM
75 77
76-} 78-}
77eigHg :: Matrix (Complex Double)-> (Vector Double, Matrix (Complex Double)) 79eigHg :: Matrix (Complex Double)-> (Vector Double, Matrix (Complex Double))
78eigHg m 80eigHg = eigHg' . cmat
81
82eigHg' m
79 | r == 1 = (fromList [realPart $ cdat m `at` 0], singleton 1) 83 | r == 1 = (fromList [realPart $ cdat m `at` 0], singleton 1)
80 | otherwise = unsafePerformIO $ do 84 | otherwise = unsafePerformIO $ do
81 l <- createVector r 85 l <- createVector r
82 v <- createMatrix RowMajor r r 86 v <- createMatrix RowMajor r r
83 c_eigH // mat cdat m // vec l // mat dat v // check "eigHg" [cdat m] 87 c_eigH // matc m // vec l // matc v // check "eigHg" [cdat m]
84 return (l,v) 88 return (l,v)
85 where r = rows m 89 where r = rows m
86foreign import ccall "gsl-aux.h eigensystemC" c_eigH :: TCMVCM 90foreign import ccall "gsl-aux.h eigensystemC" c_eigH :: TCMVCM
@@ -109,14 +113,14 @@ foreign import ccall "gsl-aux.h eigensystemC" c_eigH :: TCMVCM
109-} 113-}
110svdg :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) 114svdg :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double)
111svdg x = if rows x >= cols x 115svdg x = if rows x >= cols x
112 then svd' x 116 then svd' (cmat x)
113 else (v, s, u) where (u,s,v) = svd' (trans x) 117 else (v, s, u) where (u,s,v) = svd' (cmat (trans x))
114 118
115svd' x = unsafePerformIO $ do 119svd' x = unsafePerformIO $ do
116 u <- createMatrix RowMajor r c 120 u <- createMatrix RowMajor r c
117 s <- createVector c 121 s <- createVector c
118 v <- createMatrix RowMajor c c 122 v <- createMatrix RowMajor c c
119 c_svd // mat cdat x // mat dat u // vec s // mat dat v // check "svdg" [cdat x] 123 c_svd // matc x // matc u // vec s // matc v // check "svdg" [cdat x]
120 return (u,s,v) 124 return (u,s,v)
121 where r = rows x 125 where r = rows x
122 c = cols x 126 c = cols x
@@ -140,30 +144,36 @@ foreign import ccall "gsl-aux.h svd" c_svd :: TMMVM
140 144
141-} 145-}
142qr :: Matrix Double -> (Matrix Double, Matrix Double) 146qr :: Matrix Double -> (Matrix Double, Matrix Double)
143qr x = unsafePerformIO $ do 147qr = qr' . cmat
148
149qr' x = unsafePerformIO $ do
144 q <- createMatrix RowMajor r r 150 q <- createMatrix RowMajor r r
145 rot <- createMatrix RowMajor r c 151 rot <- createMatrix RowMajor r c
146 c_qr // mat cdat x // mat dat q // mat dat rot // check "qr" [cdat x] 152 c_qr // matc x // matc q // matc rot // check "qr" [cdat x]
147 return (q,rot) 153 return (q,rot)
148 where r = rows x 154 where r = rows x
149 c = cols x 155 c = cols x
150foreign import ccall "gsl-aux.h QR" c_qr :: TMMM 156foreign import ccall "gsl-aux.h QR" c_qr :: TMMM
151 157
152qrPacked :: Matrix Double -> (Matrix Double, Vector Double) 158qrPacked :: Matrix Double -> (Matrix Double, Vector Double)
153qrPacked x = unsafePerformIO $ do 159qrPacked = qrPacked' . cmat
160
161qrPacked' x = unsafePerformIO $ do
154 qr <- createMatrix RowMajor r c 162 qr <- createMatrix RowMajor r c
155 tau <- createVector (min r c) 163 tau <- createVector (min r c)
156 c_qrPacked // mat cdat x // mat dat qr // vec tau // check "qrUnpacked" [cdat x] 164 c_qrPacked // matc x // matc qr // vec tau // check "qrUnpacked" [cdat x]
157 return (qr,tau) 165 return (qr,tau)
158 where r = rows x 166 where r = rows x
159 c = cols x 167 c = cols x
160foreign import ccall "gsl-aux.h QRpacked" c_qrPacked :: TMMV 168foreign import ccall "gsl-aux.h QRpacked" c_qrPacked :: TMMV
161 169
162unpackQR :: (Matrix Double, Vector Double) -> (Matrix Double, Matrix Double) 170unpackQR :: (Matrix Double, Vector Double) -> (Matrix Double, Matrix Double)
163unpackQR (qr,tau) = unsafePerformIO $ do 171unpackQR (qr,tau) = unpackQR' (cmat qr, tau)
172
173unpackQR' (qr,tau) = unsafePerformIO $ do
164 q <- createMatrix RowMajor r r 174 q <- createMatrix RowMajor r r
165 rot <- createMatrix RowMajor r c 175 rot <- createMatrix RowMajor r c
166 c_qrUnpack // mat cdat qr // vec tau // mat dat q // mat dat rot // check "qrUnpack" [cdat qr,tau] 176 c_qrUnpack // matc qr // vec tau // matc q // matc rot // check "qrUnpack" [cdat qr,tau]
167 return (q,rot) 177 return (q,rot)
168 where r = rows qr 178 where r = rows qr
169 c = cols qr 179 c = cols qr
@@ -183,17 +193,21 @@ type TMVMM = Int -> Int -> PD -> Int -> PD -> TMM
183 193
184-} 194-}
185cholR :: Matrix Double -> Matrix Double 195cholR :: Matrix Double -> Matrix Double
186cholR x = unsafePerformIO $ do 196cholR = cholR' . cmat
197
198cholR' x = unsafePerformIO $ do
187 res <- createMatrix RowMajor r r 199 res <- createMatrix RowMajor r r
188 c_cholR // mat cdat x // mat dat res // check "cholR" [cdat x] 200 c_cholR // matc x // matc res // check "cholR" [cdat x]
189 return res 201 return res
190 where r = rows x 202 where r = rows x
191foreign import ccall "gsl-aux.h cholR" c_cholR :: TMM 203foreign import ccall "gsl-aux.h cholR" c_cholR :: TMM
192 204
193cholC :: Matrix (Complex Double) -> Matrix (Complex Double) 205cholC :: Matrix (Complex Double) -> Matrix (Complex Double)
194cholC x = unsafePerformIO $ do 206cholC = cholC' . cmat
207
208cholC' x = unsafePerformIO $ do
195 res <- createMatrix RowMajor r r 209 res <- createMatrix RowMajor r r
196 c_cholC // mat cdat x // mat dat res // check "cholC" [cdat x] 210 c_cholC // matc x // matc res // check "cholC" [cdat x]
197 return res 211 return res
198 where r = rows x 212 where r = rows x
199foreign import ccall "gsl-aux.h cholC" c_cholC :: TCMCM 213foreign import ccall "gsl-aux.h cholC" c_cholC :: TCMCM
@@ -204,10 +218,12 @@ foreign import ccall "gsl-aux.h cholC" c_cholC :: TCMCM
204{- -| efficient multiplication by the inverse of a matrix (for real matrices) 218{- -| efficient multiplication by the inverse of a matrix (for real matrices)
205-} 219-}
206luSolveR :: Matrix Double -> Matrix Double -> Matrix Double 220luSolveR :: Matrix Double -> Matrix Double -> Matrix Double
207luSolveR a b 221luSolveR a b = luSolveR' (cmat a) (cmat b)
222
223luSolveR' a b
208 | n1==n2 && n1==r = unsafePerformIO $ do 224 | n1==n2 && n1==r = unsafePerformIO $ do
209 s <- createMatrix RowMajor r c 225 s <- createMatrix RowMajor r c
210 c_luSolveR // mat cdat a // mat cdat b // mat dat s // check "luSolveR" [cdat a, cdat b] 226 c_luSolveR // matc a // matc b // matc s // check "luSolveR" [cdat a, cdat b]
211 return s 227 return s
212 | otherwise = error "luSolveR of nonsquare matrix" 228 | otherwise = error "luSolveR of nonsquare matrix"
213 where n1 = rows a 229 where n1 = rows a
@@ -219,10 +235,12 @@ foreign import ccall "gsl-aux.h luSolveR" c_luSolveR :: TMMM
219{- -| efficient multiplication by the inverse of a matrix (for complex matrices). 235{- -| efficient multiplication by the inverse of a matrix (for complex matrices).
220-} 236-}
221luSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) 237luSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
222luSolveC a b 238luSolveC a b = luSolveC' (cmat a) (cmat b)
239
240luSolveC' a b
223 | n1==n2 && n1==r = unsafePerformIO $ do 241 | n1==n2 && n1==r = unsafePerformIO $ do
224 s <- createMatrix RowMajor r c 242 s <- createMatrix RowMajor r c
225 c_luSolveC // mat cdat a // mat cdat b // mat dat s // check "luSolveC" [cdat a, cdat b] 243 c_luSolveC // matc a // matc b // matc s // check "luSolveC" [cdat a, cdat b]
226 return s 244 return s
227 | otherwise = error "luSolveC of nonsquare matrix" 245 | otherwise = error "luSolveC of nonsquare matrix"
228 where n1 = rows a 246 where n1 = rows a
@@ -234,9 +252,11 @@ foreign import ccall "gsl-aux.h luSolveC" c_luSolveC :: TCMCMCM
234{- | lu decomposition of real matrix (packed as a vector including l, u, the permutation and sign) 252{- | lu decomposition of real matrix (packed as a vector including l, u, the permutation and sign)
235-} 253-}
236luRaux :: Matrix Double -> Vector Double 254luRaux :: Matrix Double -> Vector Double
237luRaux x = unsafePerformIO $ do 255luRaux = luRaux' . cmat
256
257luRaux' x = unsafePerformIO $ do
238 res <- createVector (r*r+r+1) 258 res <- createVector (r*r+r+1)
239 c_luRaux // mat cdat x // vec res // check "luRaux" [cdat x] 259 c_luRaux // matc x // vec res // check "luRaux" [cdat x]
240 return res 260 return res
241 where r = rows x 261 where r = rows x
242 c = cols x 262 c = cols x
@@ -245,9 +265,11 @@ foreign import ccall "gsl-aux.h luRaux" c_luRaux :: TMV
245{- | lu decomposition of complex matrix (packed as a vector including l, u, the permutation and sign) 265{- | lu decomposition of complex matrix (packed as a vector including l, u, the permutation and sign)
246-} 266-}
247luCaux :: Matrix (Complex Double) -> Vector (Complex Double) 267luCaux :: Matrix (Complex Double) -> Vector (Complex Double)
248luCaux x = unsafePerformIO $ do 268luCaux = luCaux' . cmat
269
270luCaux' x = unsafePerformIO $ do
249 res <- createVector (r*r+r+1) 271 res <- createVector (r*r+r+1)
250 c_luCaux // mat cdat x // vec res // check "luCaux" [cdat x] 272 c_luCaux // matc x // vec res // check "luCaux" [cdat x]
251 return res 273 return res
252 where r = rows x 274 where r = rows x
253 c = cols x 275 c = cols x
diff --git a/lib/Numeric/GSL/Minimization.hs b/lib/Numeric/GSL/Minimization.hs
index e93f8cb..f523849 100644
--- a/lib/Numeric/GSL/Minimization.hs
+++ b/lib/Numeric/GSL/Minimization.hs
@@ -186,12 +186,12 @@ foreign import ccall "wrapper"
186 186
187aux_vTov :: (Vector Double -> Vector Double) -> (Int -> Ptr Double -> Ptr Double -> IO()) 187aux_vTov :: (Vector Double -> Vector Double) -> (Int -> Ptr Double -> Ptr Double -> IO())
188aux_vTov f n p r = g where 188aux_vTov f n p r = g where
189 V {fptr = pr, ptr = t} = f x 189 v@V {fptr = pr} = f x
190 x = createV n copy "aux_vTov" [] 190 x = createV n copy "aux_vTov" []
191 copy n q = do 191 copy n q = do
192 copyArray q p n 192 copyArray q p n
193 return 0 193 return 0
194 g = withForeignPtr pr $ \_ -> copyArray r t n 194 g = withForeignPtr pr $ \_ -> copyArray r (ptr v) n
195 195
196-------------------------------------------------------------------- 196--------------------------------------------------------------------
197 197
@@ -202,10 +202,10 @@ createV n fun msg ptrs = unsafePerformIO $ do
202 202
203createM r c fun msg ptrs = unsafePerformIO $ do 203createM r c fun msg ptrs = unsafePerformIO $ do
204 r <- createMatrix RowMajor r c 204 r <- createMatrix RowMajor r c
205 fun // mat cdat r // check msg ptrs 205 fun // matc r // check msg ptrs
206 return r 206 return r
207 207
208createMIO r c fun msg ptrs = do 208createMIO r c fun msg ptrs = do
209 r <- createMatrix RowMajor r c 209 r <- createMatrix RowMajor r c
210 fun // mat cdat r // check msg ptrs 210 fun // matc r // check msg ptrs
211 return r 211 return r