summaryrefslogtreecommitdiff
path: root/lib/Numeric/LinearAlgebra/LAPACK.hs
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Numeric/LinearAlgebra/LAPACK.hs')
-rw-r--r--lib/Numeric/LinearAlgebra/LAPACK.hs29
1 files changed, 21 insertions, 8 deletions
diff --git a/lib/Numeric/LinearAlgebra/LAPACK.hs b/lib/Numeric/LinearAlgebra/LAPACK.hs
index d78b506..8bc2492 100644
--- a/lib/Numeric/LinearAlgebra/LAPACK.hs
+++ b/lib/Numeric/LinearAlgebra/LAPACK.hs
@@ -19,7 +19,7 @@ module Numeric.LinearAlgebra.LAPACK (
19 linearSolveR, linearSolveC, 19 linearSolveR, linearSolveC,
20 linearSolveLSR, linearSolveLSC, 20 linearSolveLSR, linearSolveLSC,
21 linearSolveSVDR, linearSolveSVDC, 21 linearSolveSVDR, linearSolveSVDC,
22 luR, luC, lusR, 22 luR, luC, lusR, lusC,
23 cholS, cholH, 23 cholS, cholH,
24 qrR, qrC, 24 qrR, qrC,
25 hessR, hessC, 25 hessR, hessC,
@@ -34,6 +34,7 @@ import Data.Packed.Matrix
34import Numeric.GSL.Vector(vectorMapValR, FunCodeSV(Scale)) 34import Numeric.GSL.Vector(vectorMapValR, FunCodeSV(Scale))
35import Complex 35import Complex
36import Foreign 36import Foreign
37import Foreign.C.Types (CInt)
37 38
38----------------------------------------------------------------------------- 39-----------------------------------------------------------------------------
39foreign import ccall "LAPACK/lapack-aux.h svd_l_R" dgesvd :: TMMVM 40foreign import ccall "LAPACK/lapack-aux.h svd_l_R" dgesvd :: TMMVM
@@ -338,17 +339,29 @@ luAux f st a = unsafePerformIO $ do
338 where n = rows a 339 where n = rows a
339 m = cols a 340 m = cols a
340 341
341
342----------------------------------------------------------------------------------- 342-----------------------------------------------------------------------------------
343type TW a = CInt -> PD -> a
344type TQ a = CInt -> CInt -> PC -> a
345
343foreign import ccall "LAPACK/lapack-aux.h luS_l_R" dgetrs :: TMVMM 346foreign import ccall "LAPACK/lapack-aux.h luS_l_R" dgetrs :: TMVMM
347foreign import ccall "LAPACK/lapack-aux.h luS_l_C" zgetrs :: TQ (TW (TQ (TQ (IO CInt))))
344 348
349-- | Wrapper for LAPACK's /dgetrs/, which solves a general real linear system (for several right-hand sides) from a precomputed LU decomposition.
345lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double 350lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double
346lusR a piv b = lusR' (fmat a) piv (fmat b) 351lusR a piv b = lusAux dgetrs "lusR" (fmat a) piv (fmat b)
352
353-- | Wrapper for LAPACK's /zgetrs/, which solves a general real linear system (for several right-hand sides) from a precomputed LU decomposition.
354lusC :: Matrix (Complex Double) -> [Int] -> Matrix (Complex Double) -> Matrix (Complex Double)
355lusC a piv b = lusAux zgetrs "lusC" (fmat a) piv (fmat b)
347 356
348lusR' a piv b = unsafePerformIO $ do 357lusAux f st a piv b
358 | n1==n2 && n2==n =unsafePerformIO $ do
349 x <- createMatrix ColumnMajor n m 359 x <- createMatrix ColumnMajor n m
350 app4 dgetrs mat a vec piv' mat b mat x "lusR" 360 app4 f mat a vec piv' mat b mat x st
351 return x 361 return x
352 where n = rows b 362 | otherwise = error $ st ++ " on LU factorization of nonsquare matrix"
353 m = cols b 363 where n1 = rows a
354 piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double 364 n2 = cols a
365 n = rows b
366 m = cols b
367 piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double