diff options
Diffstat (limited to 'packages/base/src/Internal/LAPACK.hs')
-rw-r--r-- | packages/base/src/Internal/LAPACK.hs | 19 |
1 files changed, 11 insertions, 8 deletions
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index 27d1f95..d88ff6b 100644 --- a/packages/base/src/Internal/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs | |||
@@ -22,9 +22,12 @@ import Data.Bifunctor (first) | |||
22 | 22 | ||
23 | import Internal.Devel | 23 | import Internal.Devel |
24 | import Internal.Vector | 24 | import Internal.Vector |
25 | import Internal.Vectorized (constantAux) | ||
25 | import Internal.Matrix hiding ((#), (#!)) | 26 | import Internal.Matrix hiding ((#), (#!)) |
26 | import Internal.Conversion | 27 | import Internal.Conversion |
27 | import Internal.Element | 28 | import Internal.Element |
29 | import Internal.ST (setRect) | ||
30 | import Data.Int | ||
28 | import Foreign.Ptr(nullPtr) | 31 | import Foreign.Ptr(nullPtr) |
29 | import Foreign.C.Types | 32 | import Foreign.C.Types |
30 | import Control.Monad(when) | 33 | import Control.Monad(when) |
@@ -46,10 +49,10 @@ type TMMM t = t ::> t ::> t ::> Ok | |||
46 | type F = Float | 49 | type F = Float |
47 | type Q = Complex Float | 50 | type Q = Complex Float |
48 | 51 | ||
49 | foreign import ccall unsafe "multiplyR" dgemmc :: CInt -> CInt -> TMMM R | 52 | foreign import ccall unsafe "multiplyR" dgemmc :: Int32 -> Int32 -> TMMM R |
50 | foreign import ccall unsafe "multiplyC" zgemmc :: CInt -> CInt -> TMMM C | 53 | foreign import ccall unsafe "multiplyC" zgemmc :: Int32 -> Int32 -> TMMM C |
51 | foreign import ccall unsafe "multiplyF" sgemmc :: CInt -> CInt -> TMMM F | 54 | foreign import ccall unsafe "multiplyF" sgemmc :: Int32 -> Int32 -> TMMM F |
52 | foreign import ccall unsafe "multiplyQ" cgemmc :: CInt -> CInt -> TMMM Q | 55 | foreign import ccall unsafe "multiplyQ" cgemmc :: Int32 -> Int32 -> TMMM Q |
53 | foreign import ccall unsafe "multiplyI" c_multiplyI :: I -> TMMM I | 56 | foreign import ccall unsafe "multiplyI" c_multiplyI :: I -> TMMM I |
54 | foreign import ccall unsafe "multiplyL" c_multiplyL :: Z -> TMMM Z | 57 | foreign import ccall unsafe "multiplyL" c_multiplyL :: Z -> TMMM Z |
55 | 58 | ||
@@ -82,7 +85,7 @@ multiplyF a b = multiplyAux sgemmc "sgemmc" a b | |||
82 | multiplyQ :: Matrix (Complex Float) -> Matrix (Complex Float) -> Matrix (Complex Float) | 85 | multiplyQ :: Matrix (Complex Float) -> Matrix (Complex Float) -> Matrix (Complex Float) |
83 | multiplyQ a b = multiplyAux cgemmc "cgemmc" a b | 86 | multiplyQ a b = multiplyAux cgemmc "cgemmc" a b |
84 | 87 | ||
85 | multiplyI :: I -> Matrix CInt -> Matrix CInt -> Matrix CInt | 88 | multiplyI :: I -> Matrix Int32 -> Matrix Int32 -> Matrix Int32 |
86 | multiplyI m a b = unsafePerformIO $ do | 89 | multiplyI m a b = unsafePerformIO $ do |
87 | when (cols a /= rows b) $ error $ | 90 | when (cols a /= rows b) $ error $ |
88 | "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b | 91 | "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b |
@@ -239,8 +242,8 @@ foreign import ccall unsafe "eig_l_R" dgeev :: R ::> R ::> C :> R ::> Ok | |||
239 | foreign import ccall unsafe "eig_l_G" dggev :: R ::> R ::> C :> R :> R ::> R ::> Ok | 242 | foreign import ccall unsafe "eig_l_G" dggev :: R ::> R ::> C :> R :> R ::> R ::> Ok |
240 | foreign import ccall unsafe "eig_l_C" zgeev :: C ::> C ::> C :> C ::> Ok | 243 | foreign import ccall unsafe "eig_l_C" zgeev :: C ::> C ::> C :> C ::> Ok |
241 | foreign import ccall unsafe "eig_l_GC" zggev :: C ::> C ::> C :> C :> C ::> C ::> Ok | 244 | foreign import ccall unsafe "eig_l_GC" zggev :: C ::> C ::> C :> C :> C ::> C ::> Ok |
242 | foreign import ccall unsafe "eig_l_S" dsyev :: CInt -> R :> R ::> Ok | 245 | foreign import ccall unsafe "eig_l_S" dsyev :: Int32 -> R :> R ::> Ok |
243 | foreign import ccall unsafe "eig_l_H" zheev :: CInt -> R :> C ::> Ok | 246 | foreign import ccall unsafe "eig_l_H" zheev :: Int32 -> R :> C ::> Ok |
244 | 247 | ||
245 | eigAux f st m = unsafePerformIO $ do | 248 | eigAux f st m = unsafePerformIO $ do |
246 | a <- copy ColumnMajor m | 249 | a <- copy ColumnMajor m |
@@ -636,7 +639,7 @@ qrgrAux f st n (a, tau) = unsafePerformIO $ do | |||
636 | ((subVector 0 n tau') #! res) f #| st | 639 | ((subVector 0 n tau') #! res) f #| st |
637 | return res | 640 | return res |
638 | where | 641 | where |
639 | tau' = vjoin [tau, constantD 0 n] | 642 | tau' = vjoin [tau, constantAux 0 n] |
640 | 643 | ||
641 | ----------------------------------------------------------------------------------- | 644 | ----------------------------------------------------------------------------------- |
642 | foreign import ccall unsafe "hess_l_R" dgehrd :: R :> R ::> Ok | 645 | foreign import ccall unsafe "hess_l_R" dgehrd :: R :> R ::> Ok |