summaryrefslogtreecommitdiff
path: root/packages/base
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2015-06-05 16:45:23 +0200
committerAlberto Ruiz <aruiz@um.es>2015-06-05 16:45:23 +0200
commit5d967adae38d7fe80443b57b40ae89cd15db5e18 (patch)
treef6cce41d547d539f7f0cb8b92e0273d3ecab5d45 /packages/base
parent64df799c68817054705a99e9ee02723603fae29e (diff)
move lapack
Diffstat (limited to 'packages/base')
-rw-r--r--packages/base/src/Internal/LAPACK.hs (renamed from packages/base/src/Numeric/LinearAlgebra/LAPACK.hs)137
1 files changed, 63 insertions, 74 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs
index 6fb2b13..9cab3f8 100644
--- a/packages/base/src/Numeric/LinearAlgebra/LAPACK.hs
+++ b/packages/base/src/Internal/LAPACK.hs
@@ -1,3 +1,5 @@
1{-# LANGUAGE TypeOperators #-}
2
1----------------------------------------------------------------------------- 3-----------------------------------------------------------------------------
2-- | 4-- |
3-- Module : Numeric.LinearAlgebra.LAPACK 5-- Module : Numeric.LinearAlgebra.LAPACK
@@ -9,56 +11,36 @@
9-- Functional interface to selected LAPACK functions (<http://www.netlib.org/lapack>). 11-- Functional interface to selected LAPACK functions (<http://www.netlib.org/lapack>).
10-- 12--
11----------------------------------------------------------------------------- 13-----------------------------------------------------------------------------
12{-# OPTIONS_HADDOCK hide #-} 14
13 15
14 16module Internal.LAPACK where
15module Numeric.LinearAlgebra.LAPACK ( 17
16 -- * Matrix product 18import Internal.Devel
17 multiplyR, multiplyC, multiplyF, multiplyQ, multiplyI, 19import Internal.Vector
18 -- * Linear systems 20import Internal.Matrix
19 linearSolveR, linearSolveC, 21import Internal.Conversion
20 mbLinearSolveR, mbLinearSolveC, 22import Internal.Element
21 lusR, lusC,
22 cholSolveR, cholSolveC,
23 linearSolveLSR, linearSolveLSC,
24 linearSolveSVDR, linearSolveSVDC,
25 -- * SVD
26 svR, svRd, svC, svCd,
27 svdR, svdRd, svdC, svdCd,
28 thinSVDR, thinSVDRd, thinSVDC, thinSVDCd,
29 rightSVR, rightSVC, leftSVR, leftSVC,
30 -- * Eigensystems
31 eigR, eigC, eigS, eigS', eigH, eigH',
32 eigOnlyR, eigOnlyC, eigOnlyS, eigOnlyH,
33 -- * LU
34 luR, luC,
35 -- * Cholesky
36 cholS, cholH, mbCholS, mbCholH,
37 -- * QR
38 qrR, qrC, qrgrR, qrgrC,
39 -- * Hessenberg
40 hessR, hessC,
41 -- * Schur
42 schurR, schurC
43) where
44
45import Data.Packed.Development
46import Data.Packed
47import Data.Packed.Internal
48import Numeric.Conversion
49 23
50import Foreign.Ptr(nullPtr) 24import Foreign.Ptr(nullPtr)
51import Foreign.C.Types 25import Foreign.C.Types
52import Control.Monad(when) 26import Control.Monad(when)
53import System.IO.Unsafe(unsafePerformIO) 27import System.IO.Unsafe(unsafePerformIO)
28import Data.Vector.Storable(fromList)
54 29
55----------------------------------------------------------------------------------- 30-----------------------------------------------------------------------------------
56 31
57foreign import ccall unsafe "multiplyR" dgemmc :: CInt -> CInt -> TMMM 32type TMMM t = t ..> t ..> t ..> Ok
58foreign import ccall unsafe "multiplyC" zgemmc :: CInt -> CInt -> TCMCMCM 33
59foreign import ccall unsafe "multiplyF" sgemmc :: CInt -> CInt -> TFMFMFM 34type R = Double
60foreign import ccall unsafe "multiplyQ" cgemmc :: CInt -> CInt -> TQMQMQM 35type C = Complex Double
61foreign import ccall unsafe "multiplyI" c_multiplyI :: OM CInt (OM CInt (OM CInt (IO CInt))) 36type F = Float
37type Q = Complex Float
38
39foreign import ccall unsafe "multiplyR" dgemmc :: CInt -> CInt -> TMMM R
40foreign import ccall unsafe "multiplyC" zgemmc :: CInt -> CInt -> TMMM C
41foreign import ccall unsafe "multiplyF" sgemmc :: CInt -> CInt -> TMMM F
42foreign import ccall unsafe "multiplyQ" cgemmc :: CInt -> CInt -> TMMM Q
43foreign import ccall unsafe "multiplyI" c_multiplyI :: CInt ::> CInt ::> CInt ::> Ok
62 44
63isT Matrix{order = ColumnMajor} = 0 45isT Matrix{order = ColumnMajor} = 0
64isT Matrix{order = RowMajor} = 1 46isT Matrix{order = RowMajor} = 1
@@ -98,10 +80,13 @@ multiplyI a b = unsafePerformIO $ do
98 return s 80 return s
99 81
100----------------------------------------------------------------------------- 82-----------------------------------------------------------------------------
101foreign import ccall unsafe "svd_l_R" dgesvd :: TMMVM 83
102foreign import ccall unsafe "svd_l_C" zgesvd :: TCMCMVCM 84type TSVD t = t ..> t ..> R :> t ..> Ok
103foreign import ccall unsafe "svd_l_Rdd" dgesdd :: TMMVM 85
104foreign import ccall unsafe "svd_l_Cdd" zgesdd :: TCMCMVCM 86foreign import ccall unsafe "svd_l_R" dgesvd :: TSVD R
87foreign import ccall unsafe "svd_l_C" zgesvd :: TSVD C
88foreign import ccall unsafe "svd_l_Rdd" dgesdd :: TSVD R
89foreign import ccall unsafe "svd_l_Cdd" zgesdd :: TSVD C
105 90
106-- | Full SVD of a real matrix using LAPACK's /dgesvd/. 91-- | Full SVD of a real matrix using LAPACK's /dgesvd/.
107svdR :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) 92svdR :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double)
@@ -221,10 +206,10 @@ leftSVAux f st x = unsafePerformIO $ do
221 206
222----------------------------------------------------------------------------- 207-----------------------------------------------------------------------------
223 208
224foreign import ccall unsafe "eig_l_R" dgeev :: TMMCVM 209foreign import ccall unsafe "eig_l_R" dgeev :: R ..> R ..> C :> R ..> Ok
225foreign import ccall unsafe "eig_l_C" zgeev :: TCMCMCVCM 210foreign import ccall unsafe "eig_l_C" zgeev :: C ..> C ..> C :> C ..> Ok
226foreign import ccall unsafe "eig_l_S" dsyev :: CInt -> TMVM 211foreign import ccall unsafe "eig_l_S" dsyev :: CInt -> R ..> R :> R ..> Ok
227foreign import ccall unsafe "eig_l_H" zheev :: CInt -> TCMVCM 212foreign import ccall unsafe "eig_l_H" zheev :: CInt -> C ..> R :> C ..> Ok
228 213
229eigAux f st m = unsafePerformIO $ do 214eigAux f st m = unsafePerformIO $ do
230 l <- createVector r 215 l <- createVector r
@@ -334,10 +319,10 @@ eigOnlyH = vrev . fst. eigSHAux (zheev 0) "eigH'" . fmat
334vrev = flatten . flipud . reshape 1 319vrev = flatten . flipud . reshape 1
335 320
336----------------------------------------------------------------------------- 321-----------------------------------------------------------------------------
337foreign import ccall unsafe "linearSolveR_l" dgesv :: TMMM 322foreign import ccall unsafe "linearSolveR_l" dgesv :: TMMM R
338foreign import ccall unsafe "linearSolveC_l" zgesv :: TCMCMCM 323foreign import ccall unsafe "linearSolveC_l" zgesv :: TMMM C
339foreign import ccall unsafe "cholSolveR_l" dpotrs :: TMMM 324foreign import ccall unsafe "cholSolveR_l" dpotrs :: TMMM R
340foreign import ccall unsafe "cholSolveC_l" zpotrs :: TCMCMCM 325foreign import ccall unsafe "cholSolveC_l" zpotrs :: TMMM C
341 326
342linearSolveSQAux g f st a b 327linearSolveSQAux g f st a b
343 | n1==n2 && n1==r = unsafePerformIO . g $ do 328 | n1==n2 && n1==r = unsafePerformIO . g $ do
@@ -374,10 +359,10 @@ cholSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Comp
374cholSolveC a b = linearSolveSQAux id zpotrs "cholSolveC" (fmat a) (fmat b) 359cholSolveC a b = linearSolveSQAux id zpotrs "cholSolveC" (fmat a) (fmat b)
375 360
376----------------------------------------------------------------------------------- 361-----------------------------------------------------------------------------------
377foreign import ccall unsafe "linearSolveLSR_l" dgels :: TMMM 362foreign import ccall unsafe "linearSolveLSR_l" dgels :: TMMM R
378foreign import ccall unsafe "linearSolveLSC_l" zgels :: TCMCMCM 363foreign import ccall unsafe "linearSolveLSC_l" zgels :: TMMM C
379foreign import ccall unsafe "linearSolveSVDR_l" dgelss :: Double -> TMMM 364foreign import ccall unsafe "linearSolveSVDR_l" dgelss :: Double -> TMMM R
380foreign import ccall unsafe "linearSolveSVDC_l" zgelss :: Double -> TCMCMCM 365foreign import ccall unsafe "linearSolveSVDC_l" zgelss :: Double -> TMMM C
381 366
382linearSolveAux f st a b = unsafePerformIO $ do 367linearSolveAux f st a b = unsafePerformIO $ do
383 r <- createMatrix ColumnMajor (max m n) nrhs 368 r <- createMatrix ColumnMajor (max m n) nrhs
@@ -416,8 +401,9 @@ linearSolveSVDC (Just rcond) a b = subMatrix (0,0) (cols a, cols b) $
416linearSolveSVDC Nothing a b = linearSolveSVDC (Just (-1)) (fmat a) (fmat b) 401linearSolveSVDC Nothing a b = linearSolveSVDC (Just (-1)) (fmat a) (fmat b)
417 402
418----------------------------------------------------------------------------------- 403-----------------------------------------------------------------------------------
419foreign import ccall unsafe "chol_l_H" zpotrf :: TCMCM 404
420foreign import ccall unsafe "chol_l_S" dpotrf :: TMM 405foreign import ccall unsafe "chol_l_H" zpotrf :: TMM C
406foreign import ccall unsafe "chol_l_S" dpotrf :: TMM R
421 407
422cholAux f st a = do 408cholAux f st a = do
423 r <- createMatrix ColumnMajor n n 409 r <- createMatrix ColumnMajor n n
@@ -442,8 +428,11 @@ mbCholS :: Matrix Double -> Maybe (Matrix Double)
442mbCholS = unsafePerformIO . mbCatch . cholAux dpotrf "cholS" . fmat 428mbCholS = unsafePerformIO . mbCatch . cholAux dpotrf "cholS" . fmat
443 429
444----------------------------------------------------------------------------------- 430-----------------------------------------------------------------------------------
445foreign import ccall unsafe "qr_l_R" dgeqr2 :: TMVM 431
446foreign import ccall unsafe "qr_l_C" zgeqr2 :: TCMCVCM 432type TMVM t = t ..> t :> t ..> Ok
433
434foreign import ccall unsafe "qr_l_R" dgeqr2 :: TMVM R
435foreign import ccall unsafe "qr_l_C" zgeqr2 :: TMVM C
447 436
448-- | QR factorization of a real matrix, using LAPACK's /dgeqr2/. 437-- | QR factorization of a real matrix, using LAPACK's /dgeqr2/.
449qrR :: Matrix Double -> (Matrix Double, Vector Double) 438qrR :: Matrix Double -> (Matrix Double, Vector Double)
@@ -463,8 +452,8 @@ qrAux f st a = unsafePerformIO $ do
463 n = cols a 452 n = cols a
464 mn = min m n 453 mn = min m n
465 454
466foreign import ccall unsafe "c_dorgqr" dorgqr :: TMVM 455foreign import ccall unsafe "c_dorgqr" dorgqr :: TMVM R
467foreign import ccall unsafe "c_zungqr" zungqr :: TCMCVCM 456foreign import ccall unsafe "c_zungqr" zungqr :: TMVM C
468 457
469-- | build rotation from reflectors 458-- | build rotation from reflectors
470qrgrR :: Int -> (Matrix Double, Vector Double) -> Matrix Double 459qrgrR :: Int -> (Matrix Double, Vector Double) -> Matrix Double
@@ -481,8 +470,8 @@ qrgrAux f st n (a, tau) = unsafePerformIO $ do
481 tau' = vjoin [tau, constantD 0 n] 470 tau' = vjoin [tau, constantD 0 n]
482 471
483----------------------------------------------------------------------------------- 472-----------------------------------------------------------------------------------
484foreign import ccall unsafe "hess_l_R" dgehrd :: TMVM 473foreign import ccall unsafe "hess_l_R" dgehrd :: TMVM R
485foreign import ccall unsafe "hess_l_C" zgehrd :: TCMCVCM 474foreign import ccall unsafe "hess_l_C" zgehrd :: TMVM C
486 475
487-- | Hessenberg factorization of a square real matrix, using LAPACK's /dgehrd/. 476-- | Hessenberg factorization of a square real matrix, using LAPACK's /dgehrd/.
488hessR :: Matrix Double -> (Matrix Double, Vector Double) 477hessR :: Matrix Double -> (Matrix Double, Vector Double)
@@ -502,8 +491,8 @@ hessAux f st a = unsafePerformIO $ do
502 mn = min m n 491 mn = min m n
503 492
504----------------------------------------------------------------------------------- 493-----------------------------------------------------------------------------------
505foreign import ccall unsafe "schur_l_R" dgees :: TMMM 494foreign import ccall unsafe "schur_l_R" dgees :: TMMM R
506foreign import ccall unsafe "schur_l_C" zgees :: TCMCMCM 495foreign import ccall unsafe "schur_l_C" zgees :: TMMM C
507 496
508-- | Schur factorization of a square real matrix, using LAPACK's /dgees/. 497-- | Schur factorization of a square real matrix, using LAPACK's /dgees/.
509schurR :: Matrix Double -> (Matrix Double, Matrix Double) 498schurR :: Matrix Double -> (Matrix Double, Matrix Double)
@@ -521,8 +510,8 @@ schurAux f st a = unsafePerformIO $ do
521 where n = rows a 510 where n = rows a
522 511
523----------------------------------------------------------------------------------- 512-----------------------------------------------------------------------------------
524foreign import ccall unsafe "lu_l_R" dgetrf :: TMVM 513foreign import ccall unsafe "lu_l_R" dgetrf :: TMVM R
525foreign import ccall unsafe "lu_l_C" zgetrf :: TCMVCM 514foreign import ccall unsafe "lu_l_C" zgetrf :: C ..> R :> C ..> Ok
526 515
527-- | LU factorization of a general real matrix, using LAPACK's /dgetrf/. 516-- | LU factorization of a general real matrix, using LAPACK's /dgetrf/.
528luR :: Matrix Double -> (Matrix Double, [Int]) 517luR :: Matrix Double -> (Matrix Double, [Int])
@@ -541,11 +530,11 @@ luAux f st a = unsafePerformIO $ do
541 m = cols a 530 m = cols a
542 531
543----------------------------------------------------------------------------------- 532-----------------------------------------------------------------------------------
544type TW a = CInt -> PD -> a
545type TQ a = CInt -> CInt -> PC -> a
546 533
547foreign import ccall unsafe "luS_l_R" dgetrs :: TMVMM 534type Tlus t = t ..> Double :> t ..> t ..> Ok
548foreign import ccall unsafe "luS_l_C" zgetrs :: TQ (TW (TQ (TQ (IO CInt)))) 535
536foreign import ccall unsafe "luS_l_R" dgetrs :: Tlus R
537foreign import ccall unsafe "luS_l_C" zgetrs :: Tlus C
549 538
550-- | Solve a real linear system from a precomputed LU decomposition ('luR'), using LAPACK's /dgetrs/. 539-- | Solve a real linear system from a precomputed LU decomposition ('luR'), using LAPACK's /dgetrs/.
551lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double 540lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double