summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/LAPACK.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal/LAPACK.hs')
-rw-r--r--packages/base/src/Internal/LAPACK.hs54
1 files changed, 32 insertions, 22 deletions
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs
index 8df568d..3a9abbb 100644
--- a/packages/base/src/Internal/LAPACK.hs
+++ b/packages/base/src/Internal/LAPACK.hs
@@ -17,7 +17,7 @@ module Internal.LAPACK where
17 17
18import Internal.Devel 18import Internal.Devel
19import Internal.Vector 19import Internal.Vector
20import Internal.Matrix 20import Internal.Matrix hiding ((#))
21import Internal.Conversion 21import Internal.Conversion
22import Internal.Element 22import Internal.Element
23import Foreign.Ptr(nullPtr) 23import Foreign.Ptr(nullPtr)
@@ -27,6 +27,16 @@ import System.IO.Unsafe(unsafePerformIO)
27 27
28----------------------------------------------------------------------------------- 28-----------------------------------------------------------------------------------
29 29
30infixl 1 #
31a # b = applyRaw a b
32{-# INLINE (#) #-}
33
34infixl 1 #!
35a #! b = apply a b
36{-# INLINE (#!) #-}
37
38-----------------------------------------------------------------------------------
39
30type TMMM t = t ..> t ..> t ..> Ok 40type TMMM t = t ..> t ..> t ..> Ok
31 41
32type F = Float 42type F = Float
@@ -49,7 +59,7 @@ multiplyAux f st a b = unsafePerformIO $ do
49 when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++ 59 when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++
50 show (rows a,cols a) ++ " x " ++ show (rows b, cols b) 60 show (rows a,cols a) ++ " x " ++ show (rows b, cols b)
51 s <- createMatrix ColumnMajor (rows a) (cols b) 61 s <- createMatrix ColumnMajor (rows a) (cols b)
52 app3 (f (isT a) (isT b)) mat (tt a) mat (tt b) mat s st 62 f (isT a) (isT b) # (tt a) # (tt b) # s #| st
53 return s 63 return s
54 64
55-- | Matrix product based on BLAS's /dgemm/. 65-- | Matrix product based on BLAS's /dgemm/.
@@ -73,7 +83,7 @@ multiplyI m a b = unsafePerformIO $ do
73 when (cols a /= rows b) $ error $ 83 when (cols a /= rows b) $ error $
74 "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b 84 "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b
75 s <- createMatrix ColumnMajor (rows a) (cols b) 85 s <- createMatrix ColumnMajor (rows a) (cols b)
76 app3 (c_multiplyI m) omat a omat b omat s "c_multiplyI" 86 c_multiplyI m #! a #! b #! s #|"c_multiplyI"
77 return s 87 return s
78 88
79multiplyL :: Z -> Matrix Z -> Matrix Z -> Matrix Z 89multiplyL :: Z -> Matrix Z -> Matrix Z -> Matrix Z
@@ -81,7 +91,7 @@ multiplyL m a b = unsafePerformIO $ do
81 when (cols a /= rows b) $ error $ 91 when (cols a /= rows b) $ error $
82 "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b 92 "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b
83 s <- createMatrix ColumnMajor (rows a) (cols b) 93 s <- createMatrix ColumnMajor (rows a) (cols b)
84 app3 (c_multiplyL m) omat a omat b omat s "c_multiplyL" 94 c_multiplyL m #! a #! b #! s #|"c_multiplyL"
85 return s 95 return s
86 96
87----------------------------------------------------------------------------- 97-----------------------------------------------------------------------------
@@ -113,7 +123,7 @@ svdAux f st x = unsafePerformIO $ do
113 u <- createMatrix ColumnMajor r r 123 u <- createMatrix ColumnMajor r r
114 s <- createVector (min r c) 124 s <- createVector (min r c)
115 v <- createMatrix ColumnMajor c c 125 v <- createMatrix ColumnMajor c c
116 app4 f mat x mat u vec s mat v st 126 f # x # u # s # v #| st
117 return (u,s,v) 127 return (u,s,v)
118 where r = rows x 128 where r = rows x
119 c = cols x 129 c = cols x
@@ -139,7 +149,7 @@ thinSVDAux f st x = unsafePerformIO $ do
139 u <- createMatrix ColumnMajor r q 149 u <- createMatrix ColumnMajor r q
140 s <- createVector q 150 s <- createVector q
141 v <- createMatrix ColumnMajor q c 151 v <- createMatrix ColumnMajor q c
142 app4 f mat x mat u vec s mat v st 152 f # x # u # s # v #| st
143 return (u,s,v) 153 return (u,s,v)
144 where r = rows x 154 where r = rows x
145 c = cols x 155 c = cols x
@@ -164,7 +174,7 @@ svCd = svAux zgesdd "svCd" . fmat
164 174
165svAux f st x = unsafePerformIO $ do 175svAux f st x = unsafePerformIO $ do
166 s <- createVector q 176 s <- createVector q
167 app2 g mat x vec s st 177 g # x # s #| st
168 return s 178 return s
169 where r = rows x 179 where r = rows x
170 c = cols x 180 c = cols x
@@ -183,7 +193,7 @@ rightSVC = rightSVAux zgesvd "rightSVC" . fmat
183rightSVAux f st x = unsafePerformIO $ do 193rightSVAux f st x = unsafePerformIO $ do
184 s <- createVector q 194 s <- createVector q
185 v <- createMatrix ColumnMajor c c 195 v <- createMatrix ColumnMajor c c
186 app3 g mat x vec s mat v st 196 g # x # s # v #| st
187 return (s,v) 197 return (s,v)
188 where r = rows x 198 where r = rows x
189 c = cols x 199 c = cols x
@@ -202,7 +212,7 @@ leftSVC = leftSVAux zgesvd "leftSVC" . fmat
202leftSVAux f st x = unsafePerformIO $ do 212leftSVAux f st x = unsafePerformIO $ do
203 u <- createMatrix ColumnMajor r r 213 u <- createMatrix ColumnMajor r r
204 s <- createVector q 214 s <- createVector q
205 app3 g mat x mat u vec s st 215 g # x # u # s #| st
206 return (u,s) 216 return (u,s)
207 where r = rows x 217 where r = rows x
208 c = cols x 218 c = cols x
@@ -219,7 +229,7 @@ foreign import ccall unsafe "eig_l_H" zheev :: CInt -> C ..> R :> C ..> Ok
219eigAux f st m = unsafePerformIO $ do 229eigAux f st m = unsafePerformIO $ do
220 l <- createVector r 230 l <- createVector r
221 v <- createMatrix ColumnMajor r r 231 v <- createMatrix ColumnMajor r r
222 app3 g mat m vec l mat v st 232 g # m # l # v #| st
223 return (l,v) 233 return (l,v)
224 where r = rows m 234 where r = rows m
225 g ra ca pa = f ra ca pa 0 0 nullPtr 235 g ra ca pa = f ra ca pa 0 0 nullPtr
@@ -232,7 +242,7 @@ eigC = eigAux zgeev "eigC" . fmat
232 242
233eigOnlyAux f st m = unsafePerformIO $ do 243eigOnlyAux f st m = unsafePerformIO $ do
234 l <- createVector r 244 l <- createVector r
235 app2 g mat m vec l st 245 g # m # l #| st
236 return l 246 return l
237 where r = rows m 247 where r = rows m
238 g ra ca pa nl pl = f ra ca pa 0 0 nullPtr nl pl 0 0 nullPtr 248 g ra ca pa nl pl = f ra ca pa 0 0 nullPtr nl pl 0 0 nullPtr
@@ -255,7 +265,7 @@ eigRaux :: Matrix Double -> (Vector (Complex Double), Matrix Double)
255eigRaux m = unsafePerformIO $ do 265eigRaux m = unsafePerformIO $ do
256 l <- createVector r 266 l <- createVector r
257 v <- createMatrix ColumnMajor r r 267 v <- createMatrix ColumnMajor r r
258 app3 g mat m vec l mat v "eigR" 268 g # m # l # v #| "eigR"
259 return (l,v) 269 return (l,v)
260 where r = rows m 270 where r = rows m
261 g ra ca pa = dgeev ra ca pa 0 0 nullPtr 271 g ra ca pa = dgeev ra ca pa 0 0 nullPtr
@@ -282,7 +292,7 @@ eigOnlyR = fixeig1 . eigOnlyAux dgeev "eigOnlyR" . fmat
282eigSHAux f st m = unsafePerformIO $ do 292eigSHAux f st m = unsafePerformIO $ do
283 l <- createVector r 293 l <- createVector r
284 v <- createMatrix ColumnMajor r r 294 v <- createMatrix ColumnMajor r r
285 app3 f mat m vec l mat v st 295 f # m # l # v #| st
286 return (l,v) 296 return (l,v)
287 where r = rows m 297 where r = rows m
288 298
@@ -332,7 +342,7 @@ foreign import ccall unsafe "cholSolveC_l" zpotrs :: TMMM C
332linearSolveSQAux g f st a b 342linearSolveSQAux g f st a b
333 | n1==n2 && n1==r = unsafePerformIO . g $ do 343 | n1==n2 && n1==r = unsafePerformIO . g $ do
334 s <- createMatrix ColumnMajor r c 344 s <- createMatrix ColumnMajor r c
335 app3 f mat a mat b mat s st 345 f # a # b # s #| st
336 return s 346 return s
337 | otherwise = error $ st ++ " of nonsquare matrix" 347 | otherwise = error $ st ++ " of nonsquare matrix"
338 where n1 = rows a 348 where n1 = rows a
@@ -371,7 +381,7 @@ foreign import ccall unsafe "linearSolveSVDC_l" zgelss :: Double -> TMMM C
371 381
372linearSolveAux f st a b = unsafePerformIO $ do 382linearSolveAux f st a b = unsafePerformIO $ do
373 r <- createMatrix ColumnMajor (max m n) nrhs 383 r <- createMatrix ColumnMajor (max m n) nrhs
374 app3 f mat a mat b mat r st 384 f # a # b # r #| st
375 return r 385 return r
376 where m = rows a 386 where m = rows a
377 n = cols a 387 n = cols a
@@ -412,7 +422,7 @@ foreign import ccall unsafe "chol_l_S" dpotrf :: TMM R
412 422
413cholAux f st a = do 423cholAux f st a = do
414 r <- createMatrix ColumnMajor n n 424 r <- createMatrix ColumnMajor n n
415 app2 f mat a mat r st 425 f # a # r #| st
416 return r 426 return r
417 where n = rows a 427 where n = rows a
418 428
@@ -450,7 +460,7 @@ qrC = qrAux zgeqr2 "qrC" . fmat
450qrAux f st a = unsafePerformIO $ do 460qrAux f st a = unsafePerformIO $ do
451 r <- createMatrix ColumnMajor m n 461 r <- createMatrix ColumnMajor m n
452 tau <- createVector mn 462 tau <- createVector mn
453 app3 f mat a vec tau mat r st 463 f # a # tau # r #| st
454 return (r,tau) 464 return (r,tau)
455 where 465 where
456 m = rows a 466 m = rows a
@@ -469,7 +479,7 @@ qrgrC = qrgrAux zungqr "qrgrC"
469 479
470qrgrAux f st n (a, tau) = unsafePerformIO $ do 480qrgrAux f st n (a, tau) = unsafePerformIO $ do
471 res <- createMatrix ColumnMajor (rows a) n 481 res <- createMatrix ColumnMajor (rows a) n
472 app3 f mat (fmat a) vec (subVector 0 n tau') mat res st 482 f # (fmat a) # (subVector 0 n tau') # res #| st
473 return res 483 return res
474 where 484 where
475 tau' = vjoin [tau, constantD 0 n] 485 tau' = vjoin [tau, constantD 0 n]
@@ -489,7 +499,7 @@ hessC = hessAux zgehrd "hessC" . fmat
489hessAux f st a = unsafePerformIO $ do 499hessAux f st a = unsafePerformIO $ do
490 r <- createMatrix ColumnMajor m n 500 r <- createMatrix ColumnMajor m n
491 tau <- createVector (mn-1) 501 tau <- createVector (mn-1)
492 app3 f mat a vec tau mat r st 502 f # a # tau # r #| st
493 return (r,tau) 503 return (r,tau)
494 where m = rows a 504 where m = rows a
495 n = cols a 505 n = cols a
@@ -510,7 +520,7 @@ schurC = schurAux zgees "schurC" . fmat
510schurAux f st a = unsafePerformIO $ do 520schurAux f st a = unsafePerformIO $ do
511 u <- createMatrix ColumnMajor n n 521 u <- createMatrix ColumnMajor n n
512 s <- createMatrix ColumnMajor n n 522 s <- createMatrix ColumnMajor n n
513 app3 f mat a mat u mat s st 523 f # a # u # s #| st
514 return (u,s) 524 return (u,s)
515 where n = rows a 525 where n = rows a
516 526
@@ -529,7 +539,7 @@ luC = luAux zgetrf "luC" . fmat
529luAux f st a = unsafePerformIO $ do 539luAux f st a = unsafePerformIO $ do
530 lu <- createMatrix ColumnMajor n m 540 lu <- createMatrix ColumnMajor n m
531 piv <- createVector (min n m) 541 piv <- createVector (min n m)
532 app3 f mat a vec piv mat lu st 542 f # a # piv # lu #| st
533 return (lu, map (pred.round) (toList piv)) 543 return (lu, map (pred.round) (toList piv))
534 where n = rows a 544 where n = rows a
535 m = cols a 545 m = cols a
@@ -552,7 +562,7 @@ lusC a piv b = lusAux zgetrs "lusC" (fmat a) piv (fmat b)
552lusAux f st a piv b 562lusAux f st a piv b
553 | n1==n2 && n2==n =unsafePerformIO $ do 563 | n1==n2 && n2==n =unsafePerformIO $ do
554 x <- createMatrix ColumnMajor n m 564 x <- createMatrix ColumnMajor n m
555 app4 f mat a vec piv' mat b mat x st 565 f # a # piv' # b # x #| st
556 return x 566 return x
557 | otherwise = error $ st ++ " on LU factorization of nonsquare matrix" 567 | otherwise = error $ st ++ " on LU factorization of nonsquare matrix"
558 where n1 = rows a 568 where n1 = rows a