diff options
Diffstat (limited to 'lib/Numeric/LinearAlgebra/LAPACK.hs')
-rw-r--r-- | lib/Numeric/LinearAlgebra/LAPACK.hs | 29 |
1 files changed, 21 insertions, 8 deletions
diff --git a/lib/Numeric/LinearAlgebra/LAPACK.hs b/lib/Numeric/LinearAlgebra/LAPACK.hs index d78b506..8bc2492 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK.hs +++ b/lib/Numeric/LinearAlgebra/LAPACK.hs | |||
@@ -19,7 +19,7 @@ module Numeric.LinearAlgebra.LAPACK ( | |||
19 | linearSolveR, linearSolveC, | 19 | linearSolveR, linearSolveC, |
20 | linearSolveLSR, linearSolveLSC, | 20 | linearSolveLSR, linearSolveLSC, |
21 | linearSolveSVDR, linearSolveSVDC, | 21 | linearSolveSVDR, linearSolveSVDC, |
22 | luR, luC, lusR, | 22 | luR, luC, lusR, lusC, |
23 | cholS, cholH, | 23 | cholS, cholH, |
24 | qrR, qrC, | 24 | qrR, qrC, |
25 | hessR, hessC, | 25 | hessR, hessC, |
@@ -34,6 +34,7 @@ import Data.Packed.Matrix | |||
34 | import Numeric.GSL.Vector(vectorMapValR, FunCodeSV(Scale)) | 34 | import Numeric.GSL.Vector(vectorMapValR, FunCodeSV(Scale)) |
35 | import Complex | 35 | import Complex |
36 | import Foreign | 36 | import Foreign |
37 | import Foreign.C.Types (CInt) | ||
37 | 38 | ||
38 | ----------------------------------------------------------------------------- | 39 | ----------------------------------------------------------------------------- |
39 | foreign import ccall "LAPACK/lapack-aux.h svd_l_R" dgesvd :: TMMVM | 40 | foreign import ccall "LAPACK/lapack-aux.h svd_l_R" dgesvd :: TMMVM |
@@ -338,17 +339,29 @@ luAux f st a = unsafePerformIO $ do | |||
338 | where n = rows a | 339 | where n = rows a |
339 | m = cols a | 340 | m = cols a |
340 | 341 | ||
341 | |||
342 | ----------------------------------------------------------------------------------- | 342 | ----------------------------------------------------------------------------------- |
343 | type TW a = CInt -> PD -> a | ||
344 | type TQ a = CInt -> CInt -> PC -> a | ||
345 | |||
343 | foreign import ccall "LAPACK/lapack-aux.h luS_l_R" dgetrs :: TMVMM | 346 | foreign import ccall "LAPACK/lapack-aux.h luS_l_R" dgetrs :: TMVMM |
347 | foreign import ccall "LAPACK/lapack-aux.h luS_l_C" zgetrs :: TQ (TW (TQ (TQ (IO CInt)))) | ||
344 | 348 | ||
349 | -- | Wrapper for LAPACK's /dgetrs/, which solves a general real linear system (for several right-hand sides) from a precomputed LU decomposition. | ||
345 | lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double | 350 | lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double |
346 | lusR a piv b = lusR' (fmat a) piv (fmat b) | 351 | lusR a piv b = lusAux dgetrs "lusR" (fmat a) piv (fmat b) |
352 | |||
353 | -- | Wrapper for LAPACK's /zgetrs/, which solves a general real linear system (for several right-hand sides) from a precomputed LU decomposition. | ||
354 | lusC :: Matrix (Complex Double) -> [Int] -> Matrix (Complex Double) -> Matrix (Complex Double) | ||
355 | lusC a piv b = lusAux zgetrs "lusC" (fmat a) piv (fmat b) | ||
347 | 356 | ||
348 | lusR' a piv b = unsafePerformIO $ do | 357 | lusAux f st a piv b |
358 | | n1==n2 && n2==n =unsafePerformIO $ do | ||
349 | x <- createMatrix ColumnMajor n m | 359 | x <- createMatrix ColumnMajor n m |
350 | app4 dgetrs mat a vec piv' mat b mat x "lusR" | 360 | app4 f mat a vec piv' mat b mat x st |
351 | return x | 361 | return x |
352 | where n = rows b | 362 | | otherwise = error $ st ++ " on LU factorization of nonsquare matrix" |
353 | m = cols b | 363 | where n1 = rows a |
354 | piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double | 364 | n2 = cols a |
365 | n = rows b | ||
366 | m = cols b | ||
367 | piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double | ||