summaryrefslogtreecommitdiff
path: root/lib/Numeric/LinearAlgebra/Algorithms.hs
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2008-11-04 09:32:35 +0000
committerAlberto Ruiz <aruiz@um.es>2008-11-04 09:32:35 +0000
commit02805ad64715373347b34bac2f75cbb866563ba2 (patch)
tree4eeb137ce0232d57ce98c0a0ced8fffe7baf7f99 /lib/Numeric/LinearAlgebra/Algorithms.hs
parent86c7aed1de8efe5988f994867d35addb6b62a655 (diff)
multiply/trans ok
Diffstat (limited to 'lib/Numeric/LinearAlgebra/Algorithms.hs')
-rw-r--r--lib/Numeric/LinearAlgebra/Algorithms.hs80
1 files changed, 4 insertions, 76 deletions
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs
index 75f4ba3..f259db5 100644
--- a/lib/Numeric/LinearAlgebra/Algorithms.hs
+++ b/lib/Numeric/LinearAlgebra/Algorithms.hs
@@ -54,7 +54,6 @@ module Numeric.LinearAlgebra.Algorithms (
54 ctrans, 54 ctrans,
55 eps, i, 55 eps, i,
56 outer, kronecker, 56 outer, kronecker,
57 mulH,
58-- * Util 57-- * Util
59 haussholder, 58 haussholder,
60 unpackQR, unpackHess, 59 unpackQR, unpackHess,
@@ -70,8 +69,8 @@ import Complex
70import Numeric.LinearAlgebra.Linear 69import Numeric.LinearAlgebra.Linear
71import Data.List(foldl1') 70import Data.List(foldl1')
72import Data.Array 71import Data.Array
73import Foreign 72
74import Foreign.C.Types 73
75 74
76-- | Auxiliary typeclass used to define generic computations for both real and complex matrices. 75-- | Auxiliary typeclass used to define generic computations for both real and complex matrices.
77class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where 76class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where
@@ -132,7 +131,7 @@ instance Field Double where
132 qr = unpackQR . qrR 131 qr = unpackQR . qrR
133 hess = unpackHess hessR 132 hess = unpackHess hessR
134 schur = schurR 133 schur = schurR
135 multiply = multiplyR3 134 multiply = multiplyR
136 135
137instance Field (Complex Double) where 136instance Field (Complex Double) where
138 svd = svdC 137 svd = svdC
@@ -147,7 +146,7 @@ instance Field (Complex Double) where
147 qr = unpackQR . qrC 146 qr = unpackQR . qrC
148 hess = unpackHess hessC 147 hess = unpackHess hessC
149 schur = schurC 148 schur = schurC
150 multiply = multiplyC3 149 multiply = multiplyC
151 150
152 151
153-- | Eigenvalues and Eigenvectors of a complex hermitian or real symmetric matrix using lapack's dsyev or zheev. 152-- | Eigenvalues and Eigenvectors of a complex hermitian or real symmetric matrix using lapack's dsyev or zheev.
@@ -567,74 +566,3 @@ kronecker a b = fromBlocks
567 . map (reshape (cols b)) 566 . map (reshape (cols b))
568 . toRows 567 . toRows
569 $ flatten a `outer` flatten b 568 $ flatten a `outer` flatten b
570
571---------------------------------------------------------------------
572-- reference multiply
573---------------------------------------------------------------------
574
575mulH a b = fromLists [[ doth ai bj | bj <- toColumns b] | ai <- toRows a ]
576 where doth u v = sum $ zipWith (*) (toList u) (toList v)
577
578-----------------------------------------------------------------------------------
579-- workaround
580-----------------------------------------------------------------------------------
581
582mulCW a b = toComplex (rr,ri)
583 where rr = multiply ar br `sub` multiply ai bi
584 ri = multiply ar bi `add` multiply ai br
585 (ar,ai) = fromComplex a
586 (br,bi) = fromComplex b
587
588-----------------------------------------------------------------------------------
589-- Direct CBLAS
590-----------------------------------------------------------------------------------
591
592-- taken from Patrick Perry's BLAS package
593newtype CBLASOrder = CBLASOrder CInt deriving (Eq, Show)
594newtype CBLASTrans = CBLASTrans CInt deriving (Eq, Show)
595
596rowMajor, colMajor :: CBLASOrder
597rowMajor = CBLASOrder 101
598colMajor = CBLASOrder 102
599
600noTrans, trans', conjTrans :: CBLASTrans
601noTrans = CBLASTrans 111
602trans' = CBLASTrans 112
603conjTrans = CBLASTrans 113
604
605foreign import ccall "cblas.h cblas_dgemm"
606 dgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt
607 -> Double -> Ptr Double -> CInt -> Ptr Double -> CInt -> Double
608 -> Ptr Double -> CInt -> IO ()
609
610multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double
611multiplyR3 x y = multiply3 dgemm "cblas_dgemm" (fmat x) (fmat y)
612 where
613 multiply3 f st a b
614 | cols a == rows b = unsafePerformIO $ do
615 s <- createMatrix ColumnMajor (rows a) (cols b)
616 let g ar ac ap br bc bp rr _rc rp = f colMajor noTrans noTrans ar bc ac 1 ap ar bp br 0 rp rr >> return 0
617 app3 g mat a mat b mat s st
618 return s
619 | otherwise = error $ st ++ " (matrix product) of nonconformant matrices"
620
621
622foreign import ccall "cblas.h cblas_zgemm"
623 zgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt
624 -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double)
625 -> Ptr (Complex Double) -> CInt -> IO ()
626
627multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
628multiplyC3 x y = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat x) (fmat y)
629 where
630 multiply3 f st a b
631 | cols a == rows b = do
632 s <- createMatrix ColumnMajor (rows a) (cols b)
633 palpha <- new 1
634 pbeta <- new 0
635 let g ar ac ap br bc bp rr _rc rp = f colMajor noTrans noTrans ar bc ac palpha ap ar bp br pbeta rp rr >> return 0
636 app3 g mat a mat b mat s st
637 free palpha
638 free pbeta
639 return s
640 | otherwise = error $ st ++ " (matrix product) of nonconformant matrices"