diff options
author | Alberto Ruiz <aruiz@um.es> | 2008-11-04 09:32:35 +0000 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2008-11-04 09:32:35 +0000 |
commit | 02805ad64715373347b34bac2f75cbb866563ba2 (patch) | |
tree | 4eeb137ce0232d57ce98c0a0ced8fffe7baf7f99 /lib | |
parent | 86c7aed1de8efe5988f994867d35addb6b62a655 (diff) |
multiply/trans ok
Diffstat (limited to 'lib')
-rw-r--r-- | lib/Graphics/Plot.hs | 4 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Algorithms.hs | 80 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/LAPACK.hs | 28 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | 41 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h | 8 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Tests/Properties.hs | 4 |
6 files changed, 86 insertions, 79 deletions
diff --git a/lib/Graphics/Plot.hs b/lib/Graphics/Plot.hs index 9352048..bcb1fb3 100644 --- a/lib/Graphics/Plot.hs +++ b/lib/Graphics/Plot.hs | |||
@@ -9,7 +9,9 @@ | |||
9 | -- Portability : uses gnuplot and ImageMagick | 9 | -- Portability : uses gnuplot and ImageMagick |
10 | -- | 10 | -- |
11 | -- Very basic (and provisional) drawing tools using gnuplot and imageMagick. | 11 | -- Very basic (and provisional) drawing tools using gnuplot and imageMagick. |
12 | -- | 12 | -- |
13 | -- This module is deprecated. It will be replaced by improved drawing tools based | ||
14 | -- on the Gnuplot package by Henning Thielemann. | ||
13 | ----------------------------------------------------------------------------- | 15 | ----------------------------------------------------------------------------- |
14 | 16 | ||
15 | module Graphics.Plot( | 17 | module Graphics.Plot( |
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs index 75f4ba3..f259db5 100644 --- a/lib/Numeric/LinearAlgebra/Algorithms.hs +++ b/lib/Numeric/LinearAlgebra/Algorithms.hs | |||
@@ -54,7 +54,6 @@ module Numeric.LinearAlgebra.Algorithms ( | |||
54 | ctrans, | 54 | ctrans, |
55 | eps, i, | 55 | eps, i, |
56 | outer, kronecker, | 56 | outer, kronecker, |
57 | mulH, | ||
58 | -- * Util | 57 | -- * Util |
59 | haussholder, | 58 | haussholder, |
60 | unpackQR, unpackHess, | 59 | unpackQR, unpackHess, |
@@ -70,8 +69,8 @@ import Complex | |||
70 | import Numeric.LinearAlgebra.Linear | 69 | import Numeric.LinearAlgebra.Linear |
71 | import Data.List(foldl1') | 70 | import Data.List(foldl1') |
72 | import Data.Array | 71 | import Data.Array |
73 | import Foreign | 72 | |
74 | import Foreign.C.Types | 73 | |
75 | 74 | ||
76 | -- | Auxiliary typeclass used to define generic computations for both real and complex matrices. | 75 | -- | Auxiliary typeclass used to define generic computations for both real and complex matrices. |
77 | class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where | 76 | class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where |
@@ -132,7 +131,7 @@ instance Field Double where | |||
132 | qr = unpackQR . qrR | 131 | qr = unpackQR . qrR |
133 | hess = unpackHess hessR | 132 | hess = unpackHess hessR |
134 | schur = schurR | 133 | schur = schurR |
135 | multiply = multiplyR3 | 134 | multiply = multiplyR |
136 | 135 | ||
137 | instance Field (Complex Double) where | 136 | instance Field (Complex Double) where |
138 | svd = svdC | 137 | svd = svdC |
@@ -147,7 +146,7 @@ instance Field (Complex Double) where | |||
147 | qr = unpackQR . qrC | 146 | qr = unpackQR . qrC |
148 | hess = unpackHess hessC | 147 | hess = unpackHess hessC |
149 | schur = schurC | 148 | schur = schurC |
150 | multiply = multiplyC3 | 149 | multiply = multiplyC |
151 | 150 | ||
152 | 151 | ||
153 | -- | Eigenvalues and Eigenvectors of a complex hermitian or real symmetric matrix using lapack's dsyev or zheev. | 152 | -- | Eigenvalues and Eigenvectors of a complex hermitian or real symmetric matrix using lapack's dsyev or zheev. |
@@ -567,74 +566,3 @@ kronecker a b = fromBlocks | |||
567 | . map (reshape (cols b)) | 566 | . map (reshape (cols b)) |
568 | . toRows | 567 | . toRows |
569 | $ flatten a `outer` flatten b | 568 | $ flatten a `outer` flatten b |
570 | |||
571 | --------------------------------------------------------------------- | ||
572 | -- reference multiply | ||
573 | --------------------------------------------------------------------- | ||
574 | |||
575 | mulH a b = fromLists [[ doth ai bj | bj <- toColumns b] | ai <- toRows a ] | ||
576 | where doth u v = sum $ zipWith (*) (toList u) (toList v) | ||
577 | |||
578 | ----------------------------------------------------------------------------------- | ||
579 | -- workaround | ||
580 | ----------------------------------------------------------------------------------- | ||
581 | |||
582 | mulCW a b = toComplex (rr,ri) | ||
583 | where rr = multiply ar br `sub` multiply ai bi | ||
584 | ri = multiply ar bi `add` multiply ai br | ||
585 | (ar,ai) = fromComplex a | ||
586 | (br,bi) = fromComplex b | ||
587 | |||
588 | ----------------------------------------------------------------------------------- | ||
589 | -- Direct CBLAS | ||
590 | ----------------------------------------------------------------------------------- | ||
591 | |||
592 | -- taken from Patrick Perry's BLAS package | ||
593 | newtype CBLASOrder = CBLASOrder CInt deriving (Eq, Show) | ||
594 | newtype CBLASTrans = CBLASTrans CInt deriving (Eq, Show) | ||
595 | |||
596 | rowMajor, colMajor :: CBLASOrder | ||
597 | rowMajor = CBLASOrder 101 | ||
598 | colMajor = CBLASOrder 102 | ||
599 | |||
600 | noTrans, trans', conjTrans :: CBLASTrans | ||
601 | noTrans = CBLASTrans 111 | ||
602 | trans' = CBLASTrans 112 | ||
603 | conjTrans = CBLASTrans 113 | ||
604 | |||
605 | foreign import ccall "cblas.h cblas_dgemm" | ||
606 | dgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt | ||
607 | -> Double -> Ptr Double -> CInt -> Ptr Double -> CInt -> Double | ||
608 | -> Ptr Double -> CInt -> IO () | ||
609 | |||
610 | multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double | ||
611 | multiplyR3 x y = multiply3 dgemm "cblas_dgemm" (fmat x) (fmat y) | ||
612 | where | ||
613 | multiply3 f st a b | ||
614 | | cols a == rows b = unsafePerformIO $ do | ||
615 | s <- createMatrix ColumnMajor (rows a) (cols b) | ||
616 | let g ar ac ap br bc bp rr _rc rp = f colMajor noTrans noTrans ar bc ac 1 ap ar bp br 0 rp rr >> return 0 | ||
617 | app3 g mat a mat b mat s st | ||
618 | return s | ||
619 | | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" | ||
620 | |||
621 | |||
622 | foreign import ccall "cblas.h cblas_zgemm" | ||
623 | zgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt | ||
624 | -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) | ||
625 | -> Ptr (Complex Double) -> CInt -> IO () | ||
626 | |||
627 | multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | ||
628 | multiplyC3 x y = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat x) (fmat y) | ||
629 | where | ||
630 | multiply3 f st a b | ||
631 | | cols a == rows b = do | ||
632 | s <- createMatrix ColumnMajor (rows a) (cols b) | ||
633 | palpha <- new 1 | ||
634 | pbeta <- new 0 | ||
635 | let g ar ac ap br bc bp rr _rc rp = f colMajor noTrans noTrans ar bc ac palpha ap ar bp br pbeta rp rr >> return 0 | ||
636 | app3 g mat a mat b mat s st | ||
637 | free palpha | ||
638 | free pbeta | ||
639 | return s | ||
640 | | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" | ||
diff --git a/lib/Numeric/LinearAlgebra/LAPACK.hs b/lib/Numeric/LinearAlgebra/LAPACK.hs index 8bc2492..56945d7 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK.hs +++ b/lib/Numeric/LinearAlgebra/LAPACK.hs | |||
@@ -14,6 +14,7 @@ | |||
14 | ----------------------------------------------------------------------------- | 14 | ----------------------------------------------------------------------------- |
15 | 15 | ||
16 | module Numeric.LinearAlgebra.LAPACK ( | 16 | module Numeric.LinearAlgebra.LAPACK ( |
17 | multiplyR, multiplyC, | ||
17 | svdR, svdRdd, svdC, | 18 | svdR, svdRdd, svdC, |
18 | eigC, eigR, eigS, eigH, eigS', eigH', | 19 | eigC, eigR, eigS, eigH, eigS', eigH', |
19 | linearSolveR, linearSolveC, | 20 | linearSolveR, linearSolveC, |
@@ -35,6 +36,33 @@ import Numeric.GSL.Vector(vectorMapValR, FunCodeSV(Scale)) | |||
35 | import Complex | 36 | import Complex |
36 | import Foreign | 37 | import Foreign |
37 | import Foreign.C.Types (CInt) | 38 | import Foreign.C.Types (CInt) |
39 | import Control.Monad(when) | ||
40 | |||
41 | ----------------------------------------------------------------------------------- | ||
42 | |||
43 | foreign import ccall "LAPACK/lapack-aux.h multiplyR" dgemmc :: CInt -> CInt -> TMMM | ||
44 | foreign import ccall "LAPACK/lapack-aux.h multiplyC" zgemmc :: CInt -> CInt -> TCMCMCM | ||
45 | |||
46 | isT MF{} = 0 | ||
47 | isT MC{} = 1 | ||
48 | |||
49 | tt x@MF{} = x | ||
50 | tt x@MC{} = trans x | ||
51 | |||
52 | multiplyAux f st a b = unsafePerformIO $ do | ||
53 | when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++ | ||
54 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) | ||
55 | s <- createMatrix ColumnMajor (rows a) (cols b) | ||
56 | app3 (f (isT a) (isT b)) mat (tt a) mat (tt b) mat s st | ||
57 | return s | ||
58 | |||
59 | -- | Matrix product based on BLAS's /dgemm/. | ||
60 | multiplyR :: Matrix Double -> Matrix Double -> Matrix Double | ||
61 | multiplyR a b = multiplyAux dgemmc "dgemmc" a b | ||
62 | |||
63 | -- | Matrix product based on BLAS's /zgemm/. | ||
64 | multiplyC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | ||
65 | multiplyC a b = multiplyAux zgemmc "zgemmc" a b | ||
38 | 66 | ||
39 | ----------------------------------------------------------------------------- | 67 | ----------------------------------------------------------------------------- |
40 | foreign import ccall "LAPACK/lapack-aux.h svd_l_R" dgesvd :: TMMVM | 68 | foreign import ccall "LAPACK/lapack-aux.h svd_l_R" dgesvd :: TMMVM |
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c index 842b5ad..e85c1b7 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | |||
@@ -860,3 +860,44 @@ int luS_l_C(KCMAT(a), KDVEC(ipiv), KCMAT(b), CMAT(x)) { | |||
860 | free(auxipiv); | 860 | free(auxipiv); |
861 | OK | 861 | OK |
862 | } | 862 | } |
863 | |||
864 | //////////////////// Matrix Product ///////////////////////// | ||
865 | |||
866 | void dgemm_(char *, char *, integer *, integer *, integer *, | ||
867 | double *, const double *, integer *, const double *, | ||
868 | integer *, double *, double *, integer *); | ||
869 | |||
870 | int multiplyR(int ta, int tb, KDMAT(a),KDMAT(b),DMAT(r)) { | ||
871 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
872 | integer m = ta?ac:ar; | ||
873 | integer n = tb?br:bc; | ||
874 | integer k = ta?ar:ac; | ||
875 | integer lda = ar; | ||
876 | integer ldb = br; | ||
877 | integer ldc = rr; | ||
878 | double alpha = 1; | ||
879 | double beta = 0; | ||
880 | dgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc); | ||
881 | OK | ||
882 | } | ||
883 | |||
884 | void zgemm_(char *, char *, integer *, integer *, integer *, | ||
885 | doublecomplex *, const doublecomplex *, integer *, const doublecomplex *, | ||
886 | integer *, doublecomplex *, doublecomplex *, integer *); | ||
887 | |||
888 | int multiplyC(int ta, int tb, KCMAT(a),KCMAT(b),CMAT(r)) { | ||
889 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
890 | integer m = ta?ac:ar; | ||
891 | integer n = tb?br:bc; | ||
892 | integer k = ta?ar:ac; | ||
893 | integer lda = ar; | ||
894 | integer ldb = br; | ||
895 | integer ldc = rr; | ||
896 | doublecomplex alpha = {1,0}; | ||
897 | doublecomplex beta = {0,0}; | ||
898 | zgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha, | ||
899 | (doublecomplex*)ap,&lda, | ||
900 | (doublecomplex*)bp,&ldb,&beta, | ||
901 | (doublecomplex*)rp,&ldc); | ||
902 | OK | ||
903 | } | ||
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h index 23e5e28..3f58243 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h | |||
@@ -45,11 +45,15 @@ typedef short ftnlen; | |||
45 | #define DMAT(A) int A##r, int A##c, double* A##p | 45 | #define DMAT(A) int A##r, int A##c, double* A##p |
46 | #define CMAT(A) int A##r, int A##c, double* A##p | 46 | #define CMAT(A) int A##r, int A##c, double* A##p |
47 | 47 | ||
48 | // const pointer versions for the parameters | ||
49 | #define KDVEC(A) int A##n, const double*A##p | 48 | #define KDVEC(A) int A##n, const double*A##p |
50 | #define KCVEC(A) int A##n, const double*A##p | 49 | #define KCVEC(A) int A##n, const double*A##p |
51 | #define KDMAT(A) int A##r, int A##c, const double* A##p | 50 | #define KDMAT(A) int A##r, int A##c, const double* A##p |
52 | #define KCMAT(A) int A##r, int A##c, const double* A##p | 51 | #define KCMAT(A) int A##r, int A##c, const double* A##p |
52 | |||
53 | /********************************************************/ | ||
54 | |||
55 | int multiplyR(int ta, int tb, KDMAT(a),KDMAT(b),DMAT(r)); | ||
56 | int multiplyC(int ta, int tb, KCMAT(a),KCMAT(b),CMAT(r)); | ||
53 | 57 | ||
54 | int svd_l_R(KDMAT(x),DMAT(u),DVEC(s),DMAT(v)); | 58 | int svd_l_R(KDMAT(x),DMAT(u),DVEC(s),DMAT(v)); |
55 | int svd_l_Rdd(KDMAT(x),DMAT(u),DVEC(s),DMAT(v)); | 59 | int svd_l_Rdd(KDMAT(x),DMAT(u),DVEC(s),DMAT(v)); |
diff --git a/lib/Numeric/LinearAlgebra/Tests/Properties.hs b/lib/Numeric/LinearAlgebra/Tests/Properties.hs index 45b03a2..ec87ad0 100644 --- a/lib/Numeric/LinearAlgebra/Tests/Properties.hs +++ b/lib/Numeric/LinearAlgebra/Tests/Properties.hs | |||
@@ -152,6 +152,10 @@ cholProp m = m |~| ctrans c <> c && upperTriang c | |||
152 | expmDiagProp m = expm (logm m) :~ 7 ~: complex m | 152 | expmDiagProp m = expm (logm m) :~ 7 ~: complex m |
153 | where logm = matFunc log | 153 | where logm = matFunc log |
154 | 154 | ||
155 | -- reference multiply | ||
156 | mulH a b = fromLists [[ doth ai bj | bj <- toColumns b] | ai <- toRows a ] | ||
157 | where doth u v = sum $ zipWith (*) (toList u) (toList v) | ||
158 | |||
155 | multProp1 (a,b) = a <> b |~| mulH a b | 159 | multProp1 (a,b) = a <> b |~| mulH a b |
156 | 160 | ||
157 | multProp2 (a,b) = ctrans (a <> b) |~| ctrans b <> ctrans a | 161 | multProp2 (a,b) = ctrans (a <> b) |~| ctrans b <> ctrans a |