diff options
Diffstat (limited to 'lib/Numeric/GSL')
-rw-r--r-- | lib/Numeric/GSL/Matrix.hs | 72 | ||||
-rw-r--r-- | lib/Numeric/GSL/Minimization.hs | 8 |
2 files changed, 51 insertions, 29 deletions
diff --git a/lib/Numeric/GSL/Matrix.hs b/lib/Numeric/GSL/Matrix.hs index 5a5c19e..e803c53 100644 --- a/lib/Numeric/GSL/Matrix.hs +++ b/lib/Numeric/GSL/Matrix.hs | |||
@@ -44,12 +44,14 @@ import Complex | |||
44 | 44 | ||
45 | -} | 45 | -} |
46 | eigSg :: Matrix Double -> (Vector Double, Matrix Double) | 46 | eigSg :: Matrix Double -> (Vector Double, Matrix Double) |
47 | eigSg m | 47 | eigSg = eigSg . cmat |
48 | |||
49 | eigSg' m | ||
48 | | r == 1 = (fromList [cdat m `at` 0], singleton 1) | 50 | | r == 1 = (fromList [cdat m `at` 0], singleton 1) |
49 | | otherwise = unsafePerformIO $ do | 51 | | otherwise = unsafePerformIO $ do |
50 | l <- createVector r | 52 | l <- createVector r |
51 | v <- createMatrix RowMajor r r | 53 | v <- createMatrix RowMajor r r |
52 | c_eigS // mat cdat m // vec l // mat dat v // check "eigSg" [cdat m] | 54 | c_eigS // matc m // vec l // matc v // check "eigSg" [cdat m] |
53 | return (l,v) | 55 | return (l,v) |
54 | where r = rows m | 56 | where r = rows m |
55 | foreign import ccall "gsl-aux.h eigensystemR" c_eigS :: TMVM | 57 | foreign import ccall "gsl-aux.h eigensystemR" c_eigS :: TMVM |
@@ -75,12 +77,14 @@ foreign import ccall "gsl-aux.h eigensystemR" c_eigS :: TMVM | |||
75 | 77 | ||
76 | -} | 78 | -} |
77 | eigHg :: Matrix (Complex Double)-> (Vector Double, Matrix (Complex Double)) | 79 | eigHg :: Matrix (Complex Double)-> (Vector Double, Matrix (Complex Double)) |
78 | eigHg m | 80 | eigHg = eigHg' . cmat |
81 | |||
82 | eigHg' m | ||
79 | | r == 1 = (fromList [realPart $ cdat m `at` 0], singleton 1) | 83 | | r == 1 = (fromList [realPart $ cdat m `at` 0], singleton 1) |
80 | | otherwise = unsafePerformIO $ do | 84 | | otherwise = unsafePerformIO $ do |
81 | l <- createVector r | 85 | l <- createVector r |
82 | v <- createMatrix RowMajor r r | 86 | v <- createMatrix RowMajor r r |
83 | c_eigH // mat cdat m // vec l // mat dat v // check "eigHg" [cdat m] | 87 | c_eigH // matc m // vec l // matc v // check "eigHg" [cdat m] |
84 | return (l,v) | 88 | return (l,v) |
85 | where r = rows m | 89 | where r = rows m |
86 | foreign import ccall "gsl-aux.h eigensystemC" c_eigH :: TCMVCM | 90 | foreign import ccall "gsl-aux.h eigensystemC" c_eigH :: TCMVCM |
@@ -109,14 +113,14 @@ foreign import ccall "gsl-aux.h eigensystemC" c_eigH :: TCMVCM | |||
109 | -} | 113 | -} |
110 | svdg :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) | 114 | svdg :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) |
111 | svdg x = if rows x >= cols x | 115 | svdg x = if rows x >= cols x |
112 | then svd' x | 116 | then svd' (cmat x) |
113 | else (v, s, u) where (u,s,v) = svd' (trans x) | 117 | else (v, s, u) where (u,s,v) = svd' (cmat (trans x)) |
114 | 118 | ||
115 | svd' x = unsafePerformIO $ do | 119 | svd' x = unsafePerformIO $ do |
116 | u <- createMatrix RowMajor r c | 120 | u <- createMatrix RowMajor r c |
117 | s <- createVector c | 121 | s <- createVector c |
118 | v <- createMatrix RowMajor c c | 122 | v <- createMatrix RowMajor c c |
119 | c_svd // mat cdat x // mat dat u // vec s // mat dat v // check "svdg" [cdat x] | 123 | c_svd // matc x // matc u // vec s // matc v // check "svdg" [cdat x] |
120 | return (u,s,v) | 124 | return (u,s,v) |
121 | where r = rows x | 125 | where r = rows x |
122 | c = cols x | 126 | c = cols x |
@@ -140,30 +144,36 @@ foreign import ccall "gsl-aux.h svd" c_svd :: TMMVM | |||
140 | 144 | ||
141 | -} | 145 | -} |
142 | qr :: Matrix Double -> (Matrix Double, Matrix Double) | 146 | qr :: Matrix Double -> (Matrix Double, Matrix Double) |
143 | qr x = unsafePerformIO $ do | 147 | qr = qr' . cmat |
148 | |||
149 | qr' x = unsafePerformIO $ do | ||
144 | q <- createMatrix RowMajor r r | 150 | q <- createMatrix RowMajor r r |
145 | rot <- createMatrix RowMajor r c | 151 | rot <- createMatrix RowMajor r c |
146 | c_qr // mat cdat x // mat dat q // mat dat rot // check "qr" [cdat x] | 152 | c_qr // matc x // matc q // matc rot // check "qr" [cdat x] |
147 | return (q,rot) | 153 | return (q,rot) |
148 | where r = rows x | 154 | where r = rows x |
149 | c = cols x | 155 | c = cols x |
150 | foreign import ccall "gsl-aux.h QR" c_qr :: TMMM | 156 | foreign import ccall "gsl-aux.h QR" c_qr :: TMMM |
151 | 157 | ||
152 | qrPacked :: Matrix Double -> (Matrix Double, Vector Double) | 158 | qrPacked :: Matrix Double -> (Matrix Double, Vector Double) |
153 | qrPacked x = unsafePerformIO $ do | 159 | qrPacked = qrPacked' . cmat |
160 | |||
161 | qrPacked' x = unsafePerformIO $ do | ||
154 | qr <- createMatrix RowMajor r c | 162 | qr <- createMatrix RowMajor r c |
155 | tau <- createVector (min r c) | 163 | tau <- createVector (min r c) |
156 | c_qrPacked // mat cdat x // mat dat qr // vec tau // check "qrUnpacked" [cdat x] | 164 | c_qrPacked // matc x // matc qr // vec tau // check "qrUnpacked" [cdat x] |
157 | return (qr,tau) | 165 | return (qr,tau) |
158 | where r = rows x | 166 | where r = rows x |
159 | c = cols x | 167 | c = cols x |
160 | foreign import ccall "gsl-aux.h QRpacked" c_qrPacked :: TMMV | 168 | foreign import ccall "gsl-aux.h QRpacked" c_qrPacked :: TMMV |
161 | 169 | ||
162 | unpackQR :: (Matrix Double, Vector Double) -> (Matrix Double, Matrix Double) | 170 | unpackQR :: (Matrix Double, Vector Double) -> (Matrix Double, Matrix Double) |
163 | unpackQR (qr,tau) = unsafePerformIO $ do | 171 | unpackQR (qr,tau) = unpackQR' (cmat qr, tau) |
172 | |||
173 | unpackQR' (qr,tau) = unsafePerformIO $ do | ||
164 | q <- createMatrix RowMajor r r | 174 | q <- createMatrix RowMajor r r |
165 | rot <- createMatrix RowMajor r c | 175 | rot <- createMatrix RowMajor r c |
166 | c_qrUnpack // mat cdat qr // vec tau // mat dat q // mat dat rot // check "qrUnpack" [cdat qr,tau] | 176 | c_qrUnpack // matc qr // vec tau // matc q // matc rot // check "qrUnpack" [cdat qr,tau] |
167 | return (q,rot) | 177 | return (q,rot) |
168 | where r = rows qr | 178 | where r = rows qr |
169 | c = cols qr | 179 | c = cols qr |
@@ -183,17 +193,21 @@ type TMVMM = Int -> Int -> PD -> Int -> PD -> TMM | |||
183 | 193 | ||
184 | -} | 194 | -} |
185 | cholR :: Matrix Double -> Matrix Double | 195 | cholR :: Matrix Double -> Matrix Double |
186 | cholR x = unsafePerformIO $ do | 196 | cholR = cholR' . cmat |
197 | |||
198 | cholR' x = unsafePerformIO $ do | ||
187 | res <- createMatrix RowMajor r r | 199 | res <- createMatrix RowMajor r r |
188 | c_cholR // mat cdat x // mat dat res // check "cholR" [cdat x] | 200 | c_cholR // matc x // matc res // check "cholR" [cdat x] |
189 | return res | 201 | return res |
190 | where r = rows x | 202 | where r = rows x |
191 | foreign import ccall "gsl-aux.h cholR" c_cholR :: TMM | 203 | foreign import ccall "gsl-aux.h cholR" c_cholR :: TMM |
192 | 204 | ||
193 | cholC :: Matrix (Complex Double) -> Matrix (Complex Double) | 205 | cholC :: Matrix (Complex Double) -> Matrix (Complex Double) |
194 | cholC x = unsafePerformIO $ do | 206 | cholC = cholC' . cmat |
207 | |||
208 | cholC' x = unsafePerformIO $ do | ||
195 | res <- createMatrix RowMajor r r | 209 | res <- createMatrix RowMajor r r |
196 | c_cholC // mat cdat x // mat dat res // check "cholC" [cdat x] | 210 | c_cholC // matc x // matc res // check "cholC" [cdat x] |
197 | return res | 211 | return res |
198 | where r = rows x | 212 | where r = rows x |
199 | foreign import ccall "gsl-aux.h cholC" c_cholC :: TCMCM | 213 | foreign import ccall "gsl-aux.h cholC" c_cholC :: TCMCM |
@@ -204,10 +218,12 @@ foreign import ccall "gsl-aux.h cholC" c_cholC :: TCMCM | |||
204 | {- -| efficient multiplication by the inverse of a matrix (for real matrices) | 218 | {- -| efficient multiplication by the inverse of a matrix (for real matrices) |
205 | -} | 219 | -} |
206 | luSolveR :: Matrix Double -> Matrix Double -> Matrix Double | 220 | luSolveR :: Matrix Double -> Matrix Double -> Matrix Double |
207 | luSolveR a b | 221 | luSolveR a b = luSolveR' (cmat a) (cmat b) |
222 | |||
223 | luSolveR' a b | ||
208 | | n1==n2 && n1==r = unsafePerformIO $ do | 224 | | n1==n2 && n1==r = unsafePerformIO $ do |
209 | s <- createMatrix RowMajor r c | 225 | s <- createMatrix RowMajor r c |
210 | c_luSolveR // mat cdat a // mat cdat b // mat dat s // check "luSolveR" [cdat a, cdat b] | 226 | c_luSolveR // matc a // matc b // matc s // check "luSolveR" [cdat a, cdat b] |
211 | return s | 227 | return s |
212 | | otherwise = error "luSolveR of nonsquare matrix" | 228 | | otherwise = error "luSolveR of nonsquare matrix" |
213 | where n1 = rows a | 229 | where n1 = rows a |
@@ -219,10 +235,12 @@ foreign import ccall "gsl-aux.h luSolveR" c_luSolveR :: TMMM | |||
219 | {- -| efficient multiplication by the inverse of a matrix (for complex matrices). | 235 | {- -| efficient multiplication by the inverse of a matrix (for complex matrices). |
220 | -} | 236 | -} |
221 | luSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | 237 | luSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) |
222 | luSolveC a b | 238 | luSolveC a b = luSolveC' (cmat a) (cmat b) |
239 | |||
240 | luSolveC' a b | ||
223 | | n1==n2 && n1==r = unsafePerformIO $ do | 241 | | n1==n2 && n1==r = unsafePerformIO $ do |
224 | s <- createMatrix RowMajor r c | 242 | s <- createMatrix RowMajor r c |
225 | c_luSolveC // mat cdat a // mat cdat b // mat dat s // check "luSolveC" [cdat a, cdat b] | 243 | c_luSolveC // matc a // matc b // matc s // check "luSolveC" [cdat a, cdat b] |
226 | return s | 244 | return s |
227 | | otherwise = error "luSolveC of nonsquare matrix" | 245 | | otherwise = error "luSolveC of nonsquare matrix" |
228 | where n1 = rows a | 246 | where n1 = rows a |
@@ -234,9 +252,11 @@ foreign import ccall "gsl-aux.h luSolveC" c_luSolveC :: TCMCMCM | |||
234 | {- | lu decomposition of real matrix (packed as a vector including l, u, the permutation and sign) | 252 | {- | lu decomposition of real matrix (packed as a vector including l, u, the permutation and sign) |
235 | -} | 253 | -} |
236 | luRaux :: Matrix Double -> Vector Double | 254 | luRaux :: Matrix Double -> Vector Double |
237 | luRaux x = unsafePerformIO $ do | 255 | luRaux = luRaux' . cmat |
256 | |||
257 | luRaux' x = unsafePerformIO $ do | ||
238 | res <- createVector (r*r+r+1) | 258 | res <- createVector (r*r+r+1) |
239 | c_luRaux // mat cdat x // vec res // check "luRaux" [cdat x] | 259 | c_luRaux // matc x // vec res // check "luRaux" [cdat x] |
240 | return res | 260 | return res |
241 | where r = rows x | 261 | where r = rows x |
242 | c = cols x | 262 | c = cols x |
@@ -245,9 +265,11 @@ foreign import ccall "gsl-aux.h luRaux" c_luRaux :: TMV | |||
245 | {- | lu decomposition of complex matrix (packed as a vector including l, u, the permutation and sign) | 265 | {- | lu decomposition of complex matrix (packed as a vector including l, u, the permutation and sign) |
246 | -} | 266 | -} |
247 | luCaux :: Matrix (Complex Double) -> Vector (Complex Double) | 267 | luCaux :: Matrix (Complex Double) -> Vector (Complex Double) |
248 | luCaux x = unsafePerformIO $ do | 268 | luCaux = luCaux' . cmat |
269 | |||
270 | luCaux' x = unsafePerformIO $ do | ||
249 | res <- createVector (r*r+r+1) | 271 | res <- createVector (r*r+r+1) |
250 | c_luCaux // mat cdat x // vec res // check "luCaux" [cdat x] | 272 | c_luCaux // matc x // vec res // check "luCaux" [cdat x] |
251 | return res | 273 | return res |
252 | where r = rows x | 274 | where r = rows x |
253 | c = cols x | 275 | c = cols x |
diff --git a/lib/Numeric/GSL/Minimization.hs b/lib/Numeric/GSL/Minimization.hs index e93f8cb..f523849 100644 --- a/lib/Numeric/GSL/Minimization.hs +++ b/lib/Numeric/GSL/Minimization.hs | |||
@@ -186,12 +186,12 @@ foreign import ccall "wrapper" | |||
186 | 186 | ||
187 | aux_vTov :: (Vector Double -> Vector Double) -> (Int -> Ptr Double -> Ptr Double -> IO()) | 187 | aux_vTov :: (Vector Double -> Vector Double) -> (Int -> Ptr Double -> Ptr Double -> IO()) |
188 | aux_vTov f n p r = g where | 188 | aux_vTov f n p r = g where |
189 | V {fptr = pr, ptr = t} = f x | 189 | v@V {fptr = pr} = f x |
190 | x = createV n copy "aux_vTov" [] | 190 | x = createV n copy "aux_vTov" [] |
191 | copy n q = do | 191 | copy n q = do |
192 | copyArray q p n | 192 | copyArray q p n |
193 | return 0 | 193 | return 0 |
194 | g = withForeignPtr pr $ \_ -> copyArray r t n | 194 | g = withForeignPtr pr $ \_ -> copyArray r (ptr v) n |
195 | 195 | ||
196 | -------------------------------------------------------------------- | 196 | -------------------------------------------------------------------- |
197 | 197 | ||
@@ -202,10 +202,10 @@ createV n fun msg ptrs = unsafePerformIO $ do | |||
202 | 202 | ||
203 | createM r c fun msg ptrs = unsafePerformIO $ do | 203 | createM r c fun msg ptrs = unsafePerformIO $ do |
204 | r <- createMatrix RowMajor r c | 204 | r <- createMatrix RowMajor r c |
205 | fun // mat cdat r // check msg ptrs | 205 | fun // matc r // check msg ptrs |
206 | return r | 206 | return r |
207 | 207 | ||
208 | createMIO r c fun msg ptrs = do | 208 | createMIO r c fun msg ptrs = do |
209 | r <- createMatrix RowMajor r c | 209 | r <- createMatrix RowMajor r c |
210 | fun // mat cdat r // check msg ptrs | 210 | fun // matc r // check msg ptrs |
211 | return r | 211 | return r |