diff options
Diffstat (limited to 'lib/Numeric/GSL/Matrix.hs')
-rw-r--r-- | lib/Numeric/GSL/Matrix.hs | 52 |
1 files changed, 32 insertions, 20 deletions
diff --git a/lib/Numeric/GSL/Matrix.hs b/lib/Numeric/GSL/Matrix.hs index e803c53..09a0be4 100644 --- a/lib/Numeric/GSL/Matrix.hs +++ b/lib/Numeric/GSL/Matrix.hs | |||
@@ -51,7 +51,8 @@ eigSg' m | |||
51 | | otherwise = unsafePerformIO $ do | 51 | | otherwise = unsafePerformIO $ do |
52 | l <- createVector r | 52 | l <- createVector r |
53 | v <- createMatrix RowMajor r r | 53 | v <- createMatrix RowMajor r r |
54 | c_eigS // matc m // vec l // matc v // check "eigSg" [cdat m] | 54 | ww3 withMatrix m withVector l withMatrix v $ \m l v -> |
55 | c_eigS // m // l // v // check "eigSg" | ||
55 | return (l,v) | 56 | return (l,v) |
56 | where r = rows m | 57 | where r = rows m |
57 | foreign import ccall "gsl-aux.h eigensystemR" c_eigS :: TMVM | 58 | foreign import ccall "gsl-aux.h eigensystemR" c_eigS :: TMVM |
@@ -84,7 +85,8 @@ eigHg' m | |||
84 | | otherwise = unsafePerformIO $ do | 85 | | otherwise = unsafePerformIO $ do |
85 | l <- createVector r | 86 | l <- createVector r |
86 | v <- createMatrix RowMajor r r | 87 | v <- createMatrix RowMajor r r |
87 | c_eigH // matc m // vec l // matc v // check "eigHg" [cdat m] | 88 | ww3 withMatrix m withVector l withMatrix v $ \m l v -> |
89 | c_eigH // m // l // v // check "eigHg" | ||
88 | return (l,v) | 90 | return (l,v) |
89 | where r = rows m | 91 | where r = rows m |
90 | foreign import ccall "gsl-aux.h eigensystemC" c_eigH :: TCMVCM | 92 | foreign import ccall "gsl-aux.h eigensystemC" c_eigH :: TCMVCM |
@@ -120,7 +122,8 @@ svd' x = unsafePerformIO $ do | |||
120 | u <- createMatrix RowMajor r c | 122 | u <- createMatrix RowMajor r c |
121 | s <- createVector c | 123 | s <- createVector c |
122 | v <- createMatrix RowMajor c c | 124 | v <- createMatrix RowMajor c c |
123 | c_svd // matc x // matc u // vec s // matc v // check "svdg" [cdat x] | 125 | ww4 withMatrix x withMatrix u withVector s withMatrix v $ \x u s v -> |
126 | c_svd // x // u // s // v // check "svdg" | ||
124 | return (u,s,v) | 127 | return (u,s,v) |
125 | where r = rows x | 128 | where r = rows x |
126 | c = cols x | 129 | c = cols x |
@@ -149,7 +152,8 @@ qr = qr' . cmat | |||
149 | qr' x = unsafePerformIO $ do | 152 | qr' x = unsafePerformIO $ do |
150 | q <- createMatrix RowMajor r r | 153 | q <- createMatrix RowMajor r r |
151 | rot <- createMatrix RowMajor r c | 154 | rot <- createMatrix RowMajor r c |
152 | c_qr // matc x // matc q // matc rot // check "qr" [cdat x] | 155 | ww3 withMatrix x withMatrix q withMatrix rot $ \x q rot -> |
156 | c_qr // x // q // rot // check "qr" | ||
153 | return (q,rot) | 157 | return (q,rot) |
154 | where r = rows x | 158 | where r = rows x |
155 | c = cols x | 159 | c = cols x |
@@ -161,7 +165,8 @@ qrPacked = qrPacked' . cmat | |||
161 | qrPacked' x = unsafePerformIO $ do | 165 | qrPacked' x = unsafePerformIO $ do |
162 | qr <- createMatrix RowMajor r c | 166 | qr <- createMatrix RowMajor r c |
163 | tau <- createVector (min r c) | 167 | tau <- createVector (min r c) |
164 | c_qrPacked // matc x // matc qr // vec tau // check "qrUnpacked" [cdat x] | 168 | ww3 withMatrix x withMatrix qr withVector tau $ \x qr tau -> |
169 | c_qrPacked // x // qr // tau // check "qrUnpacked" | ||
165 | return (qr,tau) | 170 | return (qr,tau) |
166 | where r = rows x | 171 | where r = rows x |
167 | c = cols x | 172 | c = cols x |
@@ -172,9 +177,10 @@ unpackQR (qr,tau) = unpackQR' (cmat qr, tau) | |||
172 | 177 | ||
173 | unpackQR' (qr,tau) = unsafePerformIO $ do | 178 | unpackQR' (qr,tau) = unsafePerformIO $ do |
174 | q <- createMatrix RowMajor r r | 179 | q <- createMatrix RowMajor r r |
175 | rot <- createMatrix RowMajor r c | 180 | res <- createMatrix RowMajor r c |
176 | c_qrUnpack // matc qr // vec tau // matc q // matc rot // check "qrUnpack" [cdat qr,tau] | 181 | ww4 withMatrix qr withVector tau withMatrix q withMatrix res $ \qr tau q res -> |
177 | return (q,rot) | 182 | c_qrUnpack // qr // tau // q // res // check "qrUnpack" |
183 | return (q,res) | ||
178 | where r = rows qr | 184 | where r = rows qr |
179 | c = cols qr | 185 | c = cols qr |
180 | foreign import ccall "gsl-aux.h QRunpack" c_qrUnpack :: TMVMM | 186 | foreign import ccall "gsl-aux.h QRunpack" c_qrUnpack :: TMVMM |
@@ -196,20 +202,22 @@ cholR :: Matrix Double -> Matrix Double | |||
196 | cholR = cholR' . cmat | 202 | cholR = cholR' . cmat |
197 | 203 | ||
198 | cholR' x = unsafePerformIO $ do | 204 | cholR' x = unsafePerformIO $ do |
199 | res <- createMatrix RowMajor r r | 205 | r <- createMatrix RowMajor n n |
200 | c_cholR // matc x // matc res // check "cholR" [cdat x] | 206 | ww2 withMatrix x withMatrix r $ \x r -> |
201 | return res | 207 | c_cholR // x // r // check "cholR" |
202 | where r = rows x | 208 | return r |
209 | where n = rows x | ||
203 | foreign import ccall "gsl-aux.h cholR" c_cholR :: TMM | 210 | foreign import ccall "gsl-aux.h cholR" c_cholR :: TMM |
204 | 211 | ||
205 | cholC :: Matrix (Complex Double) -> Matrix (Complex Double) | 212 | cholC :: Matrix (Complex Double) -> Matrix (Complex Double) |
206 | cholC = cholC' . cmat | 213 | cholC = cholC' . cmat |
207 | 214 | ||
208 | cholC' x = unsafePerformIO $ do | 215 | cholC' x = unsafePerformIO $ do |
209 | res <- createMatrix RowMajor r r | 216 | r <- createMatrix RowMajor n n |
210 | c_cholC // matc x // matc res // check "cholC" [cdat x] | 217 | ww2 withMatrix x withMatrix r $ \x r -> |
211 | return res | 218 | c_cholC // x // r // check "cholC" |
212 | where r = rows x | 219 | return r |
220 | where n = rows x | ||
213 | foreign import ccall "gsl-aux.h cholC" c_cholC :: TCMCM | 221 | foreign import ccall "gsl-aux.h cholC" c_cholC :: TCMCM |
214 | 222 | ||
215 | 223 | ||
@@ -223,7 +231,8 @@ luSolveR a b = luSolveR' (cmat a) (cmat b) | |||
223 | luSolveR' a b | 231 | luSolveR' a b |
224 | | n1==n2 && n1==r = unsafePerformIO $ do | 232 | | n1==n2 && n1==r = unsafePerformIO $ do |
225 | s <- createMatrix RowMajor r c | 233 | s <- createMatrix RowMajor r c |
226 | c_luSolveR // matc a // matc b // matc s // check "luSolveR" [cdat a, cdat b] | 234 | ww3 withMatrix a withMatrix b withMatrix s $ \ a b s -> |
235 | c_luSolveR // a // b // s // check "luSolveR" | ||
227 | return s | 236 | return s |
228 | | otherwise = error "luSolveR of nonsquare matrix" | 237 | | otherwise = error "luSolveR of nonsquare matrix" |
229 | where n1 = rows a | 238 | where n1 = rows a |
@@ -240,7 +249,8 @@ luSolveC a b = luSolveC' (cmat a) (cmat b) | |||
240 | luSolveC' a b | 249 | luSolveC' a b |
241 | | n1==n2 && n1==r = unsafePerformIO $ do | 250 | | n1==n2 && n1==r = unsafePerformIO $ do |
242 | s <- createMatrix RowMajor r c | 251 | s <- createMatrix RowMajor r c |
243 | c_luSolveC // matc a // matc b // matc s // check "luSolveC" [cdat a, cdat b] | 252 | ww3 withMatrix a withMatrix b withMatrix s $ \ a b s -> |
253 | c_luSolveC // a // b // s // check "luSolveC" | ||
244 | return s | 254 | return s |
245 | | otherwise = error "luSolveC of nonsquare matrix" | 255 | | otherwise = error "luSolveC of nonsquare matrix" |
246 | where n1 = rows a | 256 | where n1 = rows a |
@@ -256,7 +266,8 @@ luRaux = luRaux' . cmat | |||
256 | 266 | ||
257 | luRaux' x = unsafePerformIO $ do | 267 | luRaux' x = unsafePerformIO $ do |
258 | res <- createVector (r*r+r+1) | 268 | res <- createVector (r*r+r+1) |
259 | c_luRaux // matc x // vec res // check "luRaux" [cdat x] | 269 | ww2 withMatrix x withVector res $ \x res -> |
270 | c_luRaux // x // res // check "luRaux" | ||
260 | return res | 271 | return res |
261 | where r = rows x | 272 | where r = rows x |
262 | c = cols x | 273 | c = cols x |
@@ -269,7 +280,8 @@ luCaux = luCaux' . cmat | |||
269 | 280 | ||
270 | luCaux' x = unsafePerformIO $ do | 281 | luCaux' x = unsafePerformIO $ do |
271 | res <- createVector (r*r+r+1) | 282 | res <- createVector (r*r+r+1) |
272 | c_luCaux // matc x // vec res // check "luCaux" [cdat x] | 283 | ww2 withMatrix x withVector res $ \x res -> |
284 | c_luCaux // x // res // check "luCaux" | ||
273 | return res | 285 | return res |
274 | where r = rows x | 286 | where r = rows x |
275 | c = cols x | 287 | c = cols x |