diff options
Diffstat (limited to 'packages/base/src')
-rw-r--r-- | packages/base/src/Internal/LAPACK.hs (renamed from packages/base/src/Numeric/LinearAlgebra/LAPACK.hs) | 137 |
1 files changed, 63 insertions, 74 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index 6fb2b13..9cab3f8 100644 --- a/packages/base/src/Numeric/LinearAlgebra/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs | |||
@@ -1,3 +1,5 @@ | |||
1 | {-# LANGUAGE TypeOperators #-} | ||
2 | |||
1 | ----------------------------------------------------------------------------- | 3 | ----------------------------------------------------------------------------- |
2 | -- | | 4 | -- | |
3 | -- Module : Numeric.LinearAlgebra.LAPACK | 5 | -- Module : Numeric.LinearAlgebra.LAPACK |
@@ -9,56 +11,36 @@ | |||
9 | -- Functional interface to selected LAPACK functions (<http://www.netlib.org/lapack>). | 11 | -- Functional interface to selected LAPACK functions (<http://www.netlib.org/lapack>). |
10 | -- | 12 | -- |
11 | ----------------------------------------------------------------------------- | 13 | ----------------------------------------------------------------------------- |
12 | {-# OPTIONS_HADDOCK hide #-} | 14 | |
13 | 15 | ||
14 | 16 | module Internal.LAPACK where | |
15 | module Numeric.LinearAlgebra.LAPACK ( | 17 | |
16 | -- * Matrix product | 18 | import Internal.Devel |
17 | multiplyR, multiplyC, multiplyF, multiplyQ, multiplyI, | 19 | import Internal.Vector |
18 | -- * Linear systems | 20 | import Internal.Matrix |
19 | linearSolveR, linearSolveC, | 21 | import Internal.Conversion |
20 | mbLinearSolveR, mbLinearSolveC, | 22 | import Internal.Element |
21 | lusR, lusC, | ||
22 | cholSolveR, cholSolveC, | ||
23 | linearSolveLSR, linearSolveLSC, | ||
24 | linearSolveSVDR, linearSolveSVDC, | ||
25 | -- * SVD | ||
26 | svR, svRd, svC, svCd, | ||
27 | svdR, svdRd, svdC, svdCd, | ||
28 | thinSVDR, thinSVDRd, thinSVDC, thinSVDCd, | ||
29 | rightSVR, rightSVC, leftSVR, leftSVC, | ||
30 | -- * Eigensystems | ||
31 | eigR, eigC, eigS, eigS', eigH, eigH', | ||
32 | eigOnlyR, eigOnlyC, eigOnlyS, eigOnlyH, | ||
33 | -- * LU | ||
34 | luR, luC, | ||
35 | -- * Cholesky | ||
36 | cholS, cholH, mbCholS, mbCholH, | ||
37 | -- * QR | ||
38 | qrR, qrC, qrgrR, qrgrC, | ||
39 | -- * Hessenberg | ||
40 | hessR, hessC, | ||
41 | -- * Schur | ||
42 | schurR, schurC | ||
43 | ) where | ||
44 | |||
45 | import Data.Packed.Development | ||
46 | import Data.Packed | ||
47 | import Data.Packed.Internal | ||
48 | import Numeric.Conversion | ||
49 | 23 | ||
50 | import Foreign.Ptr(nullPtr) | 24 | import Foreign.Ptr(nullPtr) |
51 | import Foreign.C.Types | 25 | import Foreign.C.Types |
52 | import Control.Monad(when) | 26 | import Control.Monad(when) |
53 | import System.IO.Unsafe(unsafePerformIO) | 27 | import System.IO.Unsafe(unsafePerformIO) |
28 | import Data.Vector.Storable(fromList) | ||
54 | 29 | ||
55 | ----------------------------------------------------------------------------------- | 30 | ----------------------------------------------------------------------------------- |
56 | 31 | ||
57 | foreign import ccall unsafe "multiplyR" dgemmc :: CInt -> CInt -> TMMM | 32 | type TMMM t = t ..> t ..> t ..> Ok |
58 | foreign import ccall unsafe "multiplyC" zgemmc :: CInt -> CInt -> TCMCMCM | 33 | |
59 | foreign import ccall unsafe "multiplyF" sgemmc :: CInt -> CInt -> TFMFMFM | 34 | type R = Double |
60 | foreign import ccall unsafe "multiplyQ" cgemmc :: CInt -> CInt -> TQMQMQM | 35 | type C = Complex Double |
61 | foreign import ccall unsafe "multiplyI" c_multiplyI :: OM CInt (OM CInt (OM CInt (IO CInt))) | 36 | type F = Float |
37 | type Q = Complex Float | ||
38 | |||
39 | foreign import ccall unsafe "multiplyR" dgemmc :: CInt -> CInt -> TMMM R | ||
40 | foreign import ccall unsafe "multiplyC" zgemmc :: CInt -> CInt -> TMMM C | ||
41 | foreign import ccall unsafe "multiplyF" sgemmc :: CInt -> CInt -> TMMM F | ||
42 | foreign import ccall unsafe "multiplyQ" cgemmc :: CInt -> CInt -> TMMM Q | ||
43 | foreign import ccall unsafe "multiplyI" c_multiplyI :: CInt ::> CInt ::> CInt ::> Ok | ||
62 | 44 | ||
63 | isT Matrix{order = ColumnMajor} = 0 | 45 | isT Matrix{order = ColumnMajor} = 0 |
64 | isT Matrix{order = RowMajor} = 1 | 46 | isT Matrix{order = RowMajor} = 1 |
@@ -98,10 +80,13 @@ multiplyI a b = unsafePerformIO $ do | |||
98 | return s | 80 | return s |
99 | 81 | ||
100 | ----------------------------------------------------------------------------- | 82 | ----------------------------------------------------------------------------- |
101 | foreign import ccall unsafe "svd_l_R" dgesvd :: TMMVM | 83 | |
102 | foreign import ccall unsafe "svd_l_C" zgesvd :: TCMCMVCM | 84 | type TSVD t = t ..> t ..> R :> t ..> Ok |
103 | foreign import ccall unsafe "svd_l_Rdd" dgesdd :: TMMVM | 85 | |
104 | foreign import ccall unsafe "svd_l_Cdd" zgesdd :: TCMCMVCM | 86 | foreign import ccall unsafe "svd_l_R" dgesvd :: TSVD R |
87 | foreign import ccall unsafe "svd_l_C" zgesvd :: TSVD C | ||
88 | foreign import ccall unsafe "svd_l_Rdd" dgesdd :: TSVD R | ||
89 | foreign import ccall unsafe "svd_l_Cdd" zgesdd :: TSVD C | ||
105 | 90 | ||
106 | -- | Full SVD of a real matrix using LAPACK's /dgesvd/. | 91 | -- | Full SVD of a real matrix using LAPACK's /dgesvd/. |
107 | svdR :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) | 92 | svdR :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) |
@@ -221,10 +206,10 @@ leftSVAux f st x = unsafePerformIO $ do | |||
221 | 206 | ||
222 | ----------------------------------------------------------------------------- | 207 | ----------------------------------------------------------------------------- |
223 | 208 | ||
224 | foreign import ccall unsafe "eig_l_R" dgeev :: TMMCVM | 209 | foreign import ccall unsafe "eig_l_R" dgeev :: R ..> R ..> C :> R ..> Ok |
225 | foreign import ccall unsafe "eig_l_C" zgeev :: TCMCMCVCM | 210 | foreign import ccall unsafe "eig_l_C" zgeev :: C ..> C ..> C :> C ..> Ok |
226 | foreign import ccall unsafe "eig_l_S" dsyev :: CInt -> TMVM | 211 | foreign import ccall unsafe "eig_l_S" dsyev :: CInt -> R ..> R :> R ..> Ok |
227 | foreign import ccall unsafe "eig_l_H" zheev :: CInt -> TCMVCM | 212 | foreign import ccall unsafe "eig_l_H" zheev :: CInt -> C ..> R :> C ..> Ok |
228 | 213 | ||
229 | eigAux f st m = unsafePerformIO $ do | 214 | eigAux f st m = unsafePerformIO $ do |
230 | l <- createVector r | 215 | l <- createVector r |
@@ -334,10 +319,10 @@ eigOnlyH = vrev . fst. eigSHAux (zheev 0) "eigH'" . fmat | |||
334 | vrev = flatten . flipud . reshape 1 | 319 | vrev = flatten . flipud . reshape 1 |
335 | 320 | ||
336 | ----------------------------------------------------------------------------- | 321 | ----------------------------------------------------------------------------- |
337 | foreign import ccall unsafe "linearSolveR_l" dgesv :: TMMM | 322 | foreign import ccall unsafe "linearSolveR_l" dgesv :: TMMM R |
338 | foreign import ccall unsafe "linearSolveC_l" zgesv :: TCMCMCM | 323 | foreign import ccall unsafe "linearSolveC_l" zgesv :: TMMM C |
339 | foreign import ccall unsafe "cholSolveR_l" dpotrs :: TMMM | 324 | foreign import ccall unsafe "cholSolveR_l" dpotrs :: TMMM R |
340 | foreign import ccall unsafe "cholSolveC_l" zpotrs :: TCMCMCM | 325 | foreign import ccall unsafe "cholSolveC_l" zpotrs :: TMMM C |
341 | 326 | ||
342 | linearSolveSQAux g f st a b | 327 | linearSolveSQAux g f st a b |
343 | | n1==n2 && n1==r = unsafePerformIO . g $ do | 328 | | n1==n2 && n1==r = unsafePerformIO . g $ do |
@@ -374,10 +359,10 @@ cholSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Comp | |||
374 | cholSolveC a b = linearSolveSQAux id zpotrs "cholSolveC" (fmat a) (fmat b) | 359 | cholSolveC a b = linearSolveSQAux id zpotrs "cholSolveC" (fmat a) (fmat b) |
375 | 360 | ||
376 | ----------------------------------------------------------------------------------- | 361 | ----------------------------------------------------------------------------------- |
377 | foreign import ccall unsafe "linearSolveLSR_l" dgels :: TMMM | 362 | foreign import ccall unsafe "linearSolveLSR_l" dgels :: TMMM R |
378 | foreign import ccall unsafe "linearSolveLSC_l" zgels :: TCMCMCM | 363 | foreign import ccall unsafe "linearSolveLSC_l" zgels :: TMMM C |
379 | foreign import ccall unsafe "linearSolveSVDR_l" dgelss :: Double -> TMMM | 364 | foreign import ccall unsafe "linearSolveSVDR_l" dgelss :: Double -> TMMM R |
380 | foreign import ccall unsafe "linearSolveSVDC_l" zgelss :: Double -> TCMCMCM | 365 | foreign import ccall unsafe "linearSolveSVDC_l" zgelss :: Double -> TMMM C |
381 | 366 | ||
382 | linearSolveAux f st a b = unsafePerformIO $ do | 367 | linearSolveAux f st a b = unsafePerformIO $ do |
383 | r <- createMatrix ColumnMajor (max m n) nrhs | 368 | r <- createMatrix ColumnMajor (max m n) nrhs |
@@ -416,8 +401,9 @@ linearSolveSVDC (Just rcond) a b = subMatrix (0,0) (cols a, cols b) $ | |||
416 | linearSolveSVDC Nothing a b = linearSolveSVDC (Just (-1)) (fmat a) (fmat b) | 401 | linearSolveSVDC Nothing a b = linearSolveSVDC (Just (-1)) (fmat a) (fmat b) |
417 | 402 | ||
418 | ----------------------------------------------------------------------------------- | 403 | ----------------------------------------------------------------------------------- |
419 | foreign import ccall unsafe "chol_l_H" zpotrf :: TCMCM | 404 | |
420 | foreign import ccall unsafe "chol_l_S" dpotrf :: TMM | 405 | foreign import ccall unsafe "chol_l_H" zpotrf :: TMM C |
406 | foreign import ccall unsafe "chol_l_S" dpotrf :: TMM R | ||
421 | 407 | ||
422 | cholAux f st a = do | 408 | cholAux f st a = do |
423 | r <- createMatrix ColumnMajor n n | 409 | r <- createMatrix ColumnMajor n n |
@@ -442,8 +428,11 @@ mbCholS :: Matrix Double -> Maybe (Matrix Double) | |||
442 | mbCholS = unsafePerformIO . mbCatch . cholAux dpotrf "cholS" . fmat | 428 | mbCholS = unsafePerformIO . mbCatch . cholAux dpotrf "cholS" . fmat |
443 | 429 | ||
444 | ----------------------------------------------------------------------------------- | 430 | ----------------------------------------------------------------------------------- |
445 | foreign import ccall unsafe "qr_l_R" dgeqr2 :: TMVM | 431 | |
446 | foreign import ccall unsafe "qr_l_C" zgeqr2 :: TCMCVCM | 432 | type TMVM t = t ..> t :> t ..> Ok |
433 | |||
434 | foreign import ccall unsafe "qr_l_R" dgeqr2 :: TMVM R | ||
435 | foreign import ccall unsafe "qr_l_C" zgeqr2 :: TMVM C | ||
447 | 436 | ||
448 | -- | QR factorization of a real matrix, using LAPACK's /dgeqr2/. | 437 | -- | QR factorization of a real matrix, using LAPACK's /dgeqr2/. |
449 | qrR :: Matrix Double -> (Matrix Double, Vector Double) | 438 | qrR :: Matrix Double -> (Matrix Double, Vector Double) |
@@ -463,8 +452,8 @@ qrAux f st a = unsafePerformIO $ do | |||
463 | n = cols a | 452 | n = cols a |
464 | mn = min m n | 453 | mn = min m n |
465 | 454 | ||
466 | foreign import ccall unsafe "c_dorgqr" dorgqr :: TMVM | 455 | foreign import ccall unsafe "c_dorgqr" dorgqr :: TMVM R |
467 | foreign import ccall unsafe "c_zungqr" zungqr :: TCMCVCM | 456 | foreign import ccall unsafe "c_zungqr" zungqr :: TMVM C |
468 | 457 | ||
469 | -- | build rotation from reflectors | 458 | -- | build rotation from reflectors |
470 | qrgrR :: Int -> (Matrix Double, Vector Double) -> Matrix Double | 459 | qrgrR :: Int -> (Matrix Double, Vector Double) -> Matrix Double |
@@ -481,8 +470,8 @@ qrgrAux f st n (a, tau) = unsafePerformIO $ do | |||
481 | tau' = vjoin [tau, constantD 0 n] | 470 | tau' = vjoin [tau, constantD 0 n] |
482 | 471 | ||
483 | ----------------------------------------------------------------------------------- | 472 | ----------------------------------------------------------------------------------- |
484 | foreign import ccall unsafe "hess_l_R" dgehrd :: TMVM | 473 | foreign import ccall unsafe "hess_l_R" dgehrd :: TMVM R |
485 | foreign import ccall unsafe "hess_l_C" zgehrd :: TCMCVCM | 474 | foreign import ccall unsafe "hess_l_C" zgehrd :: TMVM C |
486 | 475 | ||
487 | -- | Hessenberg factorization of a square real matrix, using LAPACK's /dgehrd/. | 476 | -- | Hessenberg factorization of a square real matrix, using LAPACK's /dgehrd/. |
488 | hessR :: Matrix Double -> (Matrix Double, Vector Double) | 477 | hessR :: Matrix Double -> (Matrix Double, Vector Double) |
@@ -502,8 +491,8 @@ hessAux f st a = unsafePerformIO $ do | |||
502 | mn = min m n | 491 | mn = min m n |
503 | 492 | ||
504 | ----------------------------------------------------------------------------------- | 493 | ----------------------------------------------------------------------------------- |
505 | foreign import ccall unsafe "schur_l_R" dgees :: TMMM | 494 | foreign import ccall unsafe "schur_l_R" dgees :: TMMM R |
506 | foreign import ccall unsafe "schur_l_C" zgees :: TCMCMCM | 495 | foreign import ccall unsafe "schur_l_C" zgees :: TMMM C |
507 | 496 | ||
508 | -- | Schur factorization of a square real matrix, using LAPACK's /dgees/. | 497 | -- | Schur factorization of a square real matrix, using LAPACK's /dgees/. |
509 | schurR :: Matrix Double -> (Matrix Double, Matrix Double) | 498 | schurR :: Matrix Double -> (Matrix Double, Matrix Double) |
@@ -521,8 +510,8 @@ schurAux f st a = unsafePerformIO $ do | |||
521 | where n = rows a | 510 | where n = rows a |
522 | 511 | ||
523 | ----------------------------------------------------------------------------------- | 512 | ----------------------------------------------------------------------------------- |
524 | foreign import ccall unsafe "lu_l_R" dgetrf :: TMVM | 513 | foreign import ccall unsafe "lu_l_R" dgetrf :: TMVM R |
525 | foreign import ccall unsafe "lu_l_C" zgetrf :: TCMVCM | 514 | foreign import ccall unsafe "lu_l_C" zgetrf :: C ..> R :> C ..> Ok |
526 | 515 | ||
527 | -- | LU factorization of a general real matrix, using LAPACK's /dgetrf/. | 516 | -- | LU factorization of a general real matrix, using LAPACK's /dgetrf/. |
528 | luR :: Matrix Double -> (Matrix Double, [Int]) | 517 | luR :: Matrix Double -> (Matrix Double, [Int]) |
@@ -541,11 +530,11 @@ luAux f st a = unsafePerformIO $ do | |||
541 | m = cols a | 530 | m = cols a |
542 | 531 | ||
543 | ----------------------------------------------------------------------------------- | 532 | ----------------------------------------------------------------------------------- |
544 | type TW a = CInt -> PD -> a | ||
545 | type TQ a = CInt -> CInt -> PC -> a | ||
546 | 533 | ||
547 | foreign import ccall unsafe "luS_l_R" dgetrs :: TMVMM | 534 | type Tlus t = t ..> Double :> t ..> t ..> Ok |
548 | foreign import ccall unsafe "luS_l_C" zgetrs :: TQ (TW (TQ (TQ (IO CInt)))) | 535 | |
536 | foreign import ccall unsafe "luS_l_R" dgetrs :: Tlus R | ||
537 | foreign import ccall unsafe "luS_l_C" zgetrs :: Tlus C | ||
549 | 538 | ||
550 | -- | Solve a real linear system from a precomputed LU decomposition ('luR'), using LAPACK's /dgetrs/. | 539 | -- | Solve a real linear system from a precomputed LU decomposition ('luR'), using LAPACK's /dgetrs/. |
551 | lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double | 540 | lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double |