summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/LAPACK.hs
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2015-06-27 09:15:27 +0200
committerAlberto Ruiz <aruiz@um.es>2015-06-27 09:15:27 +0200
commit4d96b90c4cfd38cdb51f3dc66a8a644bd87cdbff (patch)
treed7b82283f08e5947b06fdec4f403a5bc87c09f35 /packages/base/src/Internal/LAPACK.hs
parent624046d6b55d37104f950e8888ab68c53a2e6bf0 (diff)
use slice interface for lapack funcs (wip)
Diffstat (limited to 'packages/base/src/Internal/LAPACK.hs')
-rw-r--r--packages/base/src/Internal/LAPACK.hs165
1 files changed, 89 insertions, 76 deletions
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs
index 5319e95..2c7148b 100644
--- a/packages/base/src/Internal/LAPACK.hs
+++ b/packages/base/src/Internal/LAPACK.hs
@@ -29,16 +29,12 @@ import System.IO.Unsafe(unsafePerformIO)
29----------------------------------------------------------------------------------- 29-----------------------------------------------------------------------------------
30 30
31infixl 1 # 31infixl 1 #
32a # b = applyRaw a b 32a # b = apply a b
33{-# INLINE (#) #-} 33{-# INLINE (#) #-}
34 34
35infixl 1 #!
36a #! b = apply a b
37{-# INLINE (#!) #-}
38
39----------------------------------------------------------------------------------- 35-----------------------------------------------------------------------------------
40 36
41type TMMM t = t ..> t ..> t ..> Ok 37type TMMM t = t ::> t ::> t ::> Ok
42 38
43type F = Float 39type F = Float
44type Q = Complex Float 40type Q = Complex Float
@@ -47,8 +43,8 @@ foreign import ccall unsafe "multiplyR" dgemmc :: CInt -> CInt -> TMMM R
47foreign import ccall unsafe "multiplyC" zgemmc :: CInt -> CInt -> TMMM C 43foreign import ccall unsafe "multiplyC" zgemmc :: CInt -> CInt -> TMMM C
48foreign import ccall unsafe "multiplyF" sgemmc :: CInt -> CInt -> TMMM F 44foreign import ccall unsafe "multiplyF" sgemmc :: CInt -> CInt -> TMMM F
49foreign import ccall unsafe "multiplyQ" cgemmc :: CInt -> CInt -> TMMM Q 45foreign import ccall unsafe "multiplyQ" cgemmc :: CInt -> CInt -> TMMM Q
50foreign import ccall unsafe "multiplyI" c_multiplyI :: I -> I ::> I ::> I ::> Ok 46foreign import ccall unsafe "multiplyI" c_multiplyI :: I -> TMMM I
51foreign import ccall unsafe "multiplyL" c_multiplyL :: Z -> Z ::> Z ::> Z ::> Ok 47foreign import ccall unsafe "multiplyL" c_multiplyL :: Z -> TMMM Z
52 48
53isT (rowOrder -> False) = 0 49isT (rowOrder -> False) = 0
54isT _ = 1 50isT _ = 1
@@ -84,7 +80,7 @@ multiplyI m a b = unsafePerformIO $ do
84 when (cols a /= rows b) $ error $ 80 when (cols a /= rows b) $ error $
85 "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b 81 "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b
86 s <- createMatrix ColumnMajor (rows a) (cols b) 82 s <- createMatrix ColumnMajor (rows a) (cols b)
87 c_multiplyI m #! a #! b #! s #|"c_multiplyI" 83 c_multiplyI m # a # b # s #|"c_multiplyI"
88 return s 84 return s
89 85
90multiplyL :: Z -> Matrix Z -> Matrix Z -> Matrix Z 86multiplyL :: Z -> Matrix Z -> Matrix Z -> Matrix Z
@@ -92,12 +88,12 @@ multiplyL m a b = unsafePerformIO $ do
92 when (cols a /= rows b) $ error $ 88 when (cols a /= rows b) $ error $
93 "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b 89 "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b
94 s <- createMatrix ColumnMajor (rows a) (cols b) 90 s <- createMatrix ColumnMajor (rows a) (cols b)
95 c_multiplyL m #! a #! b #! s #|"c_multiplyL" 91 c_multiplyL m # a # b # s #|"c_multiplyL"
96 return s 92 return s
97 93
98----------------------------------------------------------------------------- 94-----------------------------------------------------------------------------
99 95
100type TSVD t = t ..> t ..> R :> t ..> Ok 96type TSVD t = t ::> t ::> R :> t ::> Ok
101 97
102foreign import ccall unsafe "svd_l_R" dgesvd :: TSVD R 98foreign import ccall unsafe "svd_l_R" dgesvd :: TSVD R
103foreign import ccall unsafe "svd_l_C" zgesvd :: TSVD C 99foreign import ccall unsafe "svd_l_C" zgesvd :: TSVD C
@@ -126,8 +122,9 @@ svdAux f st x = unsafePerformIO $ do
126 v <- createMatrix ColumnMajor c c 122 v <- createMatrix ColumnMajor c c
127 f # x # u # s # v #| st 123 f # x # u # s # v #| st
128 return (u,s,v) 124 return (u,s,v)
129 where r = rows x 125 where
130 c = cols x 126 r = rows x
127 c = cols x
131 128
132 129
133-- | Thin SVD of a real matrix, using LAPACK's /dgesvd/ with jobu == jobvt == \'S\'. 130-- | Thin SVD of a real matrix, using LAPACK's /dgesvd/ with jobu == jobvt == \'S\'.
@@ -152,9 +149,10 @@ thinSVDAux f st x = unsafePerformIO $ do
152 v <- createMatrix ColumnMajor q c 149 v <- createMatrix ColumnMajor q c
153 f # x # u # s # v #| st 150 f # x # u # s # v #| st
154 return (u,s,v) 151 return (u,s,v)
155 where r = rows x 152 where
156 c = cols x 153 r = rows x
157 q = min r c 154 c = cols x
155 q = min r c
158 156
159 157
160-- | Singular values of a real matrix, using LAPACK's /dgesvd/ with jobu == jobvt == \'N\'. 158-- | Singular values of a real matrix, using LAPACK's /dgesvd/ with jobu == jobvt == \'N\'.
@@ -177,10 +175,11 @@ svAux f st x = unsafePerformIO $ do
177 s <- createVector q 175 s <- createVector q
178 g # x # s #| st 176 g # x # s #| st
179 return s 177 return s
180 where r = rows x 178 where
181 c = cols x 179 r = rows x
182 q = min r c 180 c = cols x
183 g ra ca pa nb pb = f ra ca pa 0 0 nullPtr nb pb 0 0 nullPtr 181 q = min r c
182 g ra ca xra xca pa nb pb = f ra ca xra xca pa 0 0 0 0 nullPtr nb pb 0 0 0 0 nullPtr
184 183
185 184
186-- | Singular values and all right singular vectors of a real matrix, using LAPACK's /dgesvd/ with jobu == \'N\' and jobvt == \'A\'. 185-- | Singular values and all right singular vectors of a real matrix, using LAPACK's /dgesvd/ with jobu == \'N\' and jobvt == \'A\'.
@@ -196,10 +195,11 @@ rightSVAux f st x = unsafePerformIO $ do
196 v <- createMatrix ColumnMajor c c 195 v <- createMatrix ColumnMajor c c
197 g # x # s # v #| st 196 g # x # s # v #| st
198 return (s,v) 197 return (s,v)
199 where r = rows x 198 where
200 c = cols x 199 r = rows x
201 q = min r c 200 c = cols x
202 g ra ca pa = f ra ca pa 0 0 nullPtr 201 q = min r c
202 g ra ca xra xca pa = f ra ca xra xca pa 0 0 0 0 nullPtr
203 203
204 204
205-- | Singular values and all left singular vectors of a real matrix, using LAPACK's /dgesvd/ with jobu == \'A\' and jobvt == \'N\'. 205-- | Singular values and all left singular vectors of a real matrix, using LAPACK's /dgesvd/ with jobu == \'A\' and jobvt == \'N\'.
@@ -215,25 +215,27 @@ leftSVAux f st x = unsafePerformIO $ do
215 s <- createVector q 215 s <- createVector q
216 g # x # u # s #| st 216 g # x # u # s #| st
217 return (u,s) 217 return (u,s)
218 where r = rows x 218 where
219 c = cols x 219 r = rows x
220 q = min r c 220 c = cols x
221 g ra ca pa ru cu pu nb pb = f ra ca pa ru cu pu nb pb 0 0 nullPtr 221 q = min r c
222 g ra ca xra xca pa ru cu xru xcu pu nb pb = f ra ca xra xca pa ru cu xru xcu pu nb pb 0 0 0 0 nullPtr
222 223
223----------------------------------------------------------------------------- 224-----------------------------------------------------------------------------
224 225
225foreign import ccall unsafe "eig_l_R" dgeev :: R ..> R ..> C :> R ..> Ok 226foreign import ccall unsafe "eig_l_R" dgeev :: R ::> R ::> C :> R ::> Ok
226foreign import ccall unsafe "eig_l_C" zgeev :: C ..> C ..> C :> C ..> Ok 227foreign import ccall unsafe "eig_l_C" zgeev :: C ::> C ::> C :> C ::> Ok
227foreign import ccall unsafe "eig_l_S" dsyev :: CInt -> R ..> R :> R ..> Ok 228foreign import ccall unsafe "eig_l_S" dsyev :: CInt -> R ::> R :> R ::> Ok
228foreign import ccall unsafe "eig_l_H" zheev :: CInt -> C ..> R :> C ..> Ok 229foreign import ccall unsafe "eig_l_H" zheev :: CInt -> C ::> R :> C ::> Ok
229 230
230eigAux f st m = unsafePerformIO $ do 231eigAux f st m = unsafePerformIO $ do
231 l <- createVector r 232 l <- createVector r
232 v <- createMatrix ColumnMajor r r 233 v <- createMatrix ColumnMajor r r
233 g # m # l # v #| st 234 g # m # l # v #| st
234 return (l,v) 235 return (l,v)
235 where r = rows m 236 where
236 g ra ca pa = f ra ca pa 0 0 nullPtr 237 r = rows m
238 g ra ca xra xca pa = f ra ca xra xca pa 0 0 0 0 nullPtr
237 239
238 240
239-- | Eigenvalues and right eigenvectors of a general complex matrix, using LAPACK's /zgeev/. 241-- | Eigenvalues and right eigenvectors of a general complex matrix, using LAPACK's /zgeev/.
@@ -242,11 +244,12 @@ eigC :: Matrix (Complex Double) -> (Vector (Complex Double), Matrix (Complex Dou
242eigC = eigAux zgeev "eigC" . fmat 244eigC = eigAux zgeev "eigC" . fmat
243 245
244eigOnlyAux f st m = unsafePerformIO $ do 246eigOnlyAux f st m = unsafePerformIO $ do
245 l <- createVector r 247 l <- createVector r
246 g # m # l #| st 248 g # m # l #| st
247 return l 249 return l
248 where r = rows m 250 where
249 g ra ca pa nl pl = f ra ca pa 0 0 nullPtr nl pl 0 0 nullPtr 251 r = rows m
252 g ra ca xra xca pa nl pl = f ra ca xra xca pa 0 0 0 0 nullPtr nl pl 0 0 0 0 nullPtr
250 253
251-- | Eigenvalues of a general complex matrix, using LAPACK's /zgeev/ with jobz == \'N\'. 254-- | Eigenvalues of a general complex matrix, using LAPACK's /zgeev/ with jobz == \'N\'.
252-- The eigenvalues are not sorted. 255-- The eigenvalues are not sorted.
@@ -264,12 +267,13 @@ eigR m = (s', v'')
264 267
265eigRaux :: Matrix Double -> (Vector (Complex Double), Matrix Double) 268eigRaux :: Matrix Double -> (Vector (Complex Double), Matrix Double)
266eigRaux m = unsafePerformIO $ do 269eigRaux m = unsafePerformIO $ do
267 l <- createVector r 270 l <- createVector r
268 v <- createMatrix ColumnMajor r r 271 v <- createMatrix ColumnMajor r r
269 g # m # l # v #| "eigR" 272 g # m # l # v #| "eigR"
270 return (l,v) 273 return (l,v)
271 where r = rows m 274 where
272 g ra ca pa = dgeev ra ca pa 0 0 nullPtr 275 r = rows m
276 g ra ca xra xca pa = dgeev ra ca xra xca pa 0 0 0 0 nullPtr
273 277
274fixeig1 s = toComplex' (subVector 0 r (asReal s), subVector r r (asReal s)) 278fixeig1 s = toComplex' (subVector 0 r (asReal s), subVector r r (asReal s))
275 where r = dim s 279 where r = dim s
@@ -291,11 +295,12 @@ eigOnlyR = fixeig1 . eigOnlyAux dgeev "eigOnlyR" . fmat
291----------------------------------------------------------------------------- 295-----------------------------------------------------------------------------
292 296
293eigSHAux f st m = unsafePerformIO $ do 297eigSHAux f st m = unsafePerformIO $ do
294 l <- createVector r 298 l <- createVector r
295 v <- createMatrix ColumnMajor r r 299 v <- createMatrix ColumnMajor r r
296 f # m # l # v #| st 300 f # m # l # v #| st
297 return (l,v) 301 return (l,v)
298 where r = rows m 302 where
303 r = rows m
299 304
300-- | Eigenvalues and right eigenvectors of a symmetric real matrix, using LAPACK's /dsyev/. 305-- | Eigenvalues and right eigenvectors of a symmetric real matrix, using LAPACK's /dsyev/.
301-- The eigenvectors are the columns of v. 306-- The eigenvectors are the columns of v.
@@ -314,8 +319,9 @@ eigS' = eigSHAux (dsyev 1) "eigS'" . fmat
314-- The eigenvalues are sorted in descending order (use 'eigH'' for ascending order). 319-- The eigenvalues are sorted in descending order (use 'eigH'' for ascending order).
315eigH :: Matrix (Complex Double) -> (Vector Double, Matrix (Complex Double)) 320eigH :: Matrix (Complex Double) -> (Vector Double, Matrix (Complex Double))
316eigH m = (s', fliprl v) 321eigH m = (s', fliprl v)
317 where (s,v) = eigH' (fmat m) 322 where
318 s' = fromList . reverse . toList $ s 323 (s,v) = eigH' (fmat m)
324 s' = fromList . reverse . toList $ s
319 325
320-- | 'eigH' in ascending order 326-- | 'eigH' in ascending order
321eigH' :: Matrix (Complex Double) -> (Vector Double, Matrix (Complex Double)) 327eigH' :: Matrix (Complex Double) -> (Vector Double, Matrix (Complex Double))
@@ -346,10 +352,11 @@ linearSolveSQAux g f st a b
346 f # a # b # s #| st 352 f # a # b # s #| st
347 return s 353 return s
348 | otherwise = error $ st ++ " of nonsquare matrix" 354 | otherwise = error $ st ++ " of nonsquare matrix"
349 where n1 = rows a 355 where
350 n2 = cols a 356 n1 = rows a
351 r = rows b 357 n2 = cols a
352 c = cols b 358 r = rows b
359 c = cols b
353 360
354-- | Solve a real linear system (for square coefficient matrix and several right-hand sides) using the LU decomposition, based on LAPACK's /dgesv/. For underconstrained or overconstrained systems use 'linearSolveLSR' or 'linearSolveSVDR'. See also 'lusR'. 361-- | Solve a real linear system (for square coefficient matrix and several right-hand sides) using the LU decomposition, based on LAPACK's /dgesv/. For underconstrained or overconstrained systems use 'linearSolveLSR' or 'linearSolveSVDR'. See also 'lusR'.
355linearSolveR :: Matrix Double -> Matrix Double -> Matrix Double 362linearSolveR :: Matrix Double -> Matrix Double -> Matrix Double
@@ -375,6 +382,7 @@ cholSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Comp
375cholSolveC a b = linearSolveSQAux id zpotrs "cholSolveC" (fmat a) (fmat b) 382cholSolveC a b = linearSolveSQAux id zpotrs "cholSolveC" (fmat a) (fmat b)
376 383
377----------------------------------------------------------------------------------- 384-----------------------------------------------------------------------------------
385
378foreign import ccall unsafe "linearSolveLSR_l" dgels :: TMMM R 386foreign import ccall unsafe "linearSolveLSR_l" dgels :: TMMM R
379foreign import ccall unsafe "linearSolveLSC_l" zgels :: TMMM C 387foreign import ccall unsafe "linearSolveLSC_l" zgels :: TMMM C
380foreign import ccall unsafe "linearSolveSVDR_l" dgelss :: Double -> TMMM R 388foreign import ccall unsafe "linearSolveSVDR_l" dgelss :: Double -> TMMM R
@@ -384,9 +392,10 @@ linearSolveAux f st a b = unsafePerformIO $ do
384 r <- createMatrix ColumnMajor (max m n) nrhs 392 r <- createMatrix ColumnMajor (max m n) nrhs
385 f # a # b # r #| st 393 f # a # b # r #| st
386 return r 394 return r
387 where m = rows a 395 where
388 n = cols a 396 m = rows a
389 nrhs = cols b 397 n = cols a
398 nrhs = cols b
390 399
391-- | Least squared error solution of an overconstrained real linear system, or the minimum norm solution of an underconstrained system, using LAPACK's /dgels/. For rank-deficient systems use 'linearSolveSVDR'. 400-- | Least squared error solution of an overconstrained real linear system, or the minimum norm solution of an underconstrained system, using LAPACK's /dgels/. For rank-deficient systems use 'linearSolveSVDR'.
392linearSolveLSR :: Matrix Double -> Matrix Double -> Matrix Double 401linearSolveLSR :: Matrix Double -> Matrix Double -> Matrix Double
@@ -418,7 +427,7 @@ linearSolveSVDC Nothing a b = linearSolveSVDC (Just (-1)) (fmat a) (fmat b)
418 427
419----------------------------------------------------------------------------------- 428-----------------------------------------------------------------------------------
420 429
421type TMM t = t ..> t ..> Ok 430type TMM t = t ::> t ::> Ok
422 431
423foreign import ccall unsafe "chol_l_H" zpotrf :: TMM C 432foreign import ccall unsafe "chol_l_H" zpotrf :: TMM C
424foreign import ccall unsafe "chol_l_S" dpotrf :: TMM R 433foreign import ccall unsafe "chol_l_S" dpotrf :: TMM R
@@ -427,7 +436,8 @@ cholAux f st a = do
427 r <- createMatrix ColumnMajor n n 436 r <- createMatrix ColumnMajor n n
428 f # a # r #| st 437 f # a # r #| st
429 return r 438 return r
430 where n = rows a 439 where
440 n = rows a
431 441
432-- | Cholesky factorization of a complex Hermitian positive definite matrix, using LAPACK's /zpotrf/. 442-- | Cholesky factorization of a complex Hermitian positive definite matrix, using LAPACK's /zpotrf/.
433cholH :: Matrix (Complex Double) -> Matrix (Complex Double) 443cholH :: Matrix (Complex Double) -> Matrix (Complex Double)
@@ -447,7 +457,7 @@ mbCholS = unsafePerformIO . mbCatch . cholAux dpotrf "cholS" . fmat
447 457
448----------------------------------------------------------------------------------- 458-----------------------------------------------------------------------------------
449 459
450type TMVM t = t ..> t :> t ..> Ok 460type TMVM t = t ::> t :> t ::> Ok
451 461
452foreign import ccall unsafe "qr_l_R" dgeqr2 :: TMVM R 462foreign import ccall unsafe "qr_l_R" dgeqr2 :: TMVM R
453foreign import ccall unsafe "qr_l_C" zgeqr2 :: TMVM C 463foreign import ccall unsafe "qr_l_C" zgeqr2 :: TMVM C
@@ -504,13 +514,14 @@ hessAux f st a = unsafePerformIO $ do
504 tau <- createVector (mn-1) 514 tau <- createVector (mn-1)
505 f # a # tau # r #| st 515 f # a # tau # r #| st
506 return (r,tau) 516 return (r,tau)
507 where m = rows a 517 where
508 n = cols a 518 m = rows a
509 mn = min m n 519 n = cols a
520 mn = min m n
510 521
511----------------------------------------------------------------------------------- 522-----------------------------------------------------------------------------------
512foreign import ccall unsafe "schur_l_R" dgees :: TMMM R 523foreign import ccall unsafe "schur_l_R" dgees :: R ::> R ::> R ::> Ok
513foreign import ccall unsafe "schur_l_C" zgees :: TMMM C 524foreign import ccall unsafe "schur_l_C" zgees :: C ::> C ::> C ::> Ok
514 525
515-- | Schur factorization of a square real matrix, using LAPACK's /dgees/. 526-- | Schur factorization of a square real matrix, using LAPACK's /dgees/.
516schurR :: Matrix Double -> (Matrix Double, Matrix Double) 527schurR :: Matrix Double -> (Matrix Double, Matrix Double)
@@ -525,11 +536,12 @@ schurAux f st a = unsafePerformIO $ do
525 s <- createMatrix ColumnMajor n n 536 s <- createMatrix ColumnMajor n n
526 f # a # u # s #| st 537 f # a # u # s #| st
527 return (u,s) 538 return (u,s)
528 where n = rows a 539 where
540 n = rows a
529 541
530----------------------------------------------------------------------------------- 542-----------------------------------------------------------------------------------
531foreign import ccall unsafe "lu_l_R" dgetrf :: TMVM R 543foreign import ccall unsafe "lu_l_R" dgetrf :: TMVM R
532foreign import ccall unsafe "lu_l_C" zgetrf :: C ..> R :> C ..> Ok 544foreign import ccall unsafe "lu_l_C" zgetrf :: C ::> R :> C ::> Ok
533 545
534-- | LU factorization of a general real matrix, using LAPACK's /dgetrf/. 546-- | LU factorization of a general real matrix, using LAPACK's /dgetrf/.
535luR :: Matrix Double -> (Matrix Double, [Int]) 547luR :: Matrix Double -> (Matrix Double, [Int])
@@ -544,12 +556,13 @@ luAux f st a = unsafePerformIO $ do
544 piv <- createVector (min n m) 556 piv <- createVector (min n m)
545 f # a # piv # lu #| st 557 f # a # piv # lu #| st
546 return (lu, map (pred.round) (toList piv)) 558 return (lu, map (pred.round) (toList piv))
547 where n = rows a 559 where
548 m = cols a 560 n = rows a
561 m = cols a
549 562
550----------------------------------------------------------------------------------- 563-----------------------------------------------------------------------------------
551 564
552type Tlus t = t ..> Double :> t ..> t ..> Ok 565type Tlus t = t ::> Double :> t ::> t ::> Ok
553 566
554foreign import ccall unsafe "luS_l_R" dgetrs :: Tlus R 567foreign import ccall unsafe "luS_l_R" dgetrs :: Tlus R
555foreign import ccall unsafe "luS_l_C" zgetrs :: Tlus C 568foreign import ccall unsafe "luS_l_C" zgetrs :: Tlus C