summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/LAPACK.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal/LAPACK.hs')
-rw-r--r--packages/base/src/Internal/LAPACK.hs19
1 files changed, 11 insertions, 8 deletions
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs
index 27d1f95..d88ff6b 100644
--- a/packages/base/src/Internal/LAPACK.hs
+++ b/packages/base/src/Internal/LAPACK.hs
@@ -22,9 +22,12 @@ import Data.Bifunctor (first)
22 22
23import Internal.Devel 23import Internal.Devel
24import Internal.Vector 24import Internal.Vector
25import Internal.Vectorized (constantAux)
25import Internal.Matrix hiding ((#), (#!)) 26import Internal.Matrix hiding ((#), (#!))
26import Internal.Conversion 27import Internal.Conversion
27import Internal.Element 28import Internal.Element
29import Internal.ST (setRect)
30import Data.Int
28import Foreign.Ptr(nullPtr) 31import Foreign.Ptr(nullPtr)
29import Foreign.C.Types 32import Foreign.C.Types
30import Control.Monad(when) 33import Control.Monad(when)
@@ -46,10 +49,10 @@ type TMMM t = t ::> t ::> t ::> Ok
46type F = Float 49type F = Float
47type Q = Complex Float 50type Q = Complex Float
48 51
49foreign import ccall unsafe "multiplyR" dgemmc :: CInt -> CInt -> TMMM R 52foreign import ccall unsafe "multiplyR" dgemmc :: Int32 -> Int32 -> TMMM R
50foreign import ccall unsafe "multiplyC" zgemmc :: CInt -> CInt -> TMMM C 53foreign import ccall unsafe "multiplyC" zgemmc :: Int32 -> Int32 -> TMMM C
51foreign import ccall unsafe "multiplyF" sgemmc :: CInt -> CInt -> TMMM F 54foreign import ccall unsafe "multiplyF" sgemmc :: Int32 -> Int32 -> TMMM F
52foreign import ccall unsafe "multiplyQ" cgemmc :: CInt -> CInt -> TMMM Q 55foreign import ccall unsafe "multiplyQ" cgemmc :: Int32 -> Int32 -> TMMM Q
53foreign import ccall unsafe "multiplyI" c_multiplyI :: I -> TMMM I 56foreign import ccall unsafe "multiplyI" c_multiplyI :: I -> TMMM I
54foreign import ccall unsafe "multiplyL" c_multiplyL :: Z -> TMMM Z 57foreign import ccall unsafe "multiplyL" c_multiplyL :: Z -> TMMM Z
55 58
@@ -82,7 +85,7 @@ multiplyF a b = multiplyAux sgemmc "sgemmc" a b
82multiplyQ :: Matrix (Complex Float) -> Matrix (Complex Float) -> Matrix (Complex Float) 85multiplyQ :: Matrix (Complex Float) -> Matrix (Complex Float) -> Matrix (Complex Float)
83multiplyQ a b = multiplyAux cgemmc "cgemmc" a b 86multiplyQ a b = multiplyAux cgemmc "cgemmc" a b
84 87
85multiplyI :: I -> Matrix CInt -> Matrix CInt -> Matrix CInt 88multiplyI :: I -> Matrix Int32 -> Matrix Int32 -> Matrix Int32
86multiplyI m a b = unsafePerformIO $ do 89multiplyI m a b = unsafePerformIO $ do
87 when (cols a /= rows b) $ error $ 90 when (cols a /= rows b) $ error $
88 "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b 91 "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b
@@ -239,8 +242,8 @@ foreign import ccall unsafe "eig_l_R" dgeev :: R ::> R ::> C :> R ::> Ok
239foreign import ccall unsafe "eig_l_G" dggev :: R ::> R ::> C :> R :> R ::> R ::> Ok 242foreign import ccall unsafe "eig_l_G" dggev :: R ::> R ::> C :> R :> R ::> R ::> Ok
240foreign import ccall unsafe "eig_l_C" zgeev :: C ::> C ::> C :> C ::> Ok 243foreign import ccall unsafe "eig_l_C" zgeev :: C ::> C ::> C :> C ::> Ok
241foreign import ccall unsafe "eig_l_GC" zggev :: C ::> C ::> C :> C :> C ::> C ::> Ok 244foreign import ccall unsafe "eig_l_GC" zggev :: C ::> C ::> C :> C :> C ::> C ::> Ok
242foreign import ccall unsafe "eig_l_S" dsyev :: CInt -> R :> R ::> Ok 245foreign import ccall unsafe "eig_l_S" dsyev :: Int32 -> R :> R ::> Ok
243foreign import ccall unsafe "eig_l_H" zheev :: CInt -> R :> C ::> Ok 246foreign import ccall unsafe "eig_l_H" zheev :: Int32 -> R :> C ::> Ok
244 247
245eigAux f st m = unsafePerformIO $ do 248eigAux f st m = unsafePerformIO $ do
246 a <- copy ColumnMajor m 249 a <- copy ColumnMajor m
@@ -636,7 +639,7 @@ qrgrAux f st n (a, tau) = unsafePerformIO $ do
636 ((subVector 0 n tau') #! res) f #| st 639 ((subVector 0 n tau') #! res) f #| st
637 return res 640 return res
638 where 641 where
639 tau' = vjoin [tau, constantD 0 n] 642 tau' = vjoin [tau, constantAux 0 n]
640 643
641----------------------------------------------------------------------------------- 644-----------------------------------------------------------------------------------
642foreign import ccall unsafe "hess_l_R" dgehrd :: R :> R ::> Ok 645foreign import ccall unsafe "hess_l_R" dgehrd :: R :> R ::> Ok