summaryrefslogtreecommitdiff
path: root/lib/Numeric
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2010-08-26 17:49:45 +0000
committerAlberto Ruiz <aruiz@um.es>2010-08-26 17:49:45 +0000
commit6058e1b17c005be1ea95ebb7d98d9fd15bb538d2 (patch)
treec4277e00c2c92a0ed8f3750255154fa8e2b6fe2d /lib/Numeric
parentf541d7dbdc8338b1dd1c0538751d837a16740bd8 (diff)
Float matrix product
Diffstat (limited to 'lib/Numeric')
-rw-r--r--lib/Numeric/LinearAlgebra/Algorithms.hs81
-rw-r--r--lib/Numeric/LinearAlgebra/Interface.hs2
-rw-r--r--lib/Numeric/LinearAlgebra/LAPACK.hs18
-rw-r--r--lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c81
-rw-r--r--lib/Numeric/LinearAlgebra/Linear.hs90
-rw-r--r--lib/Numeric/LinearAlgebra/Tests.hs16
-rw-r--r--lib/Numeric/LinearAlgebra/Tests/Properties.hs6
7 files changed, 194 insertions, 100 deletions
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs
index f4b7ee9..8962c60 100644
--- a/lib/Numeric/LinearAlgebra/Algorithms.hs
+++ b/lib/Numeric/LinearAlgebra/Algorithms.hs
@@ -21,9 +21,6 @@ imported from "Numeric.LinearAlgebra.LAPACK".
21module Numeric.LinearAlgebra.Algorithms ( 21module Numeric.LinearAlgebra.Algorithms (
22-- * Supported types 22-- * Supported types
23 Field(), 23 Field(),
24-- * Products
25 multiply, -- dot, moved dot to typeclass
26 outer, kronecker,
27-- * Linear Systems 24-- * Linear Systems
28 linearSolve, 25 linearSolve,
29 luSolve, 26 luSolve,
@@ -64,7 +61,6 @@ module Numeric.LinearAlgebra.Algorithms (
64-- * Norms 61-- * Norms
65 Normed(..), NormType(..), 62 Normed(..), NormType(..),
66-- * Misc 63-- * Misc
67 ctrans,
68 eps, i, 64 eps, i,
69-- * Util 65-- * Util
70 haussholder, 66 haussholder,
@@ -86,7 +82,7 @@ import Data.List(foldl1')
86import Data.Array 82import Data.Array
87 83
88-- | Auxiliary typeclass used to define generic computations for both real and complex matrices. 84-- | Auxiliary typeclass used to define generic computations for both real and complex matrices.
89class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where 85class (Prod t, Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where
90 svd' :: Matrix t -> (Matrix t, Vector Double, Matrix t) 86 svd' :: Matrix t -> (Matrix t, Vector Double, Matrix t)
91 thinSVD' :: Matrix t -> (Matrix t, Vector Double, Matrix t) 87 thinSVD' :: Matrix t -> (Matrix t, Vector Double, Matrix t)
92 sv' :: Matrix t -> Vector Double 88 sv' :: Matrix t -> Vector Double
@@ -105,8 +101,6 @@ class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where
105 qr' :: Matrix t -> (Matrix t, Matrix t) 101 qr' :: Matrix t -> (Matrix t, Matrix t)
106 hess' :: Matrix t -> (Matrix t, Matrix t) 102 hess' :: Matrix t -> (Matrix t, Matrix t)
107 schur' :: Matrix t -> (Matrix t, Matrix t) 103 schur' :: Matrix t -> (Matrix t, Matrix t)
108 ctrans' :: Matrix t -> Matrix t
109 multiply' :: Matrix t -> Matrix t -> Matrix t
110 104
111 105
112instance Field Double where 106instance Field Double where
@@ -119,7 +113,6 @@ instance Field Double where
119 cholSolve' = cholSolveR 113 cholSolve' = cholSolveR
120 linearSolveLS' = linearSolveLSR 114 linearSolveLS' = linearSolveLSR
121 linearSolveSVD' = linearSolveSVDR Nothing 115 linearSolveSVD' = linearSolveSVDR Nothing
122 ctrans' = trans
123 eig' = eigR 116 eig' = eigR
124 eigSH'' = eigS 117 eigSH'' = eigS
125 eigOnly = eigOnlyR 118 eigOnly = eigOnlyR
@@ -129,7 +122,6 @@ instance Field Double where
129 qr' = unpackQR . qrR 122 qr' = unpackQR . qrR
130 hess' = unpackHess hessR 123 hess' = unpackHess hessR
131 schur' = schurR 124 schur' = schurR
132 multiply' = multiplyR
133 125
134instance Field (Complex Double) where 126instance Field (Complex Double) where
135#ifdef NOZGESDD 127#ifdef NOZGESDD
@@ -146,7 +138,6 @@ instance Field (Complex Double) where
146 cholSolve' = cholSolveC 138 cholSolve' = cholSolveC
147 linearSolveLS' = linearSolveLSC 139 linearSolveLS' = linearSolveLSC
148 linearSolveSVD' = linearSolveSVDC Nothing 140 linearSolveSVD' = linearSolveSVDC Nothing
149 ctrans' = conj . trans
150 eig' = eigC 141 eig' = eigC
151 eigOnly = eigOnlyC 142 eigOnly = eigOnlyC
152 eigSH'' = eigH 143 eigSH'' = eigH
@@ -156,7 +147,6 @@ instance Field (Complex Double) where
156 qr' = unpackQR . qrC 147 qr' = unpackQR . qrC
157 hess' = unpackHess hessC 148 hess' = unpackHess hessC
158 schur' = schurC 149 schur' = schurC
159 multiply' = multiplyC
160 150
161-------------------------------------------------------------- 151--------------------------------------------------------------
162 152
@@ -324,13 +314,6 @@ hess = hess'
324schur :: Field t => Matrix t -> (Matrix t, Matrix t) 314schur :: Field t => Matrix t -> (Matrix t, Matrix t)
325schur = schur' 315schur = schur'
326 316
327-- | Generic conjugate transpose.
328ctrans :: Field t => Matrix t -> Matrix t
329ctrans = ctrans'
330
331-- | Matrix product.
332multiply :: Field t => Matrix t -> Matrix t -> Matrix t
333multiply = {-# SCC "multiply" #-} multiply'
334 317
335-- | Similar to 'cholSH', but instead of an error (e.g., caused by a matrix not positive definite) it returns 'Nothing'. 318-- | Similar to 'cholSH', but instead of an error (e.g., caused by a matrix not positive definite) it returns 'Nothing'.
336mbCholSH :: Field t => Matrix t -> Maybe (Matrix t) 319mbCholSH :: Field t => Matrix t -> Maybe (Matrix t)
@@ -404,20 +387,6 @@ peps x = 2.0**(fromIntegral $ 1-floatDigits x)
404i :: Complex Double 387i :: Complex Double
405i = 0:+1 388i = 0:+1
406 389
407
408-- matrix product
409mXm :: (Num t, Field t) => Matrix t -> Matrix t -> Matrix t
410mXm = multiply
411
412-- matrix - vector product
413mXv :: (Num t, Field t) => Matrix t -> Vector t -> Vector t
414mXv m v = flatten $ m `mXm` (asColumn v)
415
416-- vector - matrix product
417vXm :: (Num t, Field t) => Vector t -> Matrix t -> Vector t
418vXm v m = flatten $ (asRow v) `mXm` m
419
420
421--------------------------------------------------------------------------- 390---------------------------------------------------------------------------
422 391
423norm2 :: Vector Double -> Double 392norm2 :: Vector Double -> Double
@@ -723,51 +692,3 @@ luFact (l_u,perm) | r <= c = (l ,u ,p, s)
723 (|*|) = mul 692 (|*|) = mul
724 693
725-------------------------------------------------- 694--------------------------------------------------
726
727{- moved to Numeric.LinearAlgebra.Interface Vector typeclass
728-- | Euclidean inner product.
729dot :: (Field t) => Vector t -> Vector t -> t
730dot u v = multiply r c @@> (0,0)
731 where r = asRow u
732 c = asColumn v
733-}
734
735{- | Outer product of two vectors.
736
737@\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3]
738(3><3)
739 [ 5.0, 2.0, 3.0
740 , 10.0, 4.0, 6.0
741 , 15.0, 6.0, 9.0 ]@
742-}
743outer :: (Field t) => Vector t -> Vector t -> Matrix t
744outer u v = asColumn u `multiply` asRow v
745
746{- | Kronecker product of two matrices.
747
748@m1=(2><3)
749 [ 1.0, 2.0, 0.0
750 , 0.0, -1.0, 3.0 ]
751m2=(4><3)
752 [ 1.0, 2.0, 3.0
753 , 4.0, 5.0, 6.0
754 , 7.0, 8.0, 9.0
755 , 10.0, 11.0, 12.0 ]@
756
757@\> kronecker m1 m2
758(8><9)
759 [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0
760 , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0
761 , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0
762 , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0
763 , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0
764 , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0
765 , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0
766 , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ]@
767-}
768kronecker :: (Field t) => Matrix t -> Matrix t -> Matrix t
769kronecker a b = fromBlocks
770 . splitEvery (cols a)
771 . map (reshape (cols b))
772 . toRows
773 $ flatten a `outer` flatten b
diff --git a/lib/Numeric/LinearAlgebra/Interface.hs b/lib/Numeric/LinearAlgebra/Interface.hs
index f8917a0..6df782f 100644
--- a/lib/Numeric/LinearAlgebra/Interface.hs
+++ b/lib/Numeric/LinearAlgebra/Interface.hs
@@ -35,7 +35,7 @@ import Numeric.LinearAlgebra.Linear
35class Mul a b c | a b -> c where 35class Mul a b c | a b -> c where
36 infixl 7 <> 36 infixl 7 <>
37 -- | Matrix-matrix, matrix-vector, and vector-matrix products. 37 -- | Matrix-matrix, matrix-vector, and vector-matrix products.
38 (<>) :: Field t => a t -> b t -> c t 38 (<>) :: Prod t => a t -> b t -> c t
39 39
40instance Mul Matrix Matrix Matrix where 40instance Mul Matrix Matrix Matrix where
41 (<>) = multiply 41 (<>) = multiply
diff --git a/lib/Numeric/LinearAlgebra/LAPACK.hs b/lib/Numeric/LinearAlgebra/LAPACK.hs
index 7f057ba..eec3035 100644
--- a/lib/Numeric/LinearAlgebra/LAPACK.hs
+++ b/lib/Numeric/LinearAlgebra/LAPACK.hs
@@ -14,7 +14,7 @@
14 14
15module Numeric.LinearAlgebra.LAPACK ( 15module Numeric.LinearAlgebra.LAPACK (
16 -- * Matrix product 16 -- * Matrix product
17 multiplyR, multiplyC, 17 multiplyR, multiplyC, multiplyF, multiplyQ,
18 -- * Linear systems 18 -- * Linear systems
19 linearSolveR, linearSolveC, 19 linearSolveR, linearSolveC,
20 lusR, lusC, 20 lusR, lusC,
@@ -51,8 +51,10 @@ import Control.Monad(when)
51 51
52----------------------------------------------------------------------------------- 52-----------------------------------------------------------------------------------
53 53
54foreign import ccall "LAPACK/lapack-aux.h multiplyR" dgemmc :: CInt -> CInt -> TMMM 54foreign import ccall "multiplyR" dgemmc :: CInt -> CInt -> TMMM
55foreign import ccall "LAPACK/lapack-aux.h multiplyC" zgemmc :: CInt -> CInt -> TCMCMCM 55foreign import ccall "multiplyC" zgemmc :: CInt -> CInt -> TCMCMCM
56foreign import ccall "multiplyF" sgemmc :: CInt -> CInt -> TFMFMFM
57foreign import ccall "multiplyQ" cgemmc :: CInt -> CInt -> TQMQMQM
56 58
57isT MF{} = 0 59isT MF{} = 0
58isT MC{} = 1 60isT MC{} = 1
@@ -69,12 +71,20 @@ multiplyAux f st a b = unsafePerformIO $ do
69 71
70-- | Matrix product based on BLAS's /dgemm/. 72-- | Matrix product based on BLAS's /dgemm/.
71multiplyR :: Matrix Double -> Matrix Double -> Matrix Double 73multiplyR :: Matrix Double -> Matrix Double -> Matrix Double
72multiplyR a b = multiplyAux dgemmc "dgemmc" a b 74multiplyR a b = {-# SCC "multiplyR" #-} multiplyAux dgemmc "dgemmc" a b
73 75
74-- | Matrix product based on BLAS's /zgemm/. 76-- | Matrix product based on BLAS's /zgemm/.
75multiplyC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) 77multiplyC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
76multiplyC a b = multiplyAux zgemmc "zgemmc" a b 78multiplyC a b = multiplyAux zgemmc "zgemmc" a b
77 79
80-- | Matrix product based on BLAS's /sgemm/.
81multiplyF :: Matrix Float -> Matrix Float -> Matrix Float
82multiplyF a b = multiplyAux sgemmc "sgemmc" a b
83
84-- | Matrix product based on BLAS's /cgemm/.
85multiplyQ :: Matrix (Complex Float) -> Matrix (Complex Float) -> Matrix (Complex Float)
86multiplyQ a b = multiplyAux cgemmc "cgemmc" a b
87
78----------------------------------------------------------------------------- 88-----------------------------------------------------------------------------
79foreign import ccall "svd_l_R" dgesvd :: TMMVM 89foreign import ccall "svd_l_R" dgesvd :: TMMVM
80foreign import ccall "svd_l_C" zgesvd :: TCMCMVCM 90foreign import ccall "svd_l_C" zgesvd :: TCMCMVCM
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
index 7a40991..9e44431 100644
--- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
+++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
@@ -11,15 +11,25 @@
11 11
12#define MIN(A,B) ((A)<(B)?(A):(B)) 12#define MIN(A,B) ((A)<(B)?(A):(B))
13#define MAX(A,B) ((A)>(B)?(A):(B)) 13#define MAX(A,B) ((A)>(B)?(A):(B))
14 14
15// #define DBGL
16
15#ifdef DBGL 17#ifdef DBGL
16#define DEBUGMSG(M) printf("LAPACK Wrapper "M"\n: "); size_t t0 = time(NULL); 18#define DEBUGMSG(M) printf("\nLAPACK "M"\n");
17#define OK MACRO(printf("%ld s\n",time(0)-t0); return 0;);
18#else 19#else
19#define DEBUGMSG(M) 20#define DEBUGMSG(M)
20#define OK return 0;
21#endif 21#endif
22 22
23#define OK return 0;
24
25// #ifdef DBGL
26// #define DEBUGMSG(M) printf("LAPACK Wrapper "M"\n: "); size_t t0 = time(NULL);
27// #define OK MACRO(printf("%ld s\n",time(0)-t0); return 0;);
28// #else
29// #define DEBUGMSG(M)
30// #define OK return 0;
31// #endif
32
23#define TRACEMAT(M) {int q; printf(" %d x %d: ",M##r,M##c); \ 33#define TRACEMAT(M) {int q; printf(" %d x %d: ",M##r,M##c); \
24 for(q=0;q<M##r*M##c;q++) printf("%.1f ",M##p[q]); printf("\n");} 34 for(q=0;q<M##r*M##c;q++) printf("%.1f ",M##p[q]); printf("\n");}
25 35
@@ -1004,6 +1014,7 @@ void dgemm_(char *, char *, integer *, integer *, integer *,
1004 1014
1005int multiplyR(int ta, int tb, KDMAT(a),KDMAT(b),DMAT(r)) { 1015int multiplyR(int ta, int tb, KDMAT(a),KDMAT(b),DMAT(r)) {
1006 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); 1016 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
1017 DEBUGMSG("dgemm_");
1007 integer m = ta?ac:ar; 1018 integer m = ta?ac:ar;
1008 integer n = tb?br:bc; 1019 integer n = tb?br:bc;
1009 integer k = ta?ar:ac; 1020 integer k = ta?ar:ac;
@@ -1022,6 +1033,7 @@ void zgemm_(char *, char *, integer *, integer *, integer *,
1022 1033
1023int multiplyC(int ta, int tb, KCMAT(a),KCMAT(b),CMAT(r)) { 1034int multiplyC(int ta, int tb, KCMAT(a),KCMAT(b),CMAT(r)) {
1024 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); 1035 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
1036 DEBUGMSG("zgemm_");
1025 integer m = ta?ac:ar; 1037 integer m = ta?ac:ar;
1026 integer n = tb?br:bc; 1038 integer n = tb?br:bc;
1027 integer k = ta?ar:ac; 1039 integer k = ta?ar:ac;
@@ -1037,6 +1049,47 @@ int multiplyC(int ta, int tb, KCMAT(a),KCMAT(b),CMAT(r)) {
1037 OK 1049 OK
1038} 1050}
1039 1051
1052void sgemm_(char *, char *, integer *, integer *, integer *,
1053 float *, const float *, integer *, const float *,
1054 integer *, float *, float *, integer *);
1055
1056int multiplyF(int ta, int tb, KFMAT(a),KFMAT(b),FMAT(r)) {
1057 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
1058 DEBUGMSG("sgemm_");
1059 integer m = ta?ac:ar;
1060 integer n = tb?br:bc;
1061 integer k = ta?ar:ac;
1062 integer lda = ar;
1063 integer ldb = br;
1064 integer ldc = rr;
1065 float alpha = 1;
1066 float beta = 0;
1067 sgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc);
1068 OK
1069}
1070
1071void cgemm_(char *, char *, integer *, integer *, integer *,
1072 complex *, const complex *, integer *, const complex *,
1073 integer *, complex *, complex *, integer *);
1074
1075int multiplyQ(int ta, int tb, KQMAT(a),KQMAT(b),QMAT(r)) {
1076 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
1077 DEBUGMSG("cgemm_");
1078 integer m = ta?ac:ar;
1079 integer n = tb?br:bc;
1080 integer k = ta?ar:ac;
1081 integer lda = ar;
1082 integer ldb = br;
1083 integer ldc = rr;
1084 complex alpha = {1,0};
1085 complex beta = {0,0};
1086 cgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,
1087 (complex*)ap,&lda,
1088 (complex*)bp,&ldb,&beta,
1089 (complex*)rp,&ldc);
1090 OK
1091}
1092
1040//////////////////// transpose ///////////////////////// 1093//////////////////// transpose /////////////////////////
1041 1094
1042int transF(KFMAT(x),FMAT(t)) { 1095int transF(KFMAT(x),FMAT(t)) {
@@ -1128,3 +1181,23 @@ int constantC(doublecomplex* pval, CVEC(r)) {
1128 } 1181 }
1129 OK 1182 OK
1130} 1183}
1184
1185//////////////////// float-double conversion /////////////////////////
1186
1187int float2double(FVEC(x),DVEC(y)) {
1188 DEBUGMSG("float2double")
1189 int k;
1190 for(k=0;k<xn;k++) {
1191 yp[k]=xp[k];
1192 }
1193 OK
1194}
1195
1196int double2float(DVEC(x),FVEC(y)) {
1197 DEBUGMSG("double2float")
1198 int k;
1199 for(k=0;k<xn;k++) {
1200 yp[k]=xp[k];
1201 }
1202 OK
1203}
diff --git a/lib/Numeric/LinearAlgebra/Linear.hs b/lib/Numeric/LinearAlgebra/Linear.hs
index 51e93fb..ae48245 100644
--- a/lib/Numeric/LinearAlgebra/Linear.hs
+++ b/lib/Numeric/LinearAlgebra/Linear.hs
@@ -19,15 +19,19 @@ module Numeric.LinearAlgebra.Linear (
19 -- * Linear Algebra Typeclasses 19 -- * Linear Algebra Typeclasses
20 Vectors(..), 20 Vectors(..),
21 Linear(..), 21 Linear(..),
22 -- * Products
23 Prod(..),
24 mXm,mXv,vXm, mulH,
25 outer, kronecker,
22 -- * Creation of numeric vectors 26 -- * Creation of numeric vectors
23 constant, linspace 27 constant, linspace
24) where 28) where
25 29
26import Data.Packed.Internal.Vector 30import Data.Packed.Internal
27import Data.Packed.Internal.Matrix
28import Data.Packed.Matrix 31import Data.Packed.Matrix
29import Data.Complex 32import Data.Complex
30import Numeric.GSL.Vector 33import Numeric.GSL.Vector
34import Numeric.LinearAlgebra.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ)
31 35
32import Control.Monad(ap) 36import Control.Monad(ap)
33 37
@@ -86,7 +90,7 @@ instance Vectors Vector (Complex Double) where
86---------------------------------------------------- 90----------------------------------------------------
87 91
88-- | Basic element-by-element functions. 92-- | Basic element-by-element functions.
89class (Element e, AutoReal e, Convert e, Container c) => Linear c e where 93class (Element e, AutoReal e, Container c) => Linear c e where
90 -- | create a structure with a single element 94 -- | create a structure with a single element
91 scalar :: e -> c e 95 scalar :: e -> c e
92 scale :: e -> c e -> c e 96 scale :: e -> c e -> c e
@@ -184,3 +188,83 @@ linspace :: (Enum e, Linear Vector e) => Int -> (e, e) -> Vector e
184linspace n (a,b) = addConstant a $ scale s $ fromList [0 .. fromIntegral n-1] 188linspace n (a,b) = addConstant a $ scale s $ fromList [0 .. fromIntegral n-1]
185 where s = (b-a)/fromIntegral (n-1) 189 where s = (b-a)/fromIntegral (n-1)
186 190
191----------------------------------------------------
192
193-- reference multiply
194mulH a b = fromLists [[ doth ai bj | bj <- toColumns b] | ai <- toRows a ]
195 where doth u v = sum $ zipWith (*) (toList u) (toList v)
196
197class Element t => Prod t where
198 multiply :: Matrix t -> Matrix t -> Matrix t
199 multiply = mulH
200 ctrans :: Matrix t -> Matrix t
201
202instance Prod Double where
203 multiply = multiplyR
204 ctrans = trans
205
206instance Prod (Complex Double) where
207 multiply = multiplyC
208 ctrans = conj . trans
209
210instance Prod Float where
211 multiply = multiplyF
212 ctrans = trans
213
214instance Prod (Complex Float) where
215 multiply = multiplyQ
216 ctrans = conj . trans
217
218----------------------------------------------------------
219
220-- synonym for matrix product
221mXm :: Prod t => Matrix t -> Matrix t -> Matrix t
222mXm = multiply
223
224-- matrix - vector product
225mXv :: Prod t => Matrix t -> Vector t -> Vector t
226mXv m v = flatten $ m `mXm` (asColumn v)
227
228-- vector - matrix product
229vXm :: Prod t => Vector t -> Matrix t -> Vector t
230vXm v m = flatten $ (asRow v) `mXm` m
231
232{- | Outer product of two vectors.
233
234@\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3]
235(3><3)
236 [ 5.0, 2.0, 3.0
237 , 10.0, 4.0, 6.0
238 , 15.0, 6.0, 9.0 ]@
239-}
240outer :: (Prod t) => Vector t -> Vector t -> Matrix t
241outer u v = asColumn u `multiply` asRow v
242
243{- | Kronecker product of two matrices.
244
245@m1=(2><3)
246 [ 1.0, 2.0, 0.0
247 , 0.0, -1.0, 3.0 ]
248m2=(4><3)
249 [ 1.0, 2.0, 3.0
250 , 4.0, 5.0, 6.0
251 , 7.0, 8.0, 9.0
252 , 10.0, 11.0, 12.0 ]@
253
254@\> kronecker m1 m2
255(8><9)
256 [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0
257 , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0
258 , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0
259 , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0
260 , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0
261 , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0
262 , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0
263 , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ]@
264-}
265kronecker :: (Prod t) => Matrix t -> Matrix t -> Matrix t
266kronecker a b = fromBlocks
267 . splitEvery (cols a)
268 . map (reshape (cols b))
269 . toRows
270 $ flatten a `outer` flatten b
diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs
index e3b6e1f..91f6742 100644
--- a/lib/Numeric/LinearAlgebra/Tests.hs
+++ b/lib/Numeric/LinearAlgebra/Tests.hs
@@ -34,6 +34,7 @@ import qualified Prelude
34import System.CPUTime 34import System.CPUTime
35import Text.Printf 35import Text.Printf
36import Data.Packed.Development(unsafeFromForeignPtr,unsafeToForeignPtr) 36import Data.Packed.Development(unsafeFromForeignPtr,unsafeToForeignPtr)
37import Control.Arrow((***))
37 38
38#include "Tests/quickCheckCompat.h" 39#include "Tests/quickCheckCompat.h"
39 40
@@ -224,11 +225,16 @@ runTests :: Int -- ^ maximum dimension
224runTests n = do 225runTests n = do
225 setErrorHandlerOff 226 setErrorHandlerOff
226 let test p = qCheck n p 227 let test p = qCheck n p
227 putStrLn "------ mult" 228 putStrLn "------ mult Double"
228 test (multProp1 . rConsist) 229 test (multProp1 10 . rConsist)
229 test (multProp1 . cConsist) 230 test (multProp1 10 . cConsist)
230 test (multProp2 . rConsist) 231 test (multProp2 10 . rConsist)
231 test (multProp2 . cConsist) 232 test (multProp2 10 . cConsist)
233 putStrLn "------ mult Float"
234 test (multProp1 6 . (single *** single) . rConsist)
235 test (multProp1 6 . (single *** single) . cConsist)
236 test (multProp2 6 . (single *** single) . rConsist)
237 test (multProp2 6 . (single *** single) . cConsist)
232 putStrLn "------ sub-trans" 238 putStrLn "------ sub-trans"
233 test (subProp . rM) 239 test (subProp . rM)
234 test (subProp . cM) 240 test (subProp . cM)
diff --git a/lib/Numeric/LinearAlgebra/Tests/Properties.hs b/lib/Numeric/LinearAlgebra/Tests/Properties.hs
index d29e19a..f7a948e 100644
--- a/lib/Numeric/LinearAlgebra/Tests/Properties.hs
+++ b/lib/Numeric/LinearAlgebra/Tests/Properties.hs
@@ -42,7 +42,7 @@ module Numeric.LinearAlgebra.Tests.Properties (
42 linearSolveProp, linearSolveProp2 42 linearSolveProp, linearSolveProp2
43) where 43) where
44 44
45import Numeric.LinearAlgebra 45import Numeric.LinearAlgebra hiding (mulH)
46import Numeric.LinearAlgebra.LAPACK 46import Numeric.LinearAlgebra.LAPACK
47import Debug.Trace 47import Debug.Trace
48#include "quickCheckCompat.h" 48#include "quickCheckCompat.h"
@@ -237,9 +237,9 @@ expmDiagProp m = expm (logm m) :~ 7 ~: complex m
237mulH a b = fromLists [[ doth ai bj | bj <- toColumns b] | ai <- toRows a ] 237mulH a b = fromLists [[ doth ai bj | bj <- toColumns b] | ai <- toRows a ]
238 where doth u v = sum $ zipWith (*) (toList u) (toList v) 238 where doth u v = sum $ zipWith (*) (toList u) (toList v)
239 239
240multProp1 (a,b) = a <> b |~| mulH a b 240multProp1 p (a,b) = (a <> b) :~p~: (mulH a b)
241 241
242multProp2 (a,b) = ctrans (a <> b) |~| ctrans b <> ctrans a 242multProp2 p (a,b) = (ctrans (a <> b)) :~p~: (ctrans b <> ctrans a)
243 243
244linearSolveProp f m = f m m |~| ident (rows m) 244linearSolveProp f m = f m m |~| ident (rows m)
245 245