diff options
Diffstat (limited to 'lib/Numeric/LinearAlgebra/Algorithms.hs')
-rw-r--r-- | lib/Numeric/LinearAlgebra/Algorithms.hs | 48 |
1 files changed, 26 insertions, 22 deletions
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs index 00a0ab0..fbefa68 100644 --- a/lib/Numeric/LinearAlgebra/Algorithms.hs +++ b/lib/Numeric/LinearAlgebra/Algorithms.hs | |||
@@ -589,6 +589,7 @@ mulCW a b = toComplex (rr,ri) | |||
589 | -- Direct CBLAS | 589 | -- Direct CBLAS |
590 | ----------------------------------------------------------------------------------- | 590 | ----------------------------------------------------------------------------------- |
591 | 591 | ||
592 | -- taken from Patrick Perry's BLAS package | ||
592 | newtype CBLASOrder = CBLASOrder CInt deriving (Eq, Show) | 593 | newtype CBLASOrder = CBLASOrder CInt deriving (Eq, Show) |
593 | newtype CBLASTrans = CBLASTrans CInt deriving (Eq, Show) | 594 | newtype CBLASTrans = CBLASTrans CInt deriving (Eq, Show) |
594 | 595 | ||
@@ -602,8 +603,9 @@ trans' = CBLASTrans 112 | |||
602 | conjTrans = CBLASTrans 113 | 603 | conjTrans = CBLASTrans 113 |
603 | 604 | ||
604 | foreign import ccall "cblas.h cblas_dgemm" | 605 | foreign import ccall "cblas.h cblas_dgemm" |
605 | dgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt -> Double -> Ptr Double -> CInt -> Ptr Double -> CInt -> Double -> Ptr Double -> CInt -> IO () | 606 | dgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt |
606 | 607 | -> Double -> Ptr Double -> CInt -> Ptr Double -> CInt -> Double | |
608 | -> Ptr Double -> CInt -> IO () | ||
607 | 609 | ||
608 | multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double | 610 | multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double |
609 | multiplyR3 a b = multiply3 dgemm "cblas_dgemm" (fmat a) (fmat b) | 611 | multiplyR3 a b = multiply3 dgemm "cblas_dgemm" (fmat a) (fmat b) |
@@ -618,7 +620,9 @@ multiplyR3 a b = multiply3 dgemm "cblas_dgemm" (fmat a) (fmat b) | |||
618 | 620 | ||
619 | 621 | ||
620 | foreign import ccall "cblas.h cblas_zgemm" | 622 | foreign import ccall "cblas.h cblas_zgemm" |
621 | zgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> IO () | 623 | zgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt |
624 | -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) | ||
625 | -> Ptr (Complex Double) -> CInt -> IO () | ||
622 | 626 | ||
623 | multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | 627 | multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) |
624 | multiplyC3 a b = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat a) (fmat b) | 628 | multiplyC3 a b = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat a) (fmat b) |
@@ -640,27 +644,27 @@ multiplyC3 a b = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat a) (fmat | |||
640 | -- BLAS via auxiliary C | 644 | -- BLAS via auxiliary C |
641 | ----------------------------------------------------------------------------------- | 645 | ----------------------------------------------------------------------------------- |
642 | 646 | ||
643 | foreign import ccall "multiply.h multiplyR2" dgemmc :: TMMM | 647 | -- foreign import ccall "multiply.h multiplyR2" dgemmc :: TMMM |
644 | foreign import ccall "multiply.h multiplyC2" zgemmc :: TCMCMCM | 648 | -- foreign import ccall "multiply.h multiplyC2" zgemmc :: TCMCMCM |
645 | 649 | -- | |
646 | multiply2 f st a b | 650 | -- multiply2 f st a b |
647 | | cols a == rows b = unsafePerformIO $ do | 651 | -- | cols a == rows b = unsafePerformIO $ do |
648 | s <- createMatrix ColumnMajor (rows a) (cols b) | 652 | -- s <- createMatrix ColumnMajor (rows a) (cols b) |
649 | app3 f mat a mat b mat s st | 653 | -- app3 f mat a mat b mat s st |
650 | if toLists s== toLists s then return s else error $ "AYYY " ++ (show (toLists s)) | 654 | -- if toLists s== toLists s then return s else error $ "AYYY " ++ (show (toLists s)) |
651 | | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" | 655 | -- | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" |
652 | 656 | -- | |
653 | multiplyR2 :: Matrix Double -> Matrix Double -> Matrix Double | 657 | -- multiplyR2 :: Matrix Double -> Matrix Double -> Matrix Double |
654 | multiplyR2 a b = multiply2 dgemmc "dgemmc" (fmat a) (fmat b) | 658 | -- multiplyR2 a b = multiply2 dgemmc "dgemmc" (fmat a) (fmat b) |
655 | 659 | -- | |
656 | multiplyC2 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | 660 | -- multiplyC2 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) |
657 | multiplyC2 a b = multiply2 zgemmc "zgemmc" (fmat a) (fmat b) | 661 | -- multiplyC2 a b = multiply2 zgemmc "zgemmc" (fmat a) (fmat b) |
658 | 662 | ||
659 | ----------------------------------------------------------------------------------- | 663 | ----------------------------------------------------------------------------------- |
660 | -- direct C multiplication | 664 | -- direct C multiplication, to expose the NaN bug |
661 | ----------------------------------------------------------------------------------- | 665 | ----------------------------------------------------------------------------------- |
662 | 666 | ||
663 | foreign import ccall "multiply.h multiplyR" cmultiplyR :: TMMM | 667 | -- foreign import ccall "multiply.h multiplyR" cmultiplyR :: TMMM |
664 | foreign import ccall "multiply.h multiplyC" cmultiplyC :: TCMCMCM | 668 | foreign import ccall "multiply.h multiplyC" cmultiplyC :: TCMCMCM |
665 | 669 | ||
666 | cmultiply f st a b | 670 | cmultiply f st a b |
@@ -674,8 +678,8 @@ cmultiply f st a b | |||
674 | -- return s | 678 | -- return s |
675 | -- | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" | 679 | -- | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" |
676 | 680 | ||
677 | multiplyR :: Matrix Double -> Matrix Double -> Matrix Double | 681 | -- multiplyR :: Matrix Double -> Matrix Double -> Matrix Double |
678 | multiplyR a b = cmultiply cmultiplyR "cmultiplyR" (cmat a) (cmat b) | 682 | -- multiplyR a b = cmultiply cmultiplyR "cmultiplyR" (cmat a) (cmat b) |
679 | 683 | ||
680 | multiplyC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | 684 | multiplyC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) |
681 | multiplyC a b = cmultiply cmultiplyC "cmultiplyR" (cmat a) (cmat b) | 685 | multiplyC a b = cmultiply cmultiplyC "cmultiplyR" (cmat a) (cmat b) |