diff options
author | Alberto Ruiz <aruiz@um.es> | 2015-06-19 13:55:39 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2015-06-19 13:55:39 +0200 |
commit | db50bc11dafa6834a4367427156306674063ed6b (patch) | |
tree | 721e9d0235168be1d0ebb2bd1dd254a66251f274 /packages/base/src/Internal/LAPACK.hs | |
parent | 7f9c7b5adf8f05653d15f19358f41c1916e8db70 (diff) |
removed the annoying appN adapter for the foreign functions.
replaced by several overloaded app variants in the style of
the module Internal.Foreign contributed by Mike Ledger.
Diffstat (limited to 'packages/base/src/Internal/LAPACK.hs')
-rw-r--r-- | packages/base/src/Internal/LAPACK.hs | 54 |
1 files changed, 32 insertions, 22 deletions
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index 8df568d..3a9abbb 100644 --- a/packages/base/src/Internal/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs | |||
@@ -17,7 +17,7 @@ module Internal.LAPACK where | |||
17 | 17 | ||
18 | import Internal.Devel | 18 | import Internal.Devel |
19 | import Internal.Vector | 19 | import Internal.Vector |
20 | import Internal.Matrix | 20 | import Internal.Matrix hiding ((#)) |
21 | import Internal.Conversion | 21 | import Internal.Conversion |
22 | import Internal.Element | 22 | import Internal.Element |
23 | import Foreign.Ptr(nullPtr) | 23 | import Foreign.Ptr(nullPtr) |
@@ -27,6 +27,16 @@ import System.IO.Unsafe(unsafePerformIO) | |||
27 | 27 | ||
28 | ----------------------------------------------------------------------------------- | 28 | ----------------------------------------------------------------------------------- |
29 | 29 | ||
30 | infixl 1 # | ||
31 | a # b = applyRaw a b | ||
32 | {-# INLINE (#) #-} | ||
33 | |||
34 | infixl 1 #! | ||
35 | a #! b = apply a b | ||
36 | {-# INLINE (#!) #-} | ||
37 | |||
38 | ----------------------------------------------------------------------------------- | ||
39 | |||
30 | type TMMM t = t ..> t ..> t ..> Ok | 40 | type TMMM t = t ..> t ..> t ..> Ok |
31 | 41 | ||
32 | type F = Float | 42 | type F = Float |
@@ -49,7 +59,7 @@ multiplyAux f st a b = unsafePerformIO $ do | |||
49 | when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++ | 59 | when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++ |
50 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) | 60 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) |
51 | s <- createMatrix ColumnMajor (rows a) (cols b) | 61 | s <- createMatrix ColumnMajor (rows a) (cols b) |
52 | app3 (f (isT a) (isT b)) mat (tt a) mat (tt b) mat s st | 62 | f (isT a) (isT b) # (tt a) # (tt b) # s #| st |
53 | return s | 63 | return s |
54 | 64 | ||
55 | -- | Matrix product based on BLAS's /dgemm/. | 65 | -- | Matrix product based on BLAS's /dgemm/. |
@@ -73,7 +83,7 @@ multiplyI m a b = unsafePerformIO $ do | |||
73 | when (cols a /= rows b) $ error $ | 83 | when (cols a /= rows b) $ error $ |
74 | "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b | 84 | "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b |
75 | s <- createMatrix ColumnMajor (rows a) (cols b) | 85 | s <- createMatrix ColumnMajor (rows a) (cols b) |
76 | app3 (c_multiplyI m) omat a omat b omat s "c_multiplyI" | 86 | c_multiplyI m #! a #! b #! s #|"c_multiplyI" |
77 | return s | 87 | return s |
78 | 88 | ||
79 | multiplyL :: Z -> Matrix Z -> Matrix Z -> Matrix Z | 89 | multiplyL :: Z -> Matrix Z -> Matrix Z -> Matrix Z |
@@ -81,7 +91,7 @@ multiplyL m a b = unsafePerformIO $ do | |||
81 | when (cols a /= rows b) $ error $ | 91 | when (cols a /= rows b) $ error $ |
82 | "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b | 92 | "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b |
83 | s <- createMatrix ColumnMajor (rows a) (cols b) | 93 | s <- createMatrix ColumnMajor (rows a) (cols b) |
84 | app3 (c_multiplyL m) omat a omat b omat s "c_multiplyL" | 94 | c_multiplyL m #! a #! b #! s #|"c_multiplyL" |
85 | return s | 95 | return s |
86 | 96 | ||
87 | ----------------------------------------------------------------------------- | 97 | ----------------------------------------------------------------------------- |
@@ -113,7 +123,7 @@ svdAux f st x = unsafePerformIO $ do | |||
113 | u <- createMatrix ColumnMajor r r | 123 | u <- createMatrix ColumnMajor r r |
114 | s <- createVector (min r c) | 124 | s <- createVector (min r c) |
115 | v <- createMatrix ColumnMajor c c | 125 | v <- createMatrix ColumnMajor c c |
116 | app4 f mat x mat u vec s mat v st | 126 | f # x # u # s # v #| st |
117 | return (u,s,v) | 127 | return (u,s,v) |
118 | where r = rows x | 128 | where r = rows x |
119 | c = cols x | 129 | c = cols x |
@@ -139,7 +149,7 @@ thinSVDAux f st x = unsafePerformIO $ do | |||
139 | u <- createMatrix ColumnMajor r q | 149 | u <- createMatrix ColumnMajor r q |
140 | s <- createVector q | 150 | s <- createVector q |
141 | v <- createMatrix ColumnMajor q c | 151 | v <- createMatrix ColumnMajor q c |
142 | app4 f mat x mat u vec s mat v st | 152 | f # x # u # s # v #| st |
143 | return (u,s,v) | 153 | return (u,s,v) |
144 | where r = rows x | 154 | where r = rows x |
145 | c = cols x | 155 | c = cols x |
@@ -164,7 +174,7 @@ svCd = svAux zgesdd "svCd" . fmat | |||
164 | 174 | ||
165 | svAux f st x = unsafePerformIO $ do | 175 | svAux f st x = unsafePerformIO $ do |
166 | s <- createVector q | 176 | s <- createVector q |
167 | app2 g mat x vec s st | 177 | g # x # s #| st |
168 | return s | 178 | return s |
169 | where r = rows x | 179 | where r = rows x |
170 | c = cols x | 180 | c = cols x |
@@ -183,7 +193,7 @@ rightSVC = rightSVAux zgesvd "rightSVC" . fmat | |||
183 | rightSVAux f st x = unsafePerformIO $ do | 193 | rightSVAux f st x = unsafePerformIO $ do |
184 | s <- createVector q | 194 | s <- createVector q |
185 | v <- createMatrix ColumnMajor c c | 195 | v <- createMatrix ColumnMajor c c |
186 | app3 g mat x vec s mat v st | 196 | g # x # s # v #| st |
187 | return (s,v) | 197 | return (s,v) |
188 | where r = rows x | 198 | where r = rows x |
189 | c = cols x | 199 | c = cols x |
@@ -202,7 +212,7 @@ leftSVC = leftSVAux zgesvd "leftSVC" . fmat | |||
202 | leftSVAux f st x = unsafePerformIO $ do | 212 | leftSVAux f st x = unsafePerformIO $ do |
203 | u <- createMatrix ColumnMajor r r | 213 | u <- createMatrix ColumnMajor r r |
204 | s <- createVector q | 214 | s <- createVector q |
205 | app3 g mat x mat u vec s st | 215 | g # x # u # s #| st |
206 | return (u,s) | 216 | return (u,s) |
207 | where r = rows x | 217 | where r = rows x |
208 | c = cols x | 218 | c = cols x |
@@ -219,7 +229,7 @@ foreign import ccall unsafe "eig_l_H" zheev :: CInt -> C ..> R :> C ..> Ok | |||
219 | eigAux f st m = unsafePerformIO $ do | 229 | eigAux f st m = unsafePerformIO $ do |
220 | l <- createVector r | 230 | l <- createVector r |
221 | v <- createMatrix ColumnMajor r r | 231 | v <- createMatrix ColumnMajor r r |
222 | app3 g mat m vec l mat v st | 232 | g # m # l # v #| st |
223 | return (l,v) | 233 | return (l,v) |
224 | where r = rows m | 234 | where r = rows m |
225 | g ra ca pa = f ra ca pa 0 0 nullPtr | 235 | g ra ca pa = f ra ca pa 0 0 nullPtr |
@@ -232,7 +242,7 @@ eigC = eigAux zgeev "eigC" . fmat | |||
232 | 242 | ||
233 | eigOnlyAux f st m = unsafePerformIO $ do | 243 | eigOnlyAux f st m = unsafePerformIO $ do |
234 | l <- createVector r | 244 | l <- createVector r |
235 | app2 g mat m vec l st | 245 | g # m # l #| st |
236 | return l | 246 | return l |
237 | where r = rows m | 247 | where r = rows m |
238 | g ra ca pa nl pl = f ra ca pa 0 0 nullPtr nl pl 0 0 nullPtr | 248 | g ra ca pa nl pl = f ra ca pa 0 0 nullPtr nl pl 0 0 nullPtr |
@@ -255,7 +265,7 @@ eigRaux :: Matrix Double -> (Vector (Complex Double), Matrix Double) | |||
255 | eigRaux m = unsafePerformIO $ do | 265 | eigRaux m = unsafePerformIO $ do |
256 | l <- createVector r | 266 | l <- createVector r |
257 | v <- createMatrix ColumnMajor r r | 267 | v <- createMatrix ColumnMajor r r |
258 | app3 g mat m vec l mat v "eigR" | 268 | g # m # l # v #| "eigR" |
259 | return (l,v) | 269 | return (l,v) |
260 | where r = rows m | 270 | where r = rows m |
261 | g ra ca pa = dgeev ra ca pa 0 0 nullPtr | 271 | g ra ca pa = dgeev ra ca pa 0 0 nullPtr |
@@ -282,7 +292,7 @@ eigOnlyR = fixeig1 . eigOnlyAux dgeev "eigOnlyR" . fmat | |||
282 | eigSHAux f st m = unsafePerformIO $ do | 292 | eigSHAux f st m = unsafePerformIO $ do |
283 | l <- createVector r | 293 | l <- createVector r |
284 | v <- createMatrix ColumnMajor r r | 294 | v <- createMatrix ColumnMajor r r |
285 | app3 f mat m vec l mat v st | 295 | f # m # l # v #| st |
286 | return (l,v) | 296 | return (l,v) |
287 | where r = rows m | 297 | where r = rows m |
288 | 298 | ||
@@ -332,7 +342,7 @@ foreign import ccall unsafe "cholSolveC_l" zpotrs :: TMMM C | |||
332 | linearSolveSQAux g f st a b | 342 | linearSolveSQAux g f st a b |
333 | | n1==n2 && n1==r = unsafePerformIO . g $ do | 343 | | n1==n2 && n1==r = unsafePerformIO . g $ do |
334 | s <- createMatrix ColumnMajor r c | 344 | s <- createMatrix ColumnMajor r c |
335 | app3 f mat a mat b mat s st | 345 | f # a # b # s #| st |
336 | return s | 346 | return s |
337 | | otherwise = error $ st ++ " of nonsquare matrix" | 347 | | otherwise = error $ st ++ " of nonsquare matrix" |
338 | where n1 = rows a | 348 | where n1 = rows a |
@@ -371,7 +381,7 @@ foreign import ccall unsafe "linearSolveSVDC_l" zgelss :: Double -> TMMM C | |||
371 | 381 | ||
372 | linearSolveAux f st a b = unsafePerformIO $ do | 382 | linearSolveAux f st a b = unsafePerformIO $ do |
373 | r <- createMatrix ColumnMajor (max m n) nrhs | 383 | r <- createMatrix ColumnMajor (max m n) nrhs |
374 | app3 f mat a mat b mat r st | 384 | f # a # b # r #| st |
375 | return r | 385 | return r |
376 | where m = rows a | 386 | where m = rows a |
377 | n = cols a | 387 | n = cols a |
@@ -412,7 +422,7 @@ foreign import ccall unsafe "chol_l_S" dpotrf :: TMM R | |||
412 | 422 | ||
413 | cholAux f st a = do | 423 | cholAux f st a = do |
414 | r <- createMatrix ColumnMajor n n | 424 | r <- createMatrix ColumnMajor n n |
415 | app2 f mat a mat r st | 425 | f # a # r #| st |
416 | return r | 426 | return r |
417 | where n = rows a | 427 | where n = rows a |
418 | 428 | ||
@@ -450,7 +460,7 @@ qrC = qrAux zgeqr2 "qrC" . fmat | |||
450 | qrAux f st a = unsafePerformIO $ do | 460 | qrAux f st a = unsafePerformIO $ do |
451 | r <- createMatrix ColumnMajor m n | 461 | r <- createMatrix ColumnMajor m n |
452 | tau <- createVector mn | 462 | tau <- createVector mn |
453 | app3 f mat a vec tau mat r st | 463 | f # a # tau # r #| st |
454 | return (r,tau) | 464 | return (r,tau) |
455 | where | 465 | where |
456 | m = rows a | 466 | m = rows a |
@@ -469,7 +479,7 @@ qrgrC = qrgrAux zungqr "qrgrC" | |||
469 | 479 | ||
470 | qrgrAux f st n (a, tau) = unsafePerformIO $ do | 480 | qrgrAux f st n (a, tau) = unsafePerformIO $ do |
471 | res <- createMatrix ColumnMajor (rows a) n | 481 | res <- createMatrix ColumnMajor (rows a) n |
472 | app3 f mat (fmat a) vec (subVector 0 n tau') mat res st | 482 | f # (fmat a) # (subVector 0 n tau') # res #| st |
473 | return res | 483 | return res |
474 | where | 484 | where |
475 | tau' = vjoin [tau, constantD 0 n] | 485 | tau' = vjoin [tau, constantD 0 n] |
@@ -489,7 +499,7 @@ hessC = hessAux zgehrd "hessC" . fmat | |||
489 | hessAux f st a = unsafePerformIO $ do | 499 | hessAux f st a = unsafePerformIO $ do |
490 | r <- createMatrix ColumnMajor m n | 500 | r <- createMatrix ColumnMajor m n |
491 | tau <- createVector (mn-1) | 501 | tau <- createVector (mn-1) |
492 | app3 f mat a vec tau mat r st | 502 | f # a # tau # r #| st |
493 | return (r,tau) | 503 | return (r,tau) |
494 | where m = rows a | 504 | where m = rows a |
495 | n = cols a | 505 | n = cols a |
@@ -510,7 +520,7 @@ schurC = schurAux zgees "schurC" . fmat | |||
510 | schurAux f st a = unsafePerformIO $ do | 520 | schurAux f st a = unsafePerformIO $ do |
511 | u <- createMatrix ColumnMajor n n | 521 | u <- createMatrix ColumnMajor n n |
512 | s <- createMatrix ColumnMajor n n | 522 | s <- createMatrix ColumnMajor n n |
513 | app3 f mat a mat u mat s st | 523 | f # a # u # s #| st |
514 | return (u,s) | 524 | return (u,s) |
515 | where n = rows a | 525 | where n = rows a |
516 | 526 | ||
@@ -529,7 +539,7 @@ luC = luAux zgetrf "luC" . fmat | |||
529 | luAux f st a = unsafePerformIO $ do | 539 | luAux f st a = unsafePerformIO $ do |
530 | lu <- createMatrix ColumnMajor n m | 540 | lu <- createMatrix ColumnMajor n m |
531 | piv <- createVector (min n m) | 541 | piv <- createVector (min n m) |
532 | app3 f mat a vec piv mat lu st | 542 | f # a # piv # lu #| st |
533 | return (lu, map (pred.round) (toList piv)) | 543 | return (lu, map (pred.round) (toList piv)) |
534 | where n = rows a | 544 | where n = rows a |
535 | m = cols a | 545 | m = cols a |
@@ -552,7 +562,7 @@ lusC a piv b = lusAux zgetrs "lusC" (fmat a) piv (fmat b) | |||
552 | lusAux f st a piv b | 562 | lusAux f st a piv b |
553 | | n1==n2 && n2==n =unsafePerformIO $ do | 563 | | n1==n2 && n2==n =unsafePerformIO $ do |
554 | x <- createMatrix ColumnMajor n m | 564 | x <- createMatrix ColumnMajor n m |
555 | app4 f mat a vec piv' mat b mat x st | 565 | f # a # piv' # b # x #| st |
556 | return x | 566 | return x |
557 | | otherwise = error $ st ++ " on LU factorization of nonsquare matrix" | 567 | | otherwise = error $ st ++ " on LU factorization of nonsquare matrix" |
558 | where n1 = rows a | 568 | where n1 = rows a |