summaryrefslogtreecommitdiff
path: root/lib/Data
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2008-10-02 15:53:10 +0000
committerAlberto Ruiz <aruiz@um.es>2008-10-02 15:53:10 +0000
commit192ac5f4b98517862c37ecf161505396ad223cd8 (patch)
tree811312f28bca2bd18d282bc0be732a17cd8dbcd7 /lib/Data
parent9c6b2af0066f7608301ad685ea5e60753fc3b6ff (diff)
alternative multiply versions
Diffstat (limited to 'lib/Data')
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs37
-rw-r--r--lib/Data/Packed/Internal/auxi.c90
-rw-r--r--lib/Data/Packed/Internal/auxi.h6
3 files changed, 20 insertions, 113 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index caf3699..45a3955 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -212,7 +212,6 @@ compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2
212class (Storable a, Floating a) => Element a where 212class (Storable a, Floating a) => Element a where
213 constantD :: a -> Int -> Vector a 213 constantD :: a -> Int -> Vector a
214 transdata :: Int -> Vector a -> Int -> Vector a 214 transdata :: Int -> Vector a -> Int -> Vector a
215 multiplyD :: Matrix a -> Matrix a -> Matrix a
216 subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position 215 subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position
217 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix 216 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
218 -> Matrix a -> Matrix a 217 -> Matrix a -> Matrix a
@@ -221,14 +220,12 @@ class (Storable a, Floating a) => Element a where
221instance Element Double where 220instance Element Double where
222 constantD = constantR 221 constantD = constantR
223 transdata = transdataR 222 transdata = transdataR
224 multiplyD = multiplyR
225 subMatrixD = subMatrixR 223 subMatrixD = subMatrixR
226 diagD = diagR 224 diagD = diagR
227 225
228instance Element (Complex Double) where 226instance Element (Complex Double) where
229 constantD = constantC 227 constantD = constantC
230 transdata = transdataC 228 transdata = transdataC
231 multiplyD = multiplyC
232 subMatrixD = subMatrixC 229 subMatrixD = subMatrixC
233 diagD = diagC 230 diagD = diagC
234 231
@@ -266,33 +263,6 @@ transdataAux fun c1 d c2 =
266foreign import ccall "auxi.h transR" ctransR :: TMM 263foreign import ccall "auxi.h transR" ctransR :: TMM
267foreign import ccall "auxi.h transC" ctransC :: TCMCM 264foreign import ccall "auxi.h transC" ctransC :: TCMCM
268 265
269------------------------------------------------------------------
270
271gmatC MF { rows = r, cols = c } p f = f 1 (fi c) (fi r) p
272gmatC MC { rows = r, cols = c } p f = f 0 (fi r) (fi c) p
273
274dtt MC { cdat = d } = d
275dtt MF { fdat = d } = d
276
277multiplyAux fun a b = unsafePerformIO $ do
278 when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++
279 show (rows a,cols a) ++ " x " ++ show (rows b, cols b)
280 r <- createMatrix RowMajor (rows a) (cols b)
281 withForeignPtr (fptr (dtt a)) $ \pa -> withForeignPtr (fptr (dtt b)) $ \pb ->
282 withMatrix r $ \r' ->
283 fun // gmatC a pa // gmatC b pb // r' // check "multiplyAux"
284 return r
285
286multiplyR = multiplyAux cmultiplyR
287foreign import ccall "auxi.h multiplyR" cmultiplyR :: TauxMul Double
288
289multiplyC = multiplyAux cmultiplyC
290foreign import ccall "auxi.h multiplyC" cmultiplyC :: TauxMul (Complex Double)
291
292-- | matrix product
293multiply :: (Element a) => Matrix a -> Matrix a -> Matrix a
294multiply = multiplyD
295
296---------------------------------------------------------------------- 266----------------------------------------------------------------------
297 267
298-- | extraction of a submatrix from a real matrix 268-- | extraction of a submatrix from a real matrix
@@ -370,7 +340,12 @@ constant = constantD
370 340
371-- | obtains the complex conjugate of a complex vector 341-- | obtains the complex conjugate of a complex vector
372conj :: Vector (Complex Double) -> Vector (Complex Double) 342conj :: Vector (Complex Double) -> Vector (Complex Double)
373conj v = asComplex $ flatten $ reshape 2 (asReal v) `multiply` diag (fromList [1,-1]) 343conj v = unsafePerformIO $ do
344 r <- createVector (dim v)
345 app2 cconjugate vec v vec r "cconjugate"
346 return r
347foreign import ccall "auxi.h conjugate" cconjugate :: TCVCV
348
374 349
375-- | creates a complex vector from vectors with real and imaginary parts 350-- | creates a complex vector from vectors with real and imaginary parts
376toComplex :: (Vector Double, Vector Double) -> Vector (Complex Double) 351toComplex :: (Vector Double, Vector Double) -> Vector (Complex Double)
diff --git a/lib/Data/Packed/Internal/auxi.c b/lib/Data/Packed/Internal/auxi.c
index 7f83bcf..04dc7ad 100644
--- a/lib/Data/Packed/Internal/auxi.c
+++ b/lib/Data/Packed/Internal/auxi.c
@@ -4,14 +4,9 @@
4#include <gsl/gsl_matrix.h> 4#include <gsl/gsl_matrix.h>
5#include <gsl/gsl_math.h> 5#include <gsl/gsl_math.h>
6#include <gsl/gsl_errno.h> 6#include <gsl/gsl_errno.h>
7#include <gsl/gsl_fft_complex.h>
8#include <gsl/gsl_eigen.h>
9#include <gsl/gsl_integration.h>
10#include <gsl/gsl_deriv.h>
11#include <gsl/gsl_poly.h>
12#include <gsl/gsl_multimin.h>
13#include <gsl/gsl_complex.h> 7#include <gsl/gsl_complex.h>
14#include <gsl/gsl_complex_math.h> 8#include <gsl/gsl_complex_math.h>
9#include <gsl/gsl_cblas.h>
15#include <string.h> 10#include <string.h>
16#include <stdio.h> 11#include <stdio.h>
17 12
@@ -118,78 +113,6 @@ int constantC(gsl_complex* pval, CVEC(r)) {
118} 113}
119 114
120 115
121int multiplyR(int ta, KRMAT(a), int tb, KRMAT(b),RMAT(r)) {
122 //printf("%d %d %d %d %d %d\n",ar,ac,br,bc,rr,rc);
123 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
124 DEBUGMSG("multiplyR (gsl_blas_dgemm)");
125 KDMVIEW(a);
126 KDMVIEW(b);
127 DMVIEW(r);
128 int k;
129 for(k=0;k<rr*rc;k++) rp[k]=0;
130 int debug = 0;
131 if(debug) {
132 printf("---------------------------\n");
133 printf("%p: ",ap); for(k=0;k<ar*ac;k++) printf("%f ",ap[k]); printf("\n");
134 printf("%p: ",bp); for(k=0;k<br*bc;k++) printf("%f ",bp[k]); printf("\n");
135 printf("%p: ",rp); for(k=0;k<rr*rc;k++) printf("%f ",rp[k]); printf("\n");
136 }
137 int res = gsl_blas_dgemm(
138 ta?CblasTrans:CblasNoTrans,
139 tb?CblasTrans:CblasNoTrans,
140 1.0, M(a), M(b),
141 0.0, M(r));
142 if(debug) {
143 printf("--------------\n");
144 printf("%p: ",ap); for(k=0;k<ar*ac;k++) printf("%f ",ap[k]); printf("\n");
145 printf("%p: ",bp); for(k=0;k<br*bc;k++) printf("%f ",bp[k]); printf("\n");
146 printf("%p: ",rp); for(k=0;k<rr*rc;k++) printf("%f ",rp[k]); printf("\n");
147 }
148 CHECK(res,res);
149 OK
150}
151
152int multiplyC(int ta, KCMAT(a), int tb, KCMAT(b),CMAT(r)) {
153 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
154 DEBUGMSG("multiplyC (gsl_blas_zgemm)");
155 KCMVIEW(a);
156 KCMVIEW(b);
157 CMVIEW(r);
158 int k;
159 gsl_complex alpha, beta;
160 GSL_SET_COMPLEX(&alpha,1.,0.);
161 GSL_SET_COMPLEX(&beta,0.,0.);
162 //double *TEMP = (double*)malloc(rr*rc*2*sizeof(double));
163 //gsl_matrix_complex_view T = gsl_matrix_complex_view_array(TEMP,rr,rc);
164 for(k=0;k<rr*rc;k++) rp[k]=beta;
165 //for(k=0;k<2*rr*rc;k++) TEMP[k]=0;
166 int debug = 0;
167 if(debug) {
168 printf("---------------------------\n");
169 printf("%p: ",ap); for(k=0;k<2*ar*ac;k++) printf("%f ",((double*)ap)[k]); printf("\n");
170 printf("%p: ",bp); for(k=0;k<2*br*bc;k++) printf("%f ",((double*)bp)[k]); printf("\n");
171 printf("%p: ",rp); for(k=0;k<2*rr*rc;k++) printf("%f ",((double*)rp)[k]); printf("\n");
172 //printf("%p: ",T); for(k=0;k<2*rr*rc;k++) printf("%f ",TEMP[k]); printf("\n");
173 }
174 int res = gsl_blas_zgemm(
175 ta?CblasTrans:CblasNoTrans,
176 tb?CblasTrans:CblasNoTrans,
177 alpha, M(a), M(b),
178 beta, M(r));
179 //&T.matrix);
180 //memcpy(rp,TEMP,2*rr*rc*sizeof(double));
181 if(debug) {
182 printf("--------------\n");
183 printf("%p: ",ap); for(k=0;k<2*ar*ac;k++) printf("%f ",((double*)ap)[k]); printf("\n");
184 printf("%p: ",bp); for(k=0;k<2*br*bc;k++) printf("%f ",((double*)bp)[k]); printf("\n");
185 printf("%p: ",rp); for(k=0;k<2*rr*rc;k++) printf("%f ",((double*)rp)[k]); printf("\n");
186 //printf("%p: ",T); for(k=0;k<2*rr*rc;k++) printf("%f ",TEMP[k]); printf("\n");
187 }
188 CHECK(res,res);
189 OK
190}
191
192
193int diagR(KRVEC(d),RMAT(r)) { 116int diagR(KRVEC(d),RMAT(r)) {
194 REQUIRES(dn==rr && rr==rc,BAD_SIZE); 117 REQUIRES(dn==rr && rr==rc,BAD_SIZE);
195 DEBUGMSG("diagR"); 118 DEBUGMSG("diagR");
@@ -215,3 +138,14 @@ int diagC(KCVEC(d),CMAT(r)) {
215 } 138 }
216 OK 139 OK
217} 140}
141
142int conjugate(KCVEC(x),CVEC(t)) {
143 REQUIRES(xn==tn,BAD_SIZE);
144 DEBUGMSG("conjugate");
145 int k;
146 for (k=0; k<xn; k++) {
147 tp[k].dat[0] = xp[k].dat[0];
148 tp[k].dat[1] = - xp[k].dat[1];
149 }
150 OK
151}
diff --git a/lib/Data/Packed/Internal/auxi.h b/lib/Data/Packed/Internal/auxi.h
index 73334e3..377a4a1 100644
--- a/lib/Data/Packed/Internal/auxi.h
+++ b/lib/Data/Packed/Internal/auxi.h
@@ -10,16 +10,12 @@
10#define KCVEC(A) int A##n, const gsl_complex*A##p 10#define KCVEC(A) int A##n, const gsl_complex*A##p
11#define KCMAT(A) int A##r, int A##c, const gsl_complex* A##p 11#define KCMAT(A) int A##r, int A##c, const gsl_complex* A##p
12 12
13
14int transR(KRMAT(x),RMAT(t)); 13int transR(KRMAT(x),RMAT(t));
15int transC(KCMAT(x),CMAT(t)); 14int transC(KCMAT(x),CMAT(t));
16 15
17int constantR(double *val , RVEC(r)); 16int constantR(double *val , RVEC(r));
18int constantC(gsl_complex *val, CVEC(r)); 17int constantC(gsl_complex *val, CVEC(r));
19 18
20int multiplyR(int ta, KRMAT(a), int tb, KRMAT(b),RMAT(r));
21int multiplyC(int ta, KCMAT(a), int tb, KCMAT(b),CMAT(r));
22
23int submatrixR(int r1, int r2, int c1, int c2, KRMAT(x),RMAT(r)); 19int submatrixR(int r1, int r2, int c1, int c2, KRMAT(x),RMAT(r));
24 20
25int diagR(KRVEC(d),RMAT(r)); 21int diagR(KRVEC(d),RMAT(r));
@@ -28,3 +24,5 @@ int diagC(KCVEC(d),CMAT(r));
28const char * gsl_strerror (const int gsl_errno); 24const char * gsl_strerror (const int gsl_errno);
29 25
30int matrix_fscanf(char*filename, RMAT(a)); 26int matrix_fscanf(char*filename, RMAT(a));
27
28int conjugate(KCVEC(x),CVEC(t));