summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/LAPACK.hs
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2015-06-19 13:55:39 +0200
committerAlberto Ruiz <aruiz@um.es>2015-06-19 13:55:39 +0200
commitdb50bc11dafa6834a4367427156306674063ed6b (patch)
tree721e9d0235168be1d0ebb2bd1dd254a66251f274 /packages/base/src/Internal/LAPACK.hs
parent7f9c7b5adf8f05653d15f19358f41c1916e8db70 (diff)
removed the annoying appN adapter for the foreign functions.
replaced by several overloaded app variants in the style of the module Internal.Foreign contributed by Mike Ledger.
Diffstat (limited to 'packages/base/src/Internal/LAPACK.hs')
-rw-r--r--packages/base/src/Internal/LAPACK.hs54
1 files changed, 32 insertions, 22 deletions
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs
index 8df568d..3a9abbb 100644
--- a/packages/base/src/Internal/LAPACK.hs
+++ b/packages/base/src/Internal/LAPACK.hs
@@ -17,7 +17,7 @@ module Internal.LAPACK where
17 17
18import Internal.Devel 18import Internal.Devel
19import Internal.Vector 19import Internal.Vector
20import Internal.Matrix 20import Internal.Matrix hiding ((#))
21import Internal.Conversion 21import Internal.Conversion
22import Internal.Element 22import Internal.Element
23import Foreign.Ptr(nullPtr) 23import Foreign.Ptr(nullPtr)
@@ -27,6 +27,16 @@ import System.IO.Unsafe(unsafePerformIO)
27 27
28----------------------------------------------------------------------------------- 28-----------------------------------------------------------------------------------
29 29
30infixl 1 #
31a # b = applyRaw a b
32{-# INLINE (#) #-}
33
34infixl 1 #!
35a #! b = apply a b
36{-# INLINE (#!) #-}
37
38-----------------------------------------------------------------------------------
39
30type TMMM t = t ..> t ..> t ..> Ok 40type TMMM t = t ..> t ..> t ..> Ok
31 41
32type F = Float 42type F = Float
@@ -49,7 +59,7 @@ multiplyAux f st a b = unsafePerformIO $ do
49 when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++ 59 when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++
50 show (rows a,cols a) ++ " x " ++ show (rows b, cols b) 60 show (rows a,cols a) ++ " x " ++ show (rows b, cols b)
51 s <- createMatrix ColumnMajor (rows a) (cols b) 61 s <- createMatrix ColumnMajor (rows a) (cols b)
52 app3 (f (isT a) (isT b)) mat (tt a) mat (tt b) mat s st 62 f (isT a) (isT b) # (tt a) # (tt b) # s #| st
53 return s 63 return s
54 64
55-- | Matrix product based on BLAS's /dgemm/. 65-- | Matrix product based on BLAS's /dgemm/.
@@ -73,7 +83,7 @@ multiplyI m a b = unsafePerformIO $ do
73 when (cols a /= rows b) $ error $ 83 when (cols a /= rows b) $ error $
74 "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b 84 "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b
75 s <- createMatrix ColumnMajor (rows a) (cols b) 85 s <- createMatrix ColumnMajor (rows a) (cols b)
76 app3 (c_multiplyI m) omat a omat b omat s "c_multiplyI" 86 c_multiplyI m #! a #! b #! s #|"c_multiplyI"
77 return s 87 return s
78 88
79multiplyL :: Z -> Matrix Z -> Matrix Z -> Matrix Z 89multiplyL :: Z -> Matrix Z -> Matrix Z -> Matrix Z
@@ -81,7 +91,7 @@ multiplyL m a b = unsafePerformIO $ do
81 when (cols a /= rows b) $ error $ 91 when (cols a /= rows b) $ error $
82 "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b 92 "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b
83 s <- createMatrix ColumnMajor (rows a) (cols b) 93 s <- createMatrix ColumnMajor (rows a) (cols b)
84 app3 (c_multiplyL m) omat a omat b omat s "c_multiplyL" 94 c_multiplyL m #! a #! b #! s #|"c_multiplyL"
85 return s 95 return s
86 96
87----------------------------------------------------------------------------- 97-----------------------------------------------------------------------------
@@ -113,7 +123,7 @@ svdAux f st x = unsafePerformIO $ do
113 u <- createMatrix ColumnMajor r r 123 u <- createMatrix ColumnMajor r r
114 s <- createVector (min r c) 124 s <- createVector (min r c)
115 v <- createMatrix ColumnMajor c c 125 v <- createMatrix ColumnMajor c c
116 app4 f mat x mat u vec s mat v st 126 f # x # u # s # v #| st
117 return (u,s,v) 127 return (u,s,v)
118 where r = rows x 128 where r = rows x
119 c = cols x 129 c = cols x
@@ -139,7 +149,7 @@ thinSVDAux f st x = unsafePerformIO $ do
139 u <- createMatrix ColumnMajor r q 149 u <- createMatrix ColumnMajor r q
140 s <- createVector q 150 s <- createVector q
141 v <- createMatrix ColumnMajor q c 151 v <- createMatrix ColumnMajor q c
142 app4 f mat x mat u vec s mat v st 152 f # x # u # s # v #| st
143 return (u,s,v) 153 return (u,s,v)
144 where r = rows x 154 where r = rows x
145 c = cols x 155 c = cols x
@@ -164,7 +174,7 @@ svCd = svAux zgesdd "svCd" . fmat
164 174
165svAux f st x = unsafePerformIO $ do 175svAux f st x = unsafePerformIO $ do
166 s <- createVector q 176 s <- createVector q
167 app2 g mat x vec s st 177 g # x # s #| st
168 return s 178 return s
169 where r = rows x 179 where r = rows x
170 c = cols x 180 c = cols x
@@ -183,7 +193,7 @@ rightSVC = rightSVAux zgesvd "rightSVC" . fmat
183rightSVAux f st x = unsafePerformIO $ do 193rightSVAux f st x = unsafePerformIO $ do
184 s <- createVector q 194 s <- createVector q
185 v <- createMatrix ColumnMajor c c 195 v <- createMatrix ColumnMajor c c
186 app3 g mat x vec s mat v st 196 g # x # s # v #| st
187 return (s,v) 197 return (s,v)
188 where r = rows x 198 where r = rows x
189 c = cols x 199 c = cols x
@@ -202,7 +212,7 @@ leftSVC = leftSVAux zgesvd "leftSVC" . fmat
202leftSVAux f st x = unsafePerformIO $ do 212leftSVAux f st x = unsafePerformIO $ do
203 u <- createMatrix ColumnMajor r r 213 u <- createMatrix ColumnMajor r r
204 s <- createVector q 214 s <- createVector q
205 app3 g mat x mat u vec s st 215 g # x # u # s #| st
206 return (u,s) 216 return (u,s)
207 where r = rows x 217 where r = rows x
208 c = cols x 218 c = cols x
@@ -219,7 +229,7 @@ foreign import ccall unsafe "eig_l_H" zheev :: CInt -> C ..> R :> C ..> Ok
219eigAux f st m = unsafePerformIO $ do 229eigAux f st m = unsafePerformIO $ do
220 l <- createVector r 230 l <- createVector r
221 v <- createMatrix ColumnMajor r r 231 v <- createMatrix ColumnMajor r r
222 app3 g mat m vec l mat v st 232 g # m # l # v #| st
223 return (l,v) 233 return (l,v)
224 where r = rows m 234 where r = rows m
225 g ra ca pa = f ra ca pa 0 0 nullPtr 235 g ra ca pa = f ra ca pa 0 0 nullPtr
@@ -232,7 +242,7 @@ eigC = eigAux zgeev "eigC" . fmat
232 242
233eigOnlyAux f st m = unsafePerformIO $ do 243eigOnlyAux f st m = unsafePerformIO $ do
234 l <- createVector r 244 l <- createVector r
235 app2 g mat m vec l st 245 g # m # l #| st
236 return l 246 return l
237 where r = rows m 247 where r = rows m
238 g ra ca pa nl pl = f ra ca pa 0 0 nullPtr nl pl 0 0 nullPtr 248 g ra ca pa nl pl = f ra ca pa 0 0 nullPtr nl pl 0 0 nullPtr
@@ -255,7 +265,7 @@ eigRaux :: Matrix Double -> (Vector (Complex Double), Matrix Double)
255eigRaux m = unsafePerformIO $ do 265eigRaux m = unsafePerformIO $ do
256 l <- createVector r 266 l <- createVector r
257 v <- createMatrix ColumnMajor r r 267 v <- createMatrix ColumnMajor r r
258 app3 g mat m vec l mat v "eigR" 268 g # m # l # v #| "eigR"
259 return (l,v) 269 return (l,v)
260 where r = rows m 270 where r = rows m
261 g ra ca pa = dgeev ra ca pa 0 0 nullPtr 271 g ra ca pa = dgeev ra ca pa 0 0 nullPtr
@@ -282,7 +292,7 @@ eigOnlyR = fixeig1 . eigOnlyAux dgeev "eigOnlyR" . fmat
282eigSHAux f st m = unsafePerformIO $ do 292eigSHAux f st m = unsafePerformIO $ do
283 l <- createVector r 293 l <- createVector r
284 v <- createMatrix ColumnMajor r r 294 v <- createMatrix ColumnMajor r r
285 app3 f mat m vec l mat v st 295 f # m # l # v #| st
286 return (l,v) 296 return (l,v)
287 where r = rows m 297 where r = rows m
288 298
@@ -332,7 +342,7 @@ foreign import ccall unsafe "cholSolveC_l" zpotrs :: TMMM C
332linearSolveSQAux g f st a b 342linearSolveSQAux g f st a b
333 | n1==n2 && n1==r = unsafePerformIO . g $ do 343 | n1==n2 && n1==r = unsafePerformIO . g $ do
334 s <- createMatrix ColumnMajor r c 344 s <- createMatrix ColumnMajor r c
335 app3 f mat a mat b mat s st 345 f # a # b # s #| st
336 return s 346 return s
337 | otherwise = error $ st ++ " of nonsquare matrix" 347 | otherwise = error $ st ++ " of nonsquare matrix"
338 where n1 = rows a 348 where n1 = rows a
@@ -371,7 +381,7 @@ foreign import ccall unsafe "linearSolveSVDC_l" zgelss :: Double -> TMMM C
371 381
372linearSolveAux f st a b = unsafePerformIO $ do 382linearSolveAux f st a b = unsafePerformIO $ do
373 r <- createMatrix ColumnMajor (max m n) nrhs 383 r <- createMatrix ColumnMajor (max m n) nrhs
374 app3 f mat a mat b mat r st 384 f # a # b # r #| st
375 return r 385 return r
376 where m = rows a 386 where m = rows a
377 n = cols a 387 n = cols a
@@ -412,7 +422,7 @@ foreign import ccall unsafe "chol_l_S" dpotrf :: TMM R
412 422
413cholAux f st a = do 423cholAux f st a = do
414 r <- createMatrix ColumnMajor n n 424 r <- createMatrix ColumnMajor n n
415 app2 f mat a mat r st 425 f # a # r #| st
416 return r 426 return r
417 where n = rows a 427 where n = rows a
418 428
@@ -450,7 +460,7 @@ qrC = qrAux zgeqr2 "qrC" . fmat
450qrAux f st a = unsafePerformIO $ do 460qrAux f st a = unsafePerformIO $ do
451 r <- createMatrix ColumnMajor m n 461 r <- createMatrix ColumnMajor m n
452 tau <- createVector mn 462 tau <- createVector mn
453 app3 f mat a vec tau mat r st 463 f # a # tau # r #| st
454 return (r,tau) 464 return (r,tau)
455 where 465 where
456 m = rows a 466 m = rows a
@@ -469,7 +479,7 @@ qrgrC = qrgrAux zungqr "qrgrC"
469 479
470qrgrAux f st n (a, tau) = unsafePerformIO $ do 480qrgrAux f st n (a, tau) = unsafePerformIO $ do
471 res <- createMatrix ColumnMajor (rows a) n 481 res <- createMatrix ColumnMajor (rows a) n
472 app3 f mat (fmat a) vec (subVector 0 n tau') mat res st 482 f # (fmat a) # (subVector 0 n tau') # res #| st
473 return res 483 return res
474 where 484 where
475 tau' = vjoin [tau, constantD 0 n] 485 tau' = vjoin [tau, constantD 0 n]
@@ -489,7 +499,7 @@ hessC = hessAux zgehrd "hessC" . fmat
489hessAux f st a = unsafePerformIO $ do 499hessAux f st a = unsafePerformIO $ do
490 r <- createMatrix ColumnMajor m n 500 r <- createMatrix ColumnMajor m n
491 tau <- createVector (mn-1) 501 tau <- createVector (mn-1)
492 app3 f mat a vec tau mat r st 502 f # a # tau # r #| st
493 return (r,tau) 503 return (r,tau)
494 where m = rows a 504 where m = rows a
495 n = cols a 505 n = cols a
@@ -510,7 +520,7 @@ schurC = schurAux zgees "schurC" . fmat
510schurAux f st a = unsafePerformIO $ do 520schurAux f st a = unsafePerformIO $ do
511 u <- createMatrix ColumnMajor n n 521 u <- createMatrix ColumnMajor n n
512 s <- createMatrix ColumnMajor n n 522 s <- createMatrix ColumnMajor n n
513 app3 f mat a mat u mat s st 523 f # a # u # s #| st
514 return (u,s) 524 return (u,s)
515 where n = rows a 525 where n = rows a
516 526
@@ -529,7 +539,7 @@ luC = luAux zgetrf "luC" . fmat
529luAux f st a = unsafePerformIO $ do 539luAux f st a = unsafePerformIO $ do
530 lu <- createMatrix ColumnMajor n m 540 lu <- createMatrix ColumnMajor n m
531 piv <- createVector (min n m) 541 piv <- createVector (min n m)
532 app3 f mat a vec piv mat lu st 542 f # a # piv # lu #| st
533 return (lu, map (pred.round) (toList piv)) 543 return (lu, map (pred.round) (toList piv))
534 where n = rows a 544 where n = rows a
535 m = cols a 545 m = cols a
@@ -552,7 +562,7 @@ lusC a piv b = lusAux zgetrs "lusC" (fmat a) piv (fmat b)
552lusAux f st a piv b 562lusAux f st a piv b
553 | n1==n2 && n2==n =unsafePerformIO $ do 563 | n1==n2 && n2==n =unsafePerformIO $ do
554 x <- createMatrix ColumnMajor n m 564 x <- createMatrix ColumnMajor n m
555 app4 f mat a vec piv' mat b mat x st 565 f # a # piv' # b # x #| st
556 return x 566 return x
557 | otherwise = error $ st ++ " on LU factorization of nonsquare matrix" 567 | otherwise = error $ st ++ " on LU factorization of nonsquare matrix"
558 where n1 = rows a 568 where n1 = rows a