summaryrefslogtreecommitdiff
path: root/lib/Numeric/LinearAlgebra/Algorithms.hs
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Numeric/LinearAlgebra/Algorithms.hs')
-rw-r--r--lib/Numeric/LinearAlgebra/Algorithms.hs48
1 files changed, 26 insertions, 22 deletions
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs
index 00a0ab0..fbefa68 100644
--- a/lib/Numeric/LinearAlgebra/Algorithms.hs
+++ b/lib/Numeric/LinearAlgebra/Algorithms.hs
@@ -589,6 +589,7 @@ mulCW a b = toComplex (rr,ri)
589-- Direct CBLAS 589-- Direct CBLAS
590----------------------------------------------------------------------------------- 590-----------------------------------------------------------------------------------
591 591
592-- taken from Patrick Perry's BLAS package
592newtype CBLASOrder = CBLASOrder CInt deriving (Eq, Show) 593newtype CBLASOrder = CBLASOrder CInt deriving (Eq, Show)
593newtype CBLASTrans = CBLASTrans CInt deriving (Eq, Show) 594newtype CBLASTrans = CBLASTrans CInt deriving (Eq, Show)
594 595
@@ -602,8 +603,9 @@ trans' = CBLASTrans 112
602conjTrans = CBLASTrans 113 603conjTrans = CBLASTrans 113
603 604
604foreign import ccall "cblas.h cblas_dgemm" 605foreign import ccall "cblas.h cblas_dgemm"
605 dgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt -> Double -> Ptr Double -> CInt -> Ptr Double -> CInt -> Double -> Ptr Double -> CInt -> IO () 606 dgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt
606 607 -> Double -> Ptr Double -> CInt -> Ptr Double -> CInt -> Double
608 -> Ptr Double -> CInt -> IO ()
607 609
608multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double 610multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double
609multiplyR3 a b = multiply3 dgemm "cblas_dgemm" (fmat a) (fmat b) 611multiplyR3 a b = multiply3 dgemm "cblas_dgemm" (fmat a) (fmat b)
@@ -618,7 +620,9 @@ multiplyR3 a b = multiply3 dgemm "cblas_dgemm" (fmat a) (fmat b)
618 620
619 621
620foreign import ccall "cblas.h cblas_zgemm" 622foreign import ccall "cblas.h cblas_zgemm"
621 zgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> IO () 623 zgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt
624 -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double)
625 -> Ptr (Complex Double) -> CInt -> IO ()
622 626
623multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) 627multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
624multiplyC3 a b = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat a) (fmat b) 628multiplyC3 a b = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat a) (fmat b)
@@ -640,27 +644,27 @@ multiplyC3 a b = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat a) (fmat
640-- BLAS via auxiliary C 644-- BLAS via auxiliary C
641----------------------------------------------------------------------------------- 645-----------------------------------------------------------------------------------
642 646
643foreign import ccall "multiply.h multiplyR2" dgemmc :: TMMM 647-- foreign import ccall "multiply.h multiplyR2" dgemmc :: TMMM
644foreign import ccall "multiply.h multiplyC2" zgemmc :: TCMCMCM 648-- foreign import ccall "multiply.h multiplyC2" zgemmc :: TCMCMCM
645 649--
646multiply2 f st a b 650-- multiply2 f st a b
647 | cols a == rows b = unsafePerformIO $ do 651-- | cols a == rows b = unsafePerformIO $ do
648 s <- createMatrix ColumnMajor (rows a) (cols b) 652-- s <- createMatrix ColumnMajor (rows a) (cols b)
649 app3 f mat a mat b mat s st 653-- app3 f mat a mat b mat s st
650 if toLists s== toLists s then return s else error $ "AYYY " ++ (show (toLists s)) 654-- if toLists s== toLists s then return s else error $ "AYYY " ++ (show (toLists s))
651 | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" 655-- | otherwise = error $ st ++ " (matrix product) of nonconformant matrices"
652 656--
653multiplyR2 :: Matrix Double -> Matrix Double -> Matrix Double 657-- multiplyR2 :: Matrix Double -> Matrix Double -> Matrix Double
654multiplyR2 a b = multiply2 dgemmc "dgemmc" (fmat a) (fmat b) 658-- multiplyR2 a b = multiply2 dgemmc "dgemmc" (fmat a) (fmat b)
655 659--
656multiplyC2 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) 660-- multiplyC2 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
657multiplyC2 a b = multiply2 zgemmc "zgemmc" (fmat a) (fmat b) 661-- multiplyC2 a b = multiply2 zgemmc "zgemmc" (fmat a) (fmat b)
658 662
659----------------------------------------------------------------------------------- 663-----------------------------------------------------------------------------------
660-- direct C multiplication 664-- direct C multiplication, to expose the NaN bug
661----------------------------------------------------------------------------------- 665-----------------------------------------------------------------------------------
662 666
663foreign import ccall "multiply.h multiplyR" cmultiplyR :: TMMM 667-- foreign import ccall "multiply.h multiplyR" cmultiplyR :: TMMM
664foreign import ccall "multiply.h multiplyC" cmultiplyC :: TCMCMCM 668foreign import ccall "multiply.h multiplyC" cmultiplyC :: TCMCMCM
665 669
666cmultiply f st a b 670cmultiply f st a b
@@ -674,8 +678,8 @@ cmultiply f st a b
674 -- return s 678 -- return s
675-- | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" 679-- | otherwise = error $ st ++ " (matrix product) of nonconformant matrices"
676 680
677multiplyR :: Matrix Double -> Matrix Double -> Matrix Double 681-- multiplyR :: Matrix Double -> Matrix Double -> Matrix Double
678multiplyR a b = cmultiply cmultiplyR "cmultiplyR" (cmat a) (cmat b) 682-- multiplyR a b = cmultiply cmultiplyR "cmultiplyR" (cmat a) (cmat b)
679 683
680multiplyC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) 684multiplyC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
681multiplyC a b = cmultiply cmultiplyC "cmultiplyR" (cmat a) (cmat b) 685multiplyC a b = cmultiply cmultiplyC "cmultiplyR" (cmat a) (cmat b)