diff options
author | Alberto Ruiz <aruiz@um.es> | 2008-11-04 09:32:35 +0000 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2008-11-04 09:32:35 +0000 |
commit | 02805ad64715373347b34bac2f75cbb866563ba2 (patch) | |
tree | 4eeb137ce0232d57ce98c0a0ced8fffe7baf7f99 /lib/Numeric/LinearAlgebra/Algorithms.hs | |
parent | 86c7aed1de8efe5988f994867d35addb6b62a655 (diff) |
multiply/trans ok
Diffstat (limited to 'lib/Numeric/LinearAlgebra/Algorithms.hs')
-rw-r--r-- | lib/Numeric/LinearAlgebra/Algorithms.hs | 80 |
1 files changed, 4 insertions, 76 deletions
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs index 75f4ba3..f259db5 100644 --- a/lib/Numeric/LinearAlgebra/Algorithms.hs +++ b/lib/Numeric/LinearAlgebra/Algorithms.hs | |||
@@ -54,7 +54,6 @@ module Numeric.LinearAlgebra.Algorithms ( | |||
54 | ctrans, | 54 | ctrans, |
55 | eps, i, | 55 | eps, i, |
56 | outer, kronecker, | 56 | outer, kronecker, |
57 | mulH, | ||
58 | -- * Util | 57 | -- * Util |
59 | haussholder, | 58 | haussholder, |
60 | unpackQR, unpackHess, | 59 | unpackQR, unpackHess, |
@@ -70,8 +69,8 @@ import Complex | |||
70 | import Numeric.LinearAlgebra.Linear | 69 | import Numeric.LinearAlgebra.Linear |
71 | import Data.List(foldl1') | 70 | import Data.List(foldl1') |
72 | import Data.Array | 71 | import Data.Array |
73 | import Foreign | 72 | |
74 | import Foreign.C.Types | 73 | |
75 | 74 | ||
76 | -- | Auxiliary typeclass used to define generic computations for both real and complex matrices. | 75 | -- | Auxiliary typeclass used to define generic computations for both real and complex matrices. |
77 | class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where | 76 | class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where |
@@ -132,7 +131,7 @@ instance Field Double where | |||
132 | qr = unpackQR . qrR | 131 | qr = unpackQR . qrR |
133 | hess = unpackHess hessR | 132 | hess = unpackHess hessR |
134 | schur = schurR | 133 | schur = schurR |
135 | multiply = multiplyR3 | 134 | multiply = multiplyR |
136 | 135 | ||
137 | instance Field (Complex Double) where | 136 | instance Field (Complex Double) where |
138 | svd = svdC | 137 | svd = svdC |
@@ -147,7 +146,7 @@ instance Field (Complex Double) where | |||
147 | qr = unpackQR . qrC | 146 | qr = unpackQR . qrC |
148 | hess = unpackHess hessC | 147 | hess = unpackHess hessC |
149 | schur = schurC | 148 | schur = schurC |
150 | multiply = multiplyC3 | 149 | multiply = multiplyC |
151 | 150 | ||
152 | 151 | ||
153 | -- | Eigenvalues and Eigenvectors of a complex hermitian or real symmetric matrix using lapack's dsyev or zheev. | 152 | -- | Eigenvalues and Eigenvectors of a complex hermitian or real symmetric matrix using lapack's dsyev or zheev. |
@@ -567,74 +566,3 @@ kronecker a b = fromBlocks | |||
567 | . map (reshape (cols b)) | 566 | . map (reshape (cols b)) |
568 | . toRows | 567 | . toRows |
569 | $ flatten a `outer` flatten b | 568 | $ flatten a `outer` flatten b |
570 | |||
571 | --------------------------------------------------------------------- | ||
572 | -- reference multiply | ||
573 | --------------------------------------------------------------------- | ||
574 | |||
575 | mulH a b = fromLists [[ doth ai bj | bj <- toColumns b] | ai <- toRows a ] | ||
576 | where doth u v = sum $ zipWith (*) (toList u) (toList v) | ||
577 | |||
578 | ----------------------------------------------------------------------------------- | ||
579 | -- workaround | ||
580 | ----------------------------------------------------------------------------------- | ||
581 | |||
582 | mulCW a b = toComplex (rr,ri) | ||
583 | where rr = multiply ar br `sub` multiply ai bi | ||
584 | ri = multiply ar bi `add` multiply ai br | ||
585 | (ar,ai) = fromComplex a | ||
586 | (br,bi) = fromComplex b | ||
587 | |||
588 | ----------------------------------------------------------------------------------- | ||
589 | -- Direct CBLAS | ||
590 | ----------------------------------------------------------------------------------- | ||
591 | |||
592 | -- taken from Patrick Perry's BLAS package | ||
593 | newtype CBLASOrder = CBLASOrder CInt deriving (Eq, Show) | ||
594 | newtype CBLASTrans = CBLASTrans CInt deriving (Eq, Show) | ||
595 | |||
596 | rowMajor, colMajor :: CBLASOrder | ||
597 | rowMajor = CBLASOrder 101 | ||
598 | colMajor = CBLASOrder 102 | ||
599 | |||
600 | noTrans, trans', conjTrans :: CBLASTrans | ||
601 | noTrans = CBLASTrans 111 | ||
602 | trans' = CBLASTrans 112 | ||
603 | conjTrans = CBLASTrans 113 | ||
604 | |||
605 | foreign import ccall "cblas.h cblas_dgemm" | ||
606 | dgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt | ||
607 | -> Double -> Ptr Double -> CInt -> Ptr Double -> CInt -> Double | ||
608 | -> Ptr Double -> CInt -> IO () | ||
609 | |||
610 | multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double | ||
611 | multiplyR3 x y = multiply3 dgemm "cblas_dgemm" (fmat x) (fmat y) | ||
612 | where | ||
613 | multiply3 f st a b | ||
614 | | cols a == rows b = unsafePerformIO $ do | ||
615 | s <- createMatrix ColumnMajor (rows a) (cols b) | ||
616 | let g ar ac ap br bc bp rr _rc rp = f colMajor noTrans noTrans ar bc ac 1 ap ar bp br 0 rp rr >> return 0 | ||
617 | app3 g mat a mat b mat s st | ||
618 | return s | ||
619 | | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" | ||
620 | |||
621 | |||
622 | foreign import ccall "cblas.h cblas_zgemm" | ||
623 | zgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt | ||
624 | -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) | ||
625 | -> Ptr (Complex Double) -> CInt -> IO () | ||
626 | |||
627 | multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | ||
628 | multiplyC3 x y = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat x) (fmat y) | ||
629 | where | ||
630 | multiply3 f st a b | ||
631 | | cols a == rows b = do | ||
632 | s <- createMatrix ColumnMajor (rows a) (cols b) | ||
633 | palpha <- new 1 | ||
634 | pbeta <- new 0 | ||
635 | let g ar ac ap br bc bp rr _rc rp = f colMajor noTrans noTrans ar bc ac palpha ap ar bp br pbeta rp rr >> return 0 | ||
636 | app3 g mat a mat b mat s st | ||
637 | free palpha | ||
638 | free pbeta | ||
639 | return s | ||
640 | | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" | ||