summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2008-11-04 09:32:35 +0000
committerAlberto Ruiz <aruiz@um.es>2008-11-04 09:32:35 +0000
commit02805ad64715373347b34bac2f75cbb866563ba2 (patch)
tree4eeb137ce0232d57ce98c0a0ced8fffe7baf7f99 /lib
parent86c7aed1de8efe5988f994867d35addb6b62a655 (diff)
multiply/trans ok
Diffstat (limited to 'lib')
-rw-r--r--lib/Graphics/Plot.hs4
-rw-r--r--lib/Numeric/LinearAlgebra/Algorithms.hs80
-rw-r--r--lib/Numeric/LinearAlgebra/LAPACK.hs28
-rw-r--r--lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c41
-rw-r--r--lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h8
-rw-r--r--lib/Numeric/LinearAlgebra/Tests/Properties.hs4
6 files changed, 86 insertions, 79 deletions
diff --git a/lib/Graphics/Plot.hs b/lib/Graphics/Plot.hs
index 9352048..bcb1fb3 100644
--- a/lib/Graphics/Plot.hs
+++ b/lib/Graphics/Plot.hs
@@ -9,7 +9,9 @@
9-- Portability : uses gnuplot and ImageMagick 9-- Portability : uses gnuplot and ImageMagick
10-- 10--
11-- Very basic (and provisional) drawing tools using gnuplot and imageMagick. 11-- Very basic (and provisional) drawing tools using gnuplot and imageMagick.
12-- 12--
13-- This module is deprecated. It will be replaced by improved drawing tools based
14-- on the Gnuplot package by Henning Thielemann.
13----------------------------------------------------------------------------- 15-----------------------------------------------------------------------------
14 16
15module Graphics.Plot( 17module Graphics.Plot(
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs
index 75f4ba3..f259db5 100644
--- a/lib/Numeric/LinearAlgebra/Algorithms.hs
+++ b/lib/Numeric/LinearAlgebra/Algorithms.hs
@@ -54,7 +54,6 @@ module Numeric.LinearAlgebra.Algorithms (
54 ctrans, 54 ctrans,
55 eps, i, 55 eps, i,
56 outer, kronecker, 56 outer, kronecker,
57 mulH,
58-- * Util 57-- * Util
59 haussholder, 58 haussholder,
60 unpackQR, unpackHess, 59 unpackQR, unpackHess,
@@ -70,8 +69,8 @@ import Complex
70import Numeric.LinearAlgebra.Linear 69import Numeric.LinearAlgebra.Linear
71import Data.List(foldl1') 70import Data.List(foldl1')
72import Data.Array 71import Data.Array
73import Foreign 72
74import Foreign.C.Types 73
75 74
76-- | Auxiliary typeclass used to define generic computations for both real and complex matrices. 75-- | Auxiliary typeclass used to define generic computations for both real and complex matrices.
77class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where 76class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where
@@ -132,7 +131,7 @@ instance Field Double where
132 qr = unpackQR . qrR 131 qr = unpackQR . qrR
133 hess = unpackHess hessR 132 hess = unpackHess hessR
134 schur = schurR 133 schur = schurR
135 multiply = multiplyR3 134 multiply = multiplyR
136 135
137instance Field (Complex Double) where 136instance Field (Complex Double) where
138 svd = svdC 137 svd = svdC
@@ -147,7 +146,7 @@ instance Field (Complex Double) where
147 qr = unpackQR . qrC 146 qr = unpackQR . qrC
148 hess = unpackHess hessC 147 hess = unpackHess hessC
149 schur = schurC 148 schur = schurC
150 multiply = multiplyC3 149 multiply = multiplyC
151 150
152 151
153-- | Eigenvalues and Eigenvectors of a complex hermitian or real symmetric matrix using lapack's dsyev or zheev. 152-- | Eigenvalues and Eigenvectors of a complex hermitian or real symmetric matrix using lapack's dsyev or zheev.
@@ -567,74 +566,3 @@ kronecker a b = fromBlocks
567 . map (reshape (cols b)) 566 . map (reshape (cols b))
568 . toRows 567 . toRows
569 $ flatten a `outer` flatten b 568 $ flatten a `outer` flatten b
570
571---------------------------------------------------------------------
572-- reference multiply
573---------------------------------------------------------------------
574
575mulH a b = fromLists [[ doth ai bj | bj <- toColumns b] | ai <- toRows a ]
576 where doth u v = sum $ zipWith (*) (toList u) (toList v)
577
578-----------------------------------------------------------------------------------
579-- workaround
580-----------------------------------------------------------------------------------
581
582mulCW a b = toComplex (rr,ri)
583 where rr = multiply ar br `sub` multiply ai bi
584 ri = multiply ar bi `add` multiply ai br
585 (ar,ai) = fromComplex a
586 (br,bi) = fromComplex b
587
588-----------------------------------------------------------------------------------
589-- Direct CBLAS
590-----------------------------------------------------------------------------------
591
592-- taken from Patrick Perry's BLAS package
593newtype CBLASOrder = CBLASOrder CInt deriving (Eq, Show)
594newtype CBLASTrans = CBLASTrans CInt deriving (Eq, Show)
595
596rowMajor, colMajor :: CBLASOrder
597rowMajor = CBLASOrder 101
598colMajor = CBLASOrder 102
599
600noTrans, trans', conjTrans :: CBLASTrans
601noTrans = CBLASTrans 111
602trans' = CBLASTrans 112
603conjTrans = CBLASTrans 113
604
605foreign import ccall "cblas.h cblas_dgemm"
606 dgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt
607 -> Double -> Ptr Double -> CInt -> Ptr Double -> CInt -> Double
608 -> Ptr Double -> CInt -> IO ()
609
610multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double
611multiplyR3 x y = multiply3 dgemm "cblas_dgemm" (fmat x) (fmat y)
612 where
613 multiply3 f st a b
614 | cols a == rows b = unsafePerformIO $ do
615 s <- createMatrix ColumnMajor (rows a) (cols b)
616 let g ar ac ap br bc bp rr _rc rp = f colMajor noTrans noTrans ar bc ac 1 ap ar bp br 0 rp rr >> return 0
617 app3 g mat a mat b mat s st
618 return s
619 | otherwise = error $ st ++ " (matrix product) of nonconformant matrices"
620
621
622foreign import ccall "cblas.h cblas_zgemm"
623 zgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt
624 -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double)
625 -> Ptr (Complex Double) -> CInt -> IO ()
626
627multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
628multiplyC3 x y = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat x) (fmat y)
629 where
630 multiply3 f st a b
631 | cols a == rows b = do
632 s <- createMatrix ColumnMajor (rows a) (cols b)
633 palpha <- new 1
634 pbeta <- new 0
635 let g ar ac ap br bc bp rr _rc rp = f colMajor noTrans noTrans ar bc ac palpha ap ar bp br pbeta rp rr >> return 0
636 app3 g mat a mat b mat s st
637 free palpha
638 free pbeta
639 return s
640 | otherwise = error $ st ++ " (matrix product) of nonconformant matrices"
diff --git a/lib/Numeric/LinearAlgebra/LAPACK.hs b/lib/Numeric/LinearAlgebra/LAPACK.hs
index 8bc2492..56945d7 100644
--- a/lib/Numeric/LinearAlgebra/LAPACK.hs
+++ b/lib/Numeric/LinearAlgebra/LAPACK.hs
@@ -14,6 +14,7 @@
14----------------------------------------------------------------------------- 14-----------------------------------------------------------------------------
15 15
16module Numeric.LinearAlgebra.LAPACK ( 16module Numeric.LinearAlgebra.LAPACK (
17 multiplyR, multiplyC,
17 svdR, svdRdd, svdC, 18 svdR, svdRdd, svdC,
18 eigC, eigR, eigS, eigH, eigS', eigH', 19 eigC, eigR, eigS, eigH, eigS', eigH',
19 linearSolveR, linearSolveC, 20 linearSolveR, linearSolveC,
@@ -35,6 +36,33 @@ import Numeric.GSL.Vector(vectorMapValR, FunCodeSV(Scale))
35import Complex 36import Complex
36import Foreign 37import Foreign
37import Foreign.C.Types (CInt) 38import Foreign.C.Types (CInt)
39import Control.Monad(when)
40
41-----------------------------------------------------------------------------------
42
43foreign import ccall "LAPACK/lapack-aux.h multiplyR" dgemmc :: CInt -> CInt -> TMMM
44foreign import ccall "LAPACK/lapack-aux.h multiplyC" zgemmc :: CInt -> CInt -> TCMCMCM
45
46isT MF{} = 0
47isT MC{} = 1
48
49tt x@MF{} = x
50tt x@MC{} = trans x
51
52multiplyAux f st a b = unsafePerformIO $ do
53 when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++
54 show (rows a,cols a) ++ " x " ++ show (rows b, cols b)
55 s <- createMatrix ColumnMajor (rows a) (cols b)
56 app3 (f (isT a) (isT b)) mat (tt a) mat (tt b) mat s st
57 return s
58
59-- | Matrix product based on BLAS's /dgemm/.
60multiplyR :: Matrix Double -> Matrix Double -> Matrix Double
61multiplyR a b = multiplyAux dgemmc "dgemmc" a b
62
63-- | Matrix product based on BLAS's /zgemm/.
64multiplyC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
65multiplyC a b = multiplyAux zgemmc "zgemmc" a b
38 66
39----------------------------------------------------------------------------- 67-----------------------------------------------------------------------------
40foreign import ccall "LAPACK/lapack-aux.h svd_l_R" dgesvd :: TMMVM 68foreign import ccall "LAPACK/lapack-aux.h svd_l_R" dgesvd :: TMMVM
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
index 842b5ad..e85c1b7 100644
--- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
+++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
@@ -860,3 +860,44 @@ int luS_l_C(KCMAT(a), KDVEC(ipiv), KCMAT(b), CMAT(x)) {
860 free(auxipiv); 860 free(auxipiv);
861 OK 861 OK
862} 862}
863
864//////////////////// Matrix Product /////////////////////////
865
866void dgemm_(char *, char *, integer *, integer *, integer *,
867 double *, const double *, integer *, const double *,
868 integer *, double *, double *, integer *);
869
870int multiplyR(int ta, int tb, KDMAT(a),KDMAT(b),DMAT(r)) {
871 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
872 integer m = ta?ac:ar;
873 integer n = tb?br:bc;
874 integer k = ta?ar:ac;
875 integer lda = ar;
876 integer ldb = br;
877 integer ldc = rr;
878 double alpha = 1;
879 double beta = 0;
880 dgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc);
881 OK
882}
883
884void zgemm_(char *, char *, integer *, integer *, integer *,
885 doublecomplex *, const doublecomplex *, integer *, const doublecomplex *,
886 integer *, doublecomplex *, doublecomplex *, integer *);
887
888int multiplyC(int ta, int tb, KCMAT(a),KCMAT(b),CMAT(r)) {
889 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
890 integer m = ta?ac:ar;
891 integer n = tb?br:bc;
892 integer k = ta?ar:ac;
893 integer lda = ar;
894 integer ldb = br;
895 integer ldc = rr;
896 doublecomplex alpha = {1,0};
897 doublecomplex beta = {0,0};
898 zgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,
899 (doublecomplex*)ap,&lda,
900 (doublecomplex*)bp,&ldb,&beta,
901 (doublecomplex*)rp,&ldc);
902 OK
903}
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h
index 23e5e28..3f58243 100644
--- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h
+++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h
@@ -45,11 +45,15 @@ typedef short ftnlen;
45#define DMAT(A) int A##r, int A##c, double* A##p 45#define DMAT(A) int A##r, int A##c, double* A##p
46#define CMAT(A) int A##r, int A##c, double* A##p 46#define CMAT(A) int A##r, int A##c, double* A##p
47 47
48// const pointer versions for the parameters
49#define KDVEC(A) int A##n, const double*A##p 48#define KDVEC(A) int A##n, const double*A##p
50#define KCVEC(A) int A##n, const double*A##p 49#define KCVEC(A) int A##n, const double*A##p
51#define KDMAT(A) int A##r, int A##c, const double* A##p 50#define KDMAT(A) int A##r, int A##c, const double* A##p
52#define KCMAT(A) int A##r, int A##c, const double* A##p 51#define KCMAT(A) int A##r, int A##c, const double* A##p
52
53/********************************************************/
54
55int multiplyR(int ta, int tb, KDMAT(a),KDMAT(b),DMAT(r));
56int multiplyC(int ta, int tb, KCMAT(a),KCMAT(b),CMAT(r));
53 57
54int svd_l_R(KDMAT(x),DMAT(u),DVEC(s),DMAT(v)); 58int svd_l_R(KDMAT(x),DMAT(u),DVEC(s),DMAT(v));
55int svd_l_Rdd(KDMAT(x),DMAT(u),DVEC(s),DMAT(v)); 59int svd_l_Rdd(KDMAT(x),DMAT(u),DVEC(s),DMAT(v));
diff --git a/lib/Numeric/LinearAlgebra/Tests/Properties.hs b/lib/Numeric/LinearAlgebra/Tests/Properties.hs
index 45b03a2..ec87ad0 100644
--- a/lib/Numeric/LinearAlgebra/Tests/Properties.hs
+++ b/lib/Numeric/LinearAlgebra/Tests/Properties.hs
@@ -152,6 +152,10 @@ cholProp m = m |~| ctrans c <> c && upperTriang c
152expmDiagProp m = expm (logm m) :~ 7 ~: complex m 152expmDiagProp m = expm (logm m) :~ 7 ~: complex m
153 where logm = matFunc log 153 where logm = matFunc log
154 154
155-- reference multiply
156mulH a b = fromLists [[ doth ai bj | bj <- toColumns b] | ai <- toRows a ]
157 where doth u v = sum $ zipWith (*) (toList u) (toList v)
158
155multProp1 (a,b) = a <> b |~| mulH a b 159multProp1 (a,b) = a <> b |~| mulH a b
156 160
157multProp2 (a,b) = ctrans (a <> b) |~| ctrans b <> ctrans a 161multProp2 (a,b) = ctrans (a <> b) |~| ctrans b <> ctrans a