diff options
author | Alberto Ruiz <aruiz@um.es> | 2008-10-02 15:53:10 +0000 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2008-10-02 15:53:10 +0000 |
commit | 192ac5f4b98517862c37ecf161505396ad223cd8 (patch) | |
tree | 811312f28bca2bd18d282bc0be732a17cd8dbcd7 /lib/Data/Packed/Internal | |
parent | 9c6b2af0066f7608301ad685ea5e60753fc3b6ff (diff) |
alternative multiply versions
Diffstat (limited to 'lib/Data/Packed/Internal')
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 37 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/auxi.c | 90 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/auxi.h | 6 |
3 files changed, 20 insertions, 113 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index caf3699..45a3955 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -212,7 +212,6 @@ compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2 | |||
212 | class (Storable a, Floating a) => Element a where | 212 | class (Storable a, Floating a) => Element a where |
213 | constantD :: a -> Int -> Vector a | 213 | constantD :: a -> Int -> Vector a |
214 | transdata :: Int -> Vector a -> Int -> Vector a | 214 | transdata :: Int -> Vector a -> Int -> Vector a |
215 | multiplyD :: Matrix a -> Matrix a -> Matrix a | ||
216 | subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position | 215 | subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position |
217 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix | 216 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix |
218 | -> Matrix a -> Matrix a | 217 | -> Matrix a -> Matrix a |
@@ -221,14 +220,12 @@ class (Storable a, Floating a) => Element a where | |||
221 | instance Element Double where | 220 | instance Element Double where |
222 | constantD = constantR | 221 | constantD = constantR |
223 | transdata = transdataR | 222 | transdata = transdataR |
224 | multiplyD = multiplyR | ||
225 | subMatrixD = subMatrixR | 223 | subMatrixD = subMatrixR |
226 | diagD = diagR | 224 | diagD = diagR |
227 | 225 | ||
228 | instance Element (Complex Double) where | 226 | instance Element (Complex Double) where |
229 | constantD = constantC | 227 | constantD = constantC |
230 | transdata = transdataC | 228 | transdata = transdataC |
231 | multiplyD = multiplyC | ||
232 | subMatrixD = subMatrixC | 229 | subMatrixD = subMatrixC |
233 | diagD = diagC | 230 | diagD = diagC |
234 | 231 | ||
@@ -266,33 +263,6 @@ transdataAux fun c1 d c2 = | |||
266 | foreign import ccall "auxi.h transR" ctransR :: TMM | 263 | foreign import ccall "auxi.h transR" ctransR :: TMM |
267 | foreign import ccall "auxi.h transC" ctransC :: TCMCM | 264 | foreign import ccall "auxi.h transC" ctransC :: TCMCM |
268 | 265 | ||
269 | ------------------------------------------------------------------ | ||
270 | |||
271 | gmatC MF { rows = r, cols = c } p f = f 1 (fi c) (fi r) p | ||
272 | gmatC MC { rows = r, cols = c } p f = f 0 (fi r) (fi c) p | ||
273 | |||
274 | dtt MC { cdat = d } = d | ||
275 | dtt MF { fdat = d } = d | ||
276 | |||
277 | multiplyAux fun a b = unsafePerformIO $ do | ||
278 | when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ | ||
279 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) | ||
280 | r <- createMatrix RowMajor (rows a) (cols b) | ||
281 | withForeignPtr (fptr (dtt a)) $ \pa -> withForeignPtr (fptr (dtt b)) $ \pb -> | ||
282 | withMatrix r $ \r' -> | ||
283 | fun // gmatC a pa // gmatC b pb // r' // check "multiplyAux" | ||
284 | return r | ||
285 | |||
286 | multiplyR = multiplyAux cmultiplyR | ||
287 | foreign import ccall "auxi.h multiplyR" cmultiplyR :: TauxMul Double | ||
288 | |||
289 | multiplyC = multiplyAux cmultiplyC | ||
290 | foreign import ccall "auxi.h multiplyC" cmultiplyC :: TauxMul (Complex Double) | ||
291 | |||
292 | -- | matrix product | ||
293 | multiply :: (Element a) => Matrix a -> Matrix a -> Matrix a | ||
294 | multiply = multiplyD | ||
295 | |||
296 | ---------------------------------------------------------------------- | 266 | ---------------------------------------------------------------------- |
297 | 267 | ||
298 | -- | extraction of a submatrix from a real matrix | 268 | -- | extraction of a submatrix from a real matrix |
@@ -370,7 +340,12 @@ constant = constantD | |||
370 | 340 | ||
371 | -- | obtains the complex conjugate of a complex vector | 341 | -- | obtains the complex conjugate of a complex vector |
372 | conj :: Vector (Complex Double) -> Vector (Complex Double) | 342 | conj :: Vector (Complex Double) -> Vector (Complex Double) |
373 | conj v = asComplex $ flatten $ reshape 2 (asReal v) `multiply` diag (fromList [1,-1]) | 343 | conj v = unsafePerformIO $ do |
344 | r <- createVector (dim v) | ||
345 | app2 cconjugate vec v vec r "cconjugate" | ||
346 | return r | ||
347 | foreign import ccall "auxi.h conjugate" cconjugate :: TCVCV | ||
348 | |||
374 | 349 | ||
375 | -- | creates a complex vector from vectors with real and imaginary parts | 350 | -- | creates a complex vector from vectors with real and imaginary parts |
376 | toComplex :: (Vector Double, Vector Double) -> Vector (Complex Double) | 351 | toComplex :: (Vector Double, Vector Double) -> Vector (Complex Double) |
diff --git a/lib/Data/Packed/Internal/auxi.c b/lib/Data/Packed/Internal/auxi.c index 7f83bcf..04dc7ad 100644 --- a/lib/Data/Packed/Internal/auxi.c +++ b/lib/Data/Packed/Internal/auxi.c | |||
@@ -4,14 +4,9 @@ | |||
4 | #include <gsl/gsl_matrix.h> | 4 | #include <gsl/gsl_matrix.h> |
5 | #include <gsl/gsl_math.h> | 5 | #include <gsl/gsl_math.h> |
6 | #include <gsl/gsl_errno.h> | 6 | #include <gsl/gsl_errno.h> |
7 | #include <gsl/gsl_fft_complex.h> | ||
8 | #include <gsl/gsl_eigen.h> | ||
9 | #include <gsl/gsl_integration.h> | ||
10 | #include <gsl/gsl_deriv.h> | ||
11 | #include <gsl/gsl_poly.h> | ||
12 | #include <gsl/gsl_multimin.h> | ||
13 | #include <gsl/gsl_complex.h> | 7 | #include <gsl/gsl_complex.h> |
14 | #include <gsl/gsl_complex_math.h> | 8 | #include <gsl/gsl_complex_math.h> |
9 | #include <gsl/gsl_cblas.h> | ||
15 | #include <string.h> | 10 | #include <string.h> |
16 | #include <stdio.h> | 11 | #include <stdio.h> |
17 | 12 | ||
@@ -118,78 +113,6 @@ int constantC(gsl_complex* pval, CVEC(r)) { | |||
118 | } | 113 | } |
119 | 114 | ||
120 | 115 | ||
121 | int multiplyR(int ta, KRMAT(a), int tb, KRMAT(b),RMAT(r)) { | ||
122 | //printf("%d %d %d %d %d %d\n",ar,ac,br,bc,rr,rc); | ||
123 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
124 | DEBUGMSG("multiplyR (gsl_blas_dgemm)"); | ||
125 | KDMVIEW(a); | ||
126 | KDMVIEW(b); | ||
127 | DMVIEW(r); | ||
128 | int k; | ||
129 | for(k=0;k<rr*rc;k++) rp[k]=0; | ||
130 | int debug = 0; | ||
131 | if(debug) { | ||
132 | printf("---------------------------\n"); | ||
133 | printf("%p: ",ap); for(k=0;k<ar*ac;k++) printf("%f ",ap[k]); printf("\n"); | ||
134 | printf("%p: ",bp); for(k=0;k<br*bc;k++) printf("%f ",bp[k]); printf("\n"); | ||
135 | printf("%p: ",rp); for(k=0;k<rr*rc;k++) printf("%f ",rp[k]); printf("\n"); | ||
136 | } | ||
137 | int res = gsl_blas_dgemm( | ||
138 | ta?CblasTrans:CblasNoTrans, | ||
139 | tb?CblasTrans:CblasNoTrans, | ||
140 | 1.0, M(a), M(b), | ||
141 | 0.0, M(r)); | ||
142 | if(debug) { | ||
143 | printf("--------------\n"); | ||
144 | printf("%p: ",ap); for(k=0;k<ar*ac;k++) printf("%f ",ap[k]); printf("\n"); | ||
145 | printf("%p: ",bp); for(k=0;k<br*bc;k++) printf("%f ",bp[k]); printf("\n"); | ||
146 | printf("%p: ",rp); for(k=0;k<rr*rc;k++) printf("%f ",rp[k]); printf("\n"); | ||
147 | } | ||
148 | CHECK(res,res); | ||
149 | OK | ||
150 | } | ||
151 | |||
152 | int multiplyC(int ta, KCMAT(a), int tb, KCMAT(b),CMAT(r)) { | ||
153 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
154 | DEBUGMSG("multiplyC (gsl_blas_zgemm)"); | ||
155 | KCMVIEW(a); | ||
156 | KCMVIEW(b); | ||
157 | CMVIEW(r); | ||
158 | int k; | ||
159 | gsl_complex alpha, beta; | ||
160 | GSL_SET_COMPLEX(&alpha,1.,0.); | ||
161 | GSL_SET_COMPLEX(&beta,0.,0.); | ||
162 | //double *TEMP = (double*)malloc(rr*rc*2*sizeof(double)); | ||
163 | //gsl_matrix_complex_view T = gsl_matrix_complex_view_array(TEMP,rr,rc); | ||
164 | for(k=0;k<rr*rc;k++) rp[k]=beta; | ||
165 | //for(k=0;k<2*rr*rc;k++) TEMP[k]=0; | ||
166 | int debug = 0; | ||
167 | if(debug) { | ||
168 | printf("---------------------------\n"); | ||
169 | printf("%p: ",ap); for(k=0;k<2*ar*ac;k++) printf("%f ",((double*)ap)[k]); printf("\n"); | ||
170 | printf("%p: ",bp); for(k=0;k<2*br*bc;k++) printf("%f ",((double*)bp)[k]); printf("\n"); | ||
171 | printf("%p: ",rp); for(k=0;k<2*rr*rc;k++) printf("%f ",((double*)rp)[k]); printf("\n"); | ||
172 | //printf("%p: ",T); for(k=0;k<2*rr*rc;k++) printf("%f ",TEMP[k]); printf("\n"); | ||
173 | } | ||
174 | int res = gsl_blas_zgemm( | ||
175 | ta?CblasTrans:CblasNoTrans, | ||
176 | tb?CblasTrans:CblasNoTrans, | ||
177 | alpha, M(a), M(b), | ||
178 | beta, M(r)); | ||
179 | //&T.matrix); | ||
180 | //memcpy(rp,TEMP,2*rr*rc*sizeof(double)); | ||
181 | if(debug) { | ||
182 | printf("--------------\n"); | ||
183 | printf("%p: ",ap); for(k=0;k<2*ar*ac;k++) printf("%f ",((double*)ap)[k]); printf("\n"); | ||
184 | printf("%p: ",bp); for(k=0;k<2*br*bc;k++) printf("%f ",((double*)bp)[k]); printf("\n"); | ||
185 | printf("%p: ",rp); for(k=0;k<2*rr*rc;k++) printf("%f ",((double*)rp)[k]); printf("\n"); | ||
186 | //printf("%p: ",T); for(k=0;k<2*rr*rc;k++) printf("%f ",TEMP[k]); printf("\n"); | ||
187 | } | ||
188 | CHECK(res,res); | ||
189 | OK | ||
190 | } | ||
191 | |||
192 | |||
193 | int diagR(KRVEC(d),RMAT(r)) { | 116 | int diagR(KRVEC(d),RMAT(r)) { |
194 | REQUIRES(dn==rr && rr==rc,BAD_SIZE); | 117 | REQUIRES(dn==rr && rr==rc,BAD_SIZE); |
195 | DEBUGMSG("diagR"); | 118 | DEBUGMSG("diagR"); |
@@ -215,3 +138,14 @@ int diagC(KCVEC(d),CMAT(r)) { | |||
215 | } | 138 | } |
216 | OK | 139 | OK |
217 | } | 140 | } |
141 | |||
142 | int conjugate(KCVEC(x),CVEC(t)) { | ||
143 | REQUIRES(xn==tn,BAD_SIZE); | ||
144 | DEBUGMSG("conjugate"); | ||
145 | int k; | ||
146 | for (k=0; k<xn; k++) { | ||
147 | tp[k].dat[0] = xp[k].dat[0]; | ||
148 | tp[k].dat[1] = - xp[k].dat[1]; | ||
149 | } | ||
150 | OK | ||
151 | } | ||
diff --git a/lib/Data/Packed/Internal/auxi.h b/lib/Data/Packed/Internal/auxi.h index 73334e3..377a4a1 100644 --- a/lib/Data/Packed/Internal/auxi.h +++ b/lib/Data/Packed/Internal/auxi.h | |||
@@ -10,16 +10,12 @@ | |||
10 | #define KCVEC(A) int A##n, const gsl_complex*A##p | 10 | #define KCVEC(A) int A##n, const gsl_complex*A##p |
11 | #define KCMAT(A) int A##r, int A##c, const gsl_complex* A##p | 11 | #define KCMAT(A) int A##r, int A##c, const gsl_complex* A##p |
12 | 12 | ||
13 | |||
14 | int transR(KRMAT(x),RMAT(t)); | 13 | int transR(KRMAT(x),RMAT(t)); |
15 | int transC(KCMAT(x),CMAT(t)); | 14 | int transC(KCMAT(x),CMAT(t)); |
16 | 15 | ||
17 | int constantR(double *val , RVEC(r)); | 16 | int constantR(double *val , RVEC(r)); |
18 | int constantC(gsl_complex *val, CVEC(r)); | 17 | int constantC(gsl_complex *val, CVEC(r)); |
19 | 18 | ||
20 | int multiplyR(int ta, KRMAT(a), int tb, KRMAT(b),RMAT(r)); | ||
21 | int multiplyC(int ta, KCMAT(a), int tb, KCMAT(b),CMAT(r)); | ||
22 | |||
23 | int submatrixR(int r1, int r2, int c1, int c2, KRMAT(x),RMAT(r)); | 19 | int submatrixR(int r1, int r2, int c1, int c2, KRMAT(x),RMAT(r)); |
24 | 20 | ||
25 | int diagR(KRVEC(d),RMAT(r)); | 21 | int diagR(KRVEC(d),RMAT(r)); |
@@ -28,3 +24,5 @@ int diagC(KCVEC(d),CMAT(r)); | |||
28 | const char * gsl_strerror (const int gsl_errno); | 24 | const char * gsl_strerror (const int gsl_errno); |
29 | 25 | ||
30 | int matrix_fscanf(char*filename, RMAT(a)); | 26 | int matrix_fscanf(char*filename, RMAT(a)); |
27 | |||
28 | int conjugate(KCVEC(x),CVEC(t)); | ||