diff options
Diffstat (limited to 'packages/base/src')
-rw-r--r-- | packages/base/src/C/lapack-aux.c | 1489 | ||||
-rw-r--r-- | packages/base/src/C/lapack-aux.h | 60 | ||||
-rw-r--r-- | packages/base/src/Data/Packed.hs | 25 | ||||
-rw-r--r-- | packages/base/src/Data/Packed/Development.hs | 31 | ||||
-rw-r--r-- | packages/base/src/Data/Packed/Foreign.hs | 99 | ||||
-rw-r--r-- | packages/base/src/Data/Packed/Internal.hs | 26 | ||||
-rw-r--r-- | packages/base/src/Data/Packed/Internal/Common.hs | 160 | ||||
-rw-r--r-- | packages/base/src/Data/Packed/Internal/Matrix.hs | 422 | ||||
-rw-r--r-- | packages/base/src/Data/Packed/Internal/Signatures.hs | 70 | ||||
-rw-r--r-- | packages/base/src/Data/Packed/Internal/Vector.hs | 471 | ||||
-rw-r--r-- | packages/base/src/Data/Packed/Matrix.hs | 490 | ||||
-rw-r--r-- | packages/base/src/Data/Packed/ST.hs | 178 | ||||
-rw-r--r-- | packages/base/src/Data/Packed/Vector.hs | 96 |
13 files changed, 3617 insertions, 0 deletions
diff --git a/packages/base/src/C/lapack-aux.c b/packages/base/src/C/lapack-aux.c new file mode 100644 index 0000000..e5e45ef --- /dev/null +++ b/packages/base/src/C/lapack-aux.c | |||
@@ -0,0 +1,1489 @@ | |||
1 | #include <stdio.h> | ||
2 | #include <stdlib.h> | ||
3 | #include <string.h> | ||
4 | #include <math.h> | ||
5 | #include <time.h> | ||
6 | #include "lapack-aux.h" | ||
7 | |||
8 | #define MACRO(B) do {B} while (0) | ||
9 | #define ERROR(CODE) MACRO(return CODE;) | ||
10 | #define REQUIRES(COND, CODE) MACRO(if(!(COND)) {ERROR(CODE);}) | ||
11 | |||
12 | #define MIN(A,B) ((A)<(B)?(A):(B)) | ||
13 | #define MAX(A,B) ((A)>(B)?(A):(B)) | ||
14 | |||
15 | // #define DBGL | ||
16 | |||
17 | #ifdef DBGL | ||
18 | #define DEBUGMSG(M) printf("\nLAPACK "M"\n"); | ||
19 | #else | ||
20 | #define DEBUGMSG(M) | ||
21 | #endif | ||
22 | |||
23 | #define OK return 0; | ||
24 | |||
25 | // #ifdef DBGL | ||
26 | // #define DEBUGMSG(M) printf("LAPACK Wrapper "M"\n: "); size_t t0 = time(NULL); | ||
27 | // #define OK MACRO(printf("%ld s\n",time(0)-t0); return 0;); | ||
28 | // #else | ||
29 | // #define DEBUGMSG(M) | ||
30 | // #define OK return 0; | ||
31 | // #endif | ||
32 | |||
33 | #define TRACEMAT(M) {int q; printf(" %d x %d: ",M##r,M##c); \ | ||
34 | for(q=0;q<M##r*M##c;q++) printf("%.1f ",M##p[q]); printf("\n");} | ||
35 | |||
36 | #define CHECK(RES,CODE) MACRO(if(RES) return CODE;) | ||
37 | |||
38 | #define BAD_SIZE 2000 | ||
39 | #define BAD_CODE 2001 | ||
40 | #define MEM 2002 | ||
41 | #define BAD_FILE 2003 | ||
42 | #define SINGULAR 2004 | ||
43 | #define NOCONVER 2005 | ||
44 | #define NODEFPOS 2006 | ||
45 | #define NOSPRTD 2007 | ||
46 | |||
47 | //--------------------------------------- | ||
48 | void asm_finit() { | ||
49 | #ifdef i386 | ||
50 | |||
51 | // asm("finit"); | ||
52 | |||
53 | static unsigned char buf[108]; | ||
54 | asm("FSAVE %0":"=m" (buf)); | ||
55 | |||
56 | #if FPUDEBUG | ||
57 | if(buf[8]!=255 || buf[9]!=255) { // print warning in red | ||
58 | printf("%c[;31mWarning: FPU TAG = %x %x\%c[0m\n",0x1B,buf[8],buf[9],0x1B); | ||
59 | } | ||
60 | #endif | ||
61 | |||
62 | #if NANDEBUG | ||
63 | asm("FRSTOR %0":"=m" (buf)); | ||
64 | #endif | ||
65 | |||
66 | #endif | ||
67 | } | ||
68 | |||
69 | //--------------------------------------- | ||
70 | |||
71 | #if NANDEBUG | ||
72 | |||
73 | #define CHECKNANR(M,msg) \ | ||
74 | { int k; \ | ||
75 | for(k=0; k<(M##r * M##c); k++) { \ | ||
76 | if(M##p[k] != M##p[k]) { \ | ||
77 | printf(msg); \ | ||
78 | TRACEMAT(M) \ | ||
79 | /*exit(1);*/ \ | ||
80 | } \ | ||
81 | } \ | ||
82 | } | ||
83 | |||
84 | #define CHECKNANC(M,msg) \ | ||
85 | { int k; \ | ||
86 | for(k=0; k<(M##r * M##c); k++) { \ | ||
87 | if( M##p[k].r != M##p[k].r \ | ||
88 | || M##p[k].i != M##p[k].i) { \ | ||
89 | printf(msg); \ | ||
90 | /*exit(1);*/ \ | ||
91 | } \ | ||
92 | } \ | ||
93 | } | ||
94 | |||
95 | #else | ||
96 | #define CHECKNANC(M,msg) | ||
97 | #define CHECKNANR(M,msg) | ||
98 | #endif | ||
99 | |||
100 | //--------------------------------------- | ||
101 | |||
102 | //////////////////// real svd //////////////////////////////////// | ||
103 | |||
104 | /* Subroutine */ int dgesvd_(char *jobu, char *jobvt, integer *m, integer *n, | ||
105 | doublereal *a, integer *lda, doublereal *s, doublereal *u, integer * | ||
106 | ldu, doublereal *vt, integer *ldvt, doublereal *work, integer *lwork, | ||
107 | integer *info); | ||
108 | |||
109 | int svd_l_R(KDMAT(a),DMAT(u), DVEC(s),DMAT(v)) { | ||
110 | integer m = ar; | ||
111 | integer n = ac; | ||
112 | integer q = MIN(m,n); | ||
113 | REQUIRES(sn==q,BAD_SIZE); | ||
114 | REQUIRES(up==NULL || (ur==m && (uc==m || uc==q)),BAD_SIZE); | ||
115 | char* jobu = "A"; | ||
116 | if (up==NULL) { | ||
117 | jobu = "N"; | ||
118 | } else { | ||
119 | if (uc==q) { | ||
120 | jobu = "S"; | ||
121 | } | ||
122 | } | ||
123 | REQUIRES(vp==NULL || (vc==n && (vr==n || vr==q)),BAD_SIZE); | ||
124 | char* jobvt = "A"; | ||
125 | integer ldvt = n; | ||
126 | if (vp==NULL) { | ||
127 | jobvt = "N"; | ||
128 | } else { | ||
129 | if (vr==q) { | ||
130 | jobvt = "S"; | ||
131 | ldvt = q; | ||
132 | } | ||
133 | } | ||
134 | DEBUGMSG("svd_l_R"); | ||
135 | double *B = (double*)malloc(m*n*sizeof(double)); | ||
136 | CHECK(!B,MEM); | ||
137 | memcpy(B,ap,m*n*sizeof(double)); | ||
138 | integer lwork = -1; | ||
139 | integer res; | ||
140 | // ask for optimal lwork | ||
141 | double ans; | ||
142 | dgesvd_ (jobu,jobvt, | ||
143 | &m,&n,B,&m, | ||
144 | sp, | ||
145 | up,&m, | ||
146 | vp,&ldvt, | ||
147 | &ans, &lwork, | ||
148 | &res); | ||
149 | lwork = ceil(ans); | ||
150 | double * work = (double*)malloc(lwork*sizeof(double)); | ||
151 | CHECK(!work,MEM); | ||
152 | dgesvd_ (jobu,jobvt, | ||
153 | &m,&n,B,&m, | ||
154 | sp, | ||
155 | up,&m, | ||
156 | vp,&ldvt, | ||
157 | work, &lwork, | ||
158 | &res); | ||
159 | CHECK(res,res); | ||
160 | free(work); | ||
161 | free(B); | ||
162 | OK | ||
163 | } | ||
164 | |||
165 | // (alternative version) | ||
166 | |||
167 | /* Subroutine */ int dgesdd_(char *jobz, integer *m, integer *n, doublereal * | ||
168 | a, integer *lda, doublereal *s, doublereal *u, integer *ldu, | ||
169 | doublereal *vt, integer *ldvt, doublereal *work, integer *lwork, | ||
170 | integer *iwork, integer *info); | ||
171 | |||
172 | int svd_l_Rdd(KDMAT(a),DMAT(u), DVEC(s),DMAT(v)) { | ||
173 | integer m = ar; | ||
174 | integer n = ac; | ||
175 | integer q = MIN(m,n); | ||
176 | REQUIRES(sn==q,BAD_SIZE); | ||
177 | REQUIRES((up == NULL && vp == NULL) | ||
178 | || (ur==m && vc==n | ||
179 | && ((uc == q && vr == q) | ||
180 | || (uc == m && vc==n))),BAD_SIZE); | ||
181 | char* jobz = "A"; | ||
182 | integer ldvt = n; | ||
183 | if (up==NULL) { | ||
184 | jobz = "N"; | ||
185 | } else { | ||
186 | if (uc==q && vr == q) { | ||
187 | jobz = "S"; | ||
188 | ldvt = q; | ||
189 | } | ||
190 | } | ||
191 | DEBUGMSG("svd_l_Rdd"); | ||
192 | double *B = (double*)malloc(m*n*sizeof(double)); | ||
193 | CHECK(!B,MEM); | ||
194 | memcpy(B,ap,m*n*sizeof(double)); | ||
195 | integer* iwk = (integer*) malloc(8*q*sizeof(integer)); | ||
196 | CHECK(!iwk,MEM); | ||
197 | integer lwk = -1; | ||
198 | integer res; | ||
199 | // ask for optimal lwk | ||
200 | double ans; | ||
201 | dgesdd_ (jobz,&m,&n,B,&m,sp,up,&m,vp,&ldvt,&ans,&lwk,iwk,&res); | ||
202 | lwk = ans; | ||
203 | double * workv = (double*)malloc(lwk*sizeof(double)); | ||
204 | CHECK(!workv,MEM); | ||
205 | dgesdd_ (jobz,&m,&n,B,&m,sp,up,&m,vp,&ldvt,workv,&lwk,iwk,&res); | ||
206 | CHECK(res,res); | ||
207 | free(iwk); | ||
208 | free(workv); | ||
209 | free(B); | ||
210 | OK | ||
211 | } | ||
212 | |||
213 | //////////////////// complex svd //////////////////////////////////// | ||
214 | |||
215 | // not in clapack.h | ||
216 | |||
217 | int zgesvd_(char *jobu, char *jobvt, integer *m, integer *n, | ||
218 | doublecomplex *a, integer *lda, doublereal *s, doublecomplex *u, | ||
219 | integer *ldu, doublecomplex *vt, integer *ldvt, doublecomplex *work, | ||
220 | integer *lwork, doublereal *rwork, integer *info); | ||
221 | |||
222 | int svd_l_C(KCMAT(a),CMAT(u), DVEC(s),CMAT(v)) { | ||
223 | integer m = ar; | ||
224 | integer n = ac; | ||
225 | integer q = MIN(m,n); | ||
226 | REQUIRES(sn==q,BAD_SIZE); | ||
227 | REQUIRES(up==NULL || (ur==m && (uc==m || uc==q)),BAD_SIZE); | ||
228 | char* jobu = "A"; | ||
229 | if (up==NULL) { | ||
230 | jobu = "N"; | ||
231 | } else { | ||
232 | if (uc==q) { | ||
233 | jobu = "S"; | ||
234 | } | ||
235 | } | ||
236 | REQUIRES(vp==NULL || (vc==n && (vr==n || vr==q)),BAD_SIZE); | ||
237 | char* jobvt = "A"; | ||
238 | integer ldvt = n; | ||
239 | if (vp==NULL) { | ||
240 | jobvt = "N"; | ||
241 | } else { | ||
242 | if (vr==q) { | ||
243 | jobvt = "S"; | ||
244 | ldvt = q; | ||
245 | } | ||
246 | }DEBUGMSG("svd_l_C"); | ||
247 | doublecomplex *B = (doublecomplex*)malloc(m*n*sizeof(doublecomplex)); | ||
248 | CHECK(!B,MEM); | ||
249 | memcpy(B,ap,m*n*sizeof(doublecomplex)); | ||
250 | |||
251 | double *rwork = (double*) malloc(5*q*sizeof(double)); | ||
252 | CHECK(!rwork,MEM); | ||
253 | integer lwork = -1; | ||
254 | integer res; | ||
255 | // ask for optimal lwork | ||
256 | doublecomplex ans; | ||
257 | zgesvd_ (jobu,jobvt, | ||
258 | &m,&n,B,&m, | ||
259 | sp, | ||
260 | up,&m, | ||
261 | vp,&ldvt, | ||
262 | &ans, &lwork, | ||
263 | rwork, | ||
264 | &res); | ||
265 | lwork = ceil(ans.r); | ||
266 | doublecomplex * work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); | ||
267 | CHECK(!work,MEM); | ||
268 | zgesvd_ (jobu,jobvt, | ||
269 | &m,&n,B,&m, | ||
270 | sp, | ||
271 | up,&m, | ||
272 | vp,&ldvt, | ||
273 | work, &lwork, | ||
274 | rwork, | ||
275 | &res); | ||
276 | CHECK(res,res); | ||
277 | free(work); | ||
278 | free(rwork); | ||
279 | free(B); | ||
280 | OK | ||
281 | } | ||
282 | |||
283 | int zgesdd_ (char *jobz, integer *m, integer *n, | ||
284 | doublecomplex *a, integer *lda, doublereal *s, doublecomplex *u, | ||
285 | integer *ldu, doublecomplex *vt, integer *ldvt, doublecomplex *work, | ||
286 | integer *lwork, doublereal *rwork, integer* iwork, integer *info); | ||
287 | |||
288 | int svd_l_Cdd(KCMAT(a),CMAT(u), DVEC(s),CMAT(v)) { | ||
289 | //printf("entro\n"); | ||
290 | integer m = ar; | ||
291 | integer n = ac; | ||
292 | integer q = MIN(m,n); | ||
293 | REQUIRES(sn==q,BAD_SIZE); | ||
294 | REQUIRES((up == NULL && vp == NULL) | ||
295 | || (ur==m && vc==n | ||
296 | && ((uc == q && vr == q) | ||
297 | || (uc == m && vc==n))),BAD_SIZE); | ||
298 | char* jobz = "A"; | ||
299 | integer ldvt = n; | ||
300 | if (up==NULL) { | ||
301 | jobz = "N"; | ||
302 | } else { | ||
303 | if (uc==q && vr == q) { | ||
304 | jobz = "S"; | ||
305 | ldvt = q; | ||
306 | } | ||
307 | } | ||
308 | DEBUGMSG("svd_l_Cdd"); | ||
309 | doublecomplex *B = (doublecomplex*)malloc(m*n*sizeof(doublecomplex)); | ||
310 | CHECK(!B,MEM); | ||
311 | memcpy(B,ap,m*n*sizeof(doublecomplex)); | ||
312 | integer* iwk = (integer*) malloc(8*q*sizeof(integer)); | ||
313 | CHECK(!iwk,MEM); | ||
314 | int lrwk; | ||
315 | if (0 && *jobz == 'N') { | ||
316 | lrwk = 5*q; // does not work, crash at free below | ||
317 | } else { | ||
318 | lrwk = 5*q*q + 7*q; | ||
319 | } | ||
320 | double *rwk = (double*)malloc(lrwk*sizeof(double));; | ||
321 | CHECK(!rwk,MEM); | ||
322 | //printf("%s %ld %d\n",jobz,q,lrwk); | ||
323 | integer lwk = -1; | ||
324 | integer res; | ||
325 | // ask for optimal lwk | ||
326 | doublecomplex ans; | ||
327 | zgesdd_ (jobz,&m,&n,B,&m,sp,up,&m,vp,&ldvt,&ans,&lwk,rwk,iwk,&res); | ||
328 | lwk = ans.r; | ||
329 | //printf("lwk = %ld\n",lwk); | ||
330 | doublecomplex * workv = (doublecomplex*)malloc(lwk*sizeof(doublecomplex)); | ||
331 | CHECK(!workv,MEM); | ||
332 | zgesdd_ (jobz,&m,&n,B,&m,sp,up,&m,vp,&ldvt,workv,&lwk,rwk,iwk,&res); | ||
333 | //printf("res = %ld\n",res); | ||
334 | CHECK(res,res); | ||
335 | free(workv); // printf("freed workv\n"); | ||
336 | free(rwk); // printf("freed rwk\n"); | ||
337 | free(iwk); // printf("freed iwk\n"); | ||
338 | free(B); // printf("freed B, salgo\n"); | ||
339 | OK | ||
340 | } | ||
341 | |||
342 | //////////////////// general complex eigensystem //////////// | ||
343 | |||
344 | /* Subroutine */ int zgeev_(char *jobvl, char *jobvr, integer *n, | ||
345 | doublecomplex *a, integer *lda, doublecomplex *w, doublecomplex *vl, | ||
346 | integer *ldvl, doublecomplex *vr, integer *ldvr, doublecomplex *work, | ||
347 | integer *lwork, doublereal *rwork, integer *info); | ||
348 | |||
349 | int eig_l_C(KCMAT(a), CMAT(u), CVEC(s),CMAT(v)) { | ||
350 | integer n = ar; | ||
351 | REQUIRES(ac==n && sn==n, BAD_SIZE); | ||
352 | REQUIRES(up==NULL || (ur==n && uc==n), BAD_SIZE); | ||
353 | char jobvl = up==NULL?'N':'V'; | ||
354 | REQUIRES(vp==NULL || (vr==n && vc==n), BAD_SIZE); | ||
355 | char jobvr = vp==NULL?'N':'V'; | ||
356 | DEBUGMSG("eig_l_C"); | ||
357 | doublecomplex *B = (doublecomplex*)malloc(n*n*sizeof(doublecomplex)); | ||
358 | CHECK(!B,MEM); | ||
359 | memcpy(B,ap,n*n*sizeof(doublecomplex)); | ||
360 | double *rwork = (double*) malloc(2*n*sizeof(double)); | ||
361 | CHECK(!rwork,MEM); | ||
362 | integer lwork = -1; | ||
363 | integer res; | ||
364 | // ask for optimal lwork | ||
365 | doublecomplex ans; | ||
366 | //printf("ask zgeev\n"); | ||
367 | zgeev_ (&jobvl,&jobvr, | ||
368 | &n,B,&n, | ||
369 | sp, | ||
370 | up,&n, | ||
371 | vp,&n, | ||
372 | &ans, &lwork, | ||
373 | rwork, | ||
374 | &res); | ||
375 | lwork = ceil(ans.r); | ||
376 | //printf("ans = %d\n",lwork); | ||
377 | doublecomplex * work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); | ||
378 | CHECK(!work,MEM); | ||
379 | //printf("zgeev\n"); | ||
380 | zgeev_ (&jobvl,&jobvr, | ||
381 | &n,B,&n, | ||
382 | sp, | ||
383 | up,&n, | ||
384 | vp,&n, | ||
385 | work, &lwork, | ||
386 | rwork, | ||
387 | &res); | ||
388 | CHECK(res,res); | ||
389 | free(work); | ||
390 | free(rwork); | ||
391 | free(B); | ||
392 | OK | ||
393 | } | ||
394 | |||
395 | |||
396 | |||
397 | //////////////////// general real eigensystem //////////// | ||
398 | |||
399 | /* Subroutine */ int dgeev_(char *jobvl, char *jobvr, integer *n, doublereal * | ||
400 | a, integer *lda, doublereal *wr, doublereal *wi, doublereal *vl, | ||
401 | integer *ldvl, doublereal *vr, integer *ldvr, doublereal *work, | ||
402 | integer *lwork, integer *info); | ||
403 | |||
404 | int eig_l_R(KDMAT(a),DMAT(u), CVEC(s),DMAT(v)) { | ||
405 | integer n = ar; | ||
406 | REQUIRES(ac==n && sn==n, BAD_SIZE); | ||
407 | REQUIRES(up==NULL || (ur==n && uc==n), BAD_SIZE); | ||
408 | char jobvl = up==NULL?'N':'V'; | ||
409 | REQUIRES(vp==NULL || (vr==n && vc==n), BAD_SIZE); | ||
410 | char jobvr = vp==NULL?'N':'V'; | ||
411 | DEBUGMSG("eig_l_R"); | ||
412 | double *B = (double*)malloc(n*n*sizeof(double)); | ||
413 | CHECK(!B,MEM); | ||
414 | memcpy(B,ap,n*n*sizeof(double)); | ||
415 | integer lwork = -1; | ||
416 | integer res; | ||
417 | // ask for optimal lwork | ||
418 | double ans; | ||
419 | //printf("ask dgeev\n"); | ||
420 | dgeev_ (&jobvl,&jobvr, | ||
421 | &n,B,&n, | ||
422 | (double*)sp, (double*)sp+n, | ||
423 | up,&n, | ||
424 | vp,&n, | ||
425 | &ans, &lwork, | ||
426 | &res); | ||
427 | lwork = ceil(ans); | ||
428 | //printf("ans = %d\n",lwork); | ||
429 | double * work = (double*)malloc(lwork*sizeof(double)); | ||
430 | CHECK(!work,MEM); | ||
431 | //printf("dgeev\n"); | ||
432 | dgeev_ (&jobvl,&jobvr, | ||
433 | &n,B,&n, | ||
434 | (double*)sp, (double*)sp+n, | ||
435 | up,&n, | ||
436 | vp,&n, | ||
437 | work, &lwork, | ||
438 | &res); | ||
439 | CHECK(res,res); | ||
440 | free(work); | ||
441 | free(B); | ||
442 | OK | ||
443 | } | ||
444 | |||
445 | |||
446 | //////////////////// symmetric real eigensystem //////////// | ||
447 | |||
448 | /* Subroutine */ int dsyev_(char *jobz, char *uplo, integer *n, doublereal *a, | ||
449 | integer *lda, doublereal *w, doublereal *work, integer *lwork, | ||
450 | integer *info); | ||
451 | |||
452 | int eig_l_S(int wantV,KDMAT(a),DVEC(s),DMAT(v)) { | ||
453 | integer n = ar; | ||
454 | REQUIRES(ac==n && sn==n, BAD_SIZE); | ||
455 | REQUIRES(vr==n && vc==n, BAD_SIZE); | ||
456 | char jobz = wantV?'V':'N'; | ||
457 | DEBUGMSG("eig_l_S"); | ||
458 | memcpy(vp,ap,n*n*sizeof(double)); | ||
459 | integer lwork = -1; | ||
460 | char uplo = 'U'; | ||
461 | integer res; | ||
462 | // ask for optimal lwork | ||
463 | double ans; | ||
464 | //printf("ask dsyev\n"); | ||
465 | dsyev_ (&jobz,&uplo, | ||
466 | &n,vp,&n, | ||
467 | sp, | ||
468 | &ans, &lwork, | ||
469 | &res); | ||
470 | lwork = ceil(ans); | ||
471 | //printf("ans = %d\n",lwork); | ||
472 | double * work = (double*)malloc(lwork*sizeof(double)); | ||
473 | CHECK(!work,MEM); | ||
474 | dsyev_ (&jobz,&uplo, | ||
475 | &n,vp,&n, | ||
476 | sp, | ||
477 | work, &lwork, | ||
478 | &res); | ||
479 | CHECK(res,res); | ||
480 | free(work); | ||
481 | OK | ||
482 | } | ||
483 | |||
484 | //////////////////// hermitian complex eigensystem //////////// | ||
485 | |||
486 | /* Subroutine */ int zheev_(char *jobz, char *uplo, integer *n, doublecomplex | ||
487 | *a, integer *lda, doublereal *w, doublecomplex *work, integer *lwork, | ||
488 | doublereal *rwork, integer *info); | ||
489 | |||
490 | int eig_l_H(int wantV,KCMAT(a),DVEC(s),CMAT(v)) { | ||
491 | integer n = ar; | ||
492 | REQUIRES(ac==n && sn==n, BAD_SIZE); | ||
493 | REQUIRES(vr==n && vc==n, BAD_SIZE); | ||
494 | char jobz = wantV?'V':'N'; | ||
495 | DEBUGMSG("eig_l_H"); | ||
496 | memcpy(vp,ap,2*n*n*sizeof(double)); | ||
497 | double *rwork = (double*) malloc((3*n-2)*sizeof(double)); | ||
498 | CHECK(!rwork,MEM); | ||
499 | integer lwork = -1; | ||
500 | char uplo = 'U'; | ||
501 | integer res; | ||
502 | // ask for optimal lwork | ||
503 | doublecomplex ans; | ||
504 | //printf("ask zheev\n"); | ||
505 | zheev_ (&jobz,&uplo, | ||
506 | &n,vp,&n, | ||
507 | sp, | ||
508 | &ans, &lwork, | ||
509 | rwork, | ||
510 | &res); | ||
511 | lwork = ceil(ans.r); | ||
512 | //printf("ans = %d\n",lwork); | ||
513 | doublecomplex * work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); | ||
514 | CHECK(!work,MEM); | ||
515 | zheev_ (&jobz,&uplo, | ||
516 | &n,vp,&n, | ||
517 | sp, | ||
518 | work, &lwork, | ||
519 | rwork, | ||
520 | &res); | ||
521 | CHECK(res,res); | ||
522 | free(work); | ||
523 | free(rwork); | ||
524 | OK | ||
525 | } | ||
526 | |||
527 | //////////////////// general real linear system //////////// | ||
528 | |||
529 | /* Subroutine */ int dgesv_(integer *n, integer *nrhs, doublereal *a, integer | ||
530 | *lda, integer *ipiv, doublereal *b, integer *ldb, integer *info); | ||
531 | |||
532 | int linearSolveR_l(KDMAT(a),KDMAT(b),DMAT(x)) { | ||
533 | integer n = ar; | ||
534 | integer nhrs = bc; | ||
535 | REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); | ||
536 | DEBUGMSG("linearSolveR_l"); | ||
537 | double*AC = (double*)malloc(n*n*sizeof(double)); | ||
538 | memcpy(AC,ap,n*n*sizeof(double)); | ||
539 | memcpy(xp,bp,n*nhrs*sizeof(double)); | ||
540 | integer * ipiv = (integer*)malloc(n*sizeof(integer)); | ||
541 | integer res; | ||
542 | dgesv_ (&n,&nhrs, | ||
543 | AC, &n, | ||
544 | ipiv, | ||
545 | xp, &n, | ||
546 | &res); | ||
547 | if(res>0) { | ||
548 | return SINGULAR; | ||
549 | } | ||
550 | CHECK(res,res); | ||
551 | free(ipiv); | ||
552 | free(AC); | ||
553 | OK | ||
554 | } | ||
555 | |||
556 | //////////////////// general complex linear system //////////// | ||
557 | |||
558 | /* Subroutine */ int zgesv_(integer *n, integer *nrhs, doublecomplex *a, | ||
559 | integer *lda, integer *ipiv, doublecomplex *b, integer *ldb, integer * | ||
560 | info); | ||
561 | |||
562 | int linearSolveC_l(KCMAT(a),KCMAT(b),CMAT(x)) { | ||
563 | integer n = ar; | ||
564 | integer nhrs = bc; | ||
565 | REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); | ||
566 | DEBUGMSG("linearSolveC_l"); | ||
567 | doublecomplex*AC = (doublecomplex*)malloc(n*n*sizeof(doublecomplex)); | ||
568 | memcpy(AC,ap,n*n*sizeof(doublecomplex)); | ||
569 | memcpy(xp,bp,n*nhrs*sizeof(doublecomplex)); | ||
570 | integer * ipiv = (integer*)malloc(n*sizeof(integer)); | ||
571 | integer res; | ||
572 | zgesv_ (&n,&nhrs, | ||
573 | AC, &n, | ||
574 | ipiv, | ||
575 | xp, &n, | ||
576 | &res); | ||
577 | if(res>0) { | ||
578 | return SINGULAR; | ||
579 | } | ||
580 | CHECK(res,res); | ||
581 | free(ipiv); | ||
582 | free(AC); | ||
583 | OK | ||
584 | } | ||
585 | |||
586 | //////// symmetric positive definite real linear system using Cholesky //////////// | ||
587 | |||
588 | /* Subroutine */ int dpotrs_(char *uplo, integer *n, integer *nrhs, | ||
589 | doublereal *a, integer *lda, doublereal *b, integer *ldb, integer * | ||
590 | info); | ||
591 | |||
592 | int cholSolveR_l(KDMAT(a),KDMAT(b),DMAT(x)) { | ||
593 | integer n = ar; | ||
594 | integer nhrs = bc; | ||
595 | REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); | ||
596 | DEBUGMSG("cholSolveR_l"); | ||
597 | memcpy(xp,bp,n*nhrs*sizeof(double)); | ||
598 | integer res; | ||
599 | dpotrs_ ("U", | ||
600 | &n,&nhrs, | ||
601 | (double*)ap, &n, | ||
602 | xp, &n, | ||
603 | &res); | ||
604 | CHECK(res,res); | ||
605 | OK | ||
606 | } | ||
607 | |||
608 | //////// Hermitian positive definite real linear system using Cholesky //////////// | ||
609 | |||
610 | /* Subroutine */ int zpotrs_(char *uplo, integer *n, integer *nrhs, | ||
611 | doublecomplex *a, integer *lda, doublecomplex *b, integer *ldb, | ||
612 | integer *info); | ||
613 | |||
614 | int cholSolveC_l(KCMAT(a),KCMAT(b),CMAT(x)) { | ||
615 | integer n = ar; | ||
616 | integer nhrs = bc; | ||
617 | REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); | ||
618 | DEBUGMSG("cholSolveC_l"); | ||
619 | memcpy(xp,bp,n*nhrs*sizeof(doublecomplex)); | ||
620 | integer res; | ||
621 | zpotrs_ ("U", | ||
622 | &n,&nhrs, | ||
623 | (doublecomplex*)ap, &n, | ||
624 | xp, &n, | ||
625 | &res); | ||
626 | CHECK(res,res); | ||
627 | OK | ||
628 | } | ||
629 | |||
630 | //////////////////// least squares real linear system //////////// | ||
631 | |||
632 | /* Subroutine */ int dgels_(char *trans, integer *m, integer *n, integer * | ||
633 | nrhs, doublereal *a, integer *lda, doublereal *b, integer *ldb, | ||
634 | doublereal *work, integer *lwork, integer *info); | ||
635 | |||
636 | int linearSolveLSR_l(KDMAT(a),KDMAT(b),DMAT(x)) { | ||
637 | integer m = ar; | ||
638 | integer n = ac; | ||
639 | integer nrhs = bc; | ||
640 | integer ldb = xr; | ||
641 | REQUIRES(m>=1 && n>=1 && ar==br && xr==MAX(m,n) && xc == bc, BAD_SIZE); | ||
642 | DEBUGMSG("linearSolveLSR_l"); | ||
643 | double*AC = (double*)malloc(m*n*sizeof(double)); | ||
644 | memcpy(AC,ap,m*n*sizeof(double)); | ||
645 | if (m>=n) { | ||
646 | memcpy(xp,bp,m*nrhs*sizeof(double)); | ||
647 | } else { | ||
648 | int k; | ||
649 | for(k = 0; k<nrhs; k++) { | ||
650 | memcpy(xp+ldb*k,bp+m*k,m*sizeof(double)); | ||
651 | } | ||
652 | } | ||
653 | integer res; | ||
654 | integer lwork = -1; | ||
655 | double ans; | ||
656 | dgels_ ("N",&m,&n,&nrhs, | ||
657 | AC,&m, | ||
658 | xp,&ldb, | ||
659 | &ans,&lwork, | ||
660 | &res); | ||
661 | lwork = ceil(ans); | ||
662 | //printf("ans = %d\n",lwork); | ||
663 | double * work = (double*)malloc(lwork*sizeof(double)); | ||
664 | dgels_ ("N",&m,&n,&nrhs, | ||
665 | AC,&m, | ||
666 | xp,&ldb, | ||
667 | work,&lwork, | ||
668 | &res); | ||
669 | if(res>0) { | ||
670 | return SINGULAR; | ||
671 | } | ||
672 | CHECK(res,res); | ||
673 | free(work); | ||
674 | free(AC); | ||
675 | OK | ||
676 | } | ||
677 | |||
678 | //////////////////// least squares complex linear system //////////// | ||
679 | |||
680 | /* Subroutine */ int zgels_(char *trans, integer *m, integer *n, integer * | ||
681 | nrhs, doublecomplex *a, integer *lda, doublecomplex *b, integer *ldb, | ||
682 | doublecomplex *work, integer *lwork, integer *info); | ||
683 | |||
684 | int linearSolveLSC_l(KCMAT(a),KCMAT(b),CMAT(x)) { | ||
685 | integer m = ar; | ||
686 | integer n = ac; | ||
687 | integer nrhs = bc; | ||
688 | integer ldb = xr; | ||
689 | REQUIRES(m>=1 && n>=1 && ar==br && xr==MAX(m,n) && xc == bc, BAD_SIZE); | ||
690 | DEBUGMSG("linearSolveLSC_l"); | ||
691 | doublecomplex*AC = (doublecomplex*)malloc(m*n*sizeof(doublecomplex)); | ||
692 | memcpy(AC,ap,m*n*sizeof(doublecomplex)); | ||
693 | if (m>=n) { | ||
694 | memcpy(xp,bp,m*nrhs*sizeof(doublecomplex)); | ||
695 | } else { | ||
696 | int k; | ||
697 | for(k = 0; k<nrhs; k++) { | ||
698 | memcpy(xp+ldb*k,bp+m*k,m*sizeof(doublecomplex)); | ||
699 | } | ||
700 | } | ||
701 | integer res; | ||
702 | integer lwork = -1; | ||
703 | doublecomplex ans; | ||
704 | zgels_ ("N",&m,&n,&nrhs, | ||
705 | AC,&m, | ||
706 | xp,&ldb, | ||
707 | &ans,&lwork, | ||
708 | &res); | ||
709 | lwork = ceil(ans.r); | ||
710 | //printf("ans = %d\n",lwork); | ||
711 | doublecomplex * work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); | ||
712 | zgels_ ("N",&m,&n,&nrhs, | ||
713 | AC,&m, | ||
714 | xp,&ldb, | ||
715 | work,&lwork, | ||
716 | &res); | ||
717 | if(res>0) { | ||
718 | return SINGULAR; | ||
719 | } | ||
720 | CHECK(res,res); | ||
721 | free(work); | ||
722 | free(AC); | ||
723 | OK | ||
724 | } | ||
725 | |||
726 | //////////////////// least squares real linear system using SVD //////////// | ||
727 | |||
728 | /* Subroutine */ int dgelss_(integer *m, integer *n, integer *nrhs, | ||
729 | doublereal *a, integer *lda, doublereal *b, integer *ldb, doublereal * | ||
730 | s, doublereal *rcond, integer *rank, doublereal *work, integer *lwork, | ||
731 | integer *info); | ||
732 | |||
733 | int linearSolveSVDR_l(double rcond,KDMAT(a),KDMAT(b),DMAT(x)) { | ||
734 | integer m = ar; | ||
735 | integer n = ac; | ||
736 | integer nrhs = bc; | ||
737 | integer ldb = xr; | ||
738 | REQUIRES(m>=1 && n>=1 && ar==br && xr==MAX(m,n) && xc == bc, BAD_SIZE); | ||
739 | DEBUGMSG("linearSolveSVDR_l"); | ||
740 | double*AC = (double*)malloc(m*n*sizeof(double)); | ||
741 | double*S = (double*)malloc(MIN(m,n)*sizeof(double)); | ||
742 | memcpy(AC,ap,m*n*sizeof(double)); | ||
743 | if (m>=n) { | ||
744 | memcpy(xp,bp,m*nrhs*sizeof(double)); | ||
745 | } else { | ||
746 | int k; | ||
747 | for(k = 0; k<nrhs; k++) { | ||
748 | memcpy(xp+ldb*k,bp+m*k,m*sizeof(double)); | ||
749 | } | ||
750 | } | ||
751 | integer res; | ||
752 | integer lwork = -1; | ||
753 | integer rank; | ||
754 | double ans; | ||
755 | dgelss_ (&m,&n,&nrhs, | ||
756 | AC,&m, | ||
757 | xp,&ldb, | ||
758 | S, | ||
759 | &rcond,&rank, | ||
760 | &ans,&lwork, | ||
761 | &res); | ||
762 | lwork = ceil(ans); | ||
763 | //printf("ans = %d\n",lwork); | ||
764 | double * work = (double*)malloc(lwork*sizeof(double)); | ||
765 | dgelss_ (&m,&n,&nrhs, | ||
766 | AC,&m, | ||
767 | xp,&ldb, | ||
768 | S, | ||
769 | &rcond,&rank, | ||
770 | work,&lwork, | ||
771 | &res); | ||
772 | if(res>0) { | ||
773 | return NOCONVER; | ||
774 | } | ||
775 | CHECK(res,res); | ||
776 | free(work); | ||
777 | free(S); | ||
778 | free(AC); | ||
779 | OK | ||
780 | } | ||
781 | |||
782 | //////////////////// least squares complex linear system using SVD //////////// | ||
783 | |||
784 | // not in clapack.h | ||
785 | |||
786 | int zgelss_(integer *m, integer *n, integer *nhrs, | ||
787 | doublecomplex *a, integer *lda, doublecomplex *b, integer *ldb, doublereal *s, | ||
788 | doublereal *rcond, integer* rank, | ||
789 | doublecomplex *work, integer* lwork, doublereal* rwork, | ||
790 | integer *info); | ||
791 | |||
792 | int linearSolveSVDC_l(double rcond, KCMAT(a),KCMAT(b),CMAT(x)) { | ||
793 | integer m = ar; | ||
794 | integer n = ac; | ||
795 | integer nrhs = bc; | ||
796 | integer ldb = xr; | ||
797 | REQUIRES(m>=1 && n>=1 && ar==br && xr==MAX(m,n) && xc == bc, BAD_SIZE); | ||
798 | DEBUGMSG("linearSolveSVDC_l"); | ||
799 | doublecomplex*AC = (doublecomplex*)malloc(m*n*sizeof(doublecomplex)); | ||
800 | double*S = (double*)malloc(MIN(m,n)*sizeof(double)); | ||
801 | double*RWORK = (double*)malloc(5*MIN(m,n)*sizeof(double)); | ||
802 | memcpy(AC,ap,m*n*sizeof(doublecomplex)); | ||
803 | if (m>=n) { | ||
804 | memcpy(xp,bp,m*nrhs*sizeof(doublecomplex)); | ||
805 | } else { | ||
806 | int k; | ||
807 | for(k = 0; k<nrhs; k++) { | ||
808 | memcpy(xp+ldb*k,bp+m*k,m*sizeof(doublecomplex)); | ||
809 | } | ||
810 | } | ||
811 | integer res; | ||
812 | integer lwork = -1; | ||
813 | integer rank; | ||
814 | doublecomplex ans; | ||
815 | zgelss_ (&m,&n,&nrhs, | ||
816 | AC,&m, | ||
817 | xp,&ldb, | ||
818 | S, | ||
819 | &rcond,&rank, | ||
820 | &ans,&lwork, | ||
821 | RWORK, | ||
822 | &res); | ||
823 | lwork = ceil(ans.r); | ||
824 | //printf("ans = %d\n",lwork); | ||
825 | doublecomplex * work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); | ||
826 | zgelss_ (&m,&n,&nrhs, | ||
827 | AC,&m, | ||
828 | xp,&ldb, | ||
829 | S, | ||
830 | &rcond,&rank, | ||
831 | work,&lwork, | ||
832 | RWORK, | ||
833 | &res); | ||
834 | if(res>0) { | ||
835 | return NOCONVER; | ||
836 | } | ||
837 | CHECK(res,res); | ||
838 | free(work); | ||
839 | free(RWORK); | ||
840 | free(S); | ||
841 | free(AC); | ||
842 | OK | ||
843 | } | ||
844 | |||
845 | //////////////////// Cholesky factorization ///////////////////////// | ||
846 | |||
847 | /* Subroutine */ int zpotrf_(char *uplo, integer *n, doublecomplex *a, | ||
848 | integer *lda, integer *info); | ||
849 | |||
850 | int chol_l_H(KCMAT(a),CMAT(l)) { | ||
851 | integer n = ar; | ||
852 | REQUIRES(n>=1 && ac == n && lr==n && lc==n,BAD_SIZE); | ||
853 | DEBUGMSG("chol_l_H"); | ||
854 | memcpy(lp,ap,n*n*sizeof(doublecomplex)); | ||
855 | char uplo = 'U'; | ||
856 | integer res; | ||
857 | zpotrf_ (&uplo,&n,lp,&n,&res); | ||
858 | CHECK(res>0,NODEFPOS); | ||
859 | CHECK(res,res); | ||
860 | doublecomplex zero = {0.,0.}; | ||
861 | int r,c; | ||
862 | for (r=0; r<lr-1; r++) { | ||
863 | for(c=r+1; c<lc; c++) { | ||
864 | lp[r*lc+c] = zero; | ||
865 | } | ||
866 | } | ||
867 | OK | ||
868 | } | ||
869 | |||
870 | |||
871 | /* Subroutine */ int dpotrf_(char *uplo, integer *n, doublereal *a, integer * | ||
872 | lda, integer *info); | ||
873 | |||
874 | int chol_l_S(KDMAT(a),DMAT(l)) { | ||
875 | integer n = ar; | ||
876 | REQUIRES(n>=1 && ac == n && lr==n && lc==n,BAD_SIZE); | ||
877 | DEBUGMSG("chol_l_S"); | ||
878 | memcpy(lp,ap,n*n*sizeof(double)); | ||
879 | char uplo = 'U'; | ||
880 | integer res; | ||
881 | dpotrf_ (&uplo,&n,lp,&n,&res); | ||
882 | CHECK(res>0,NODEFPOS); | ||
883 | CHECK(res,res); | ||
884 | int r,c; | ||
885 | for (r=0; r<lr-1; r++) { | ||
886 | for(c=r+1; c<lc; c++) { | ||
887 | lp[r*lc+c] = 0.; | ||
888 | } | ||
889 | } | ||
890 | OK | ||
891 | } | ||
892 | |||
893 | //////////////////// QR factorization ///////////////////////// | ||
894 | |||
895 | /* Subroutine */ int dgeqr2_(integer *m, integer *n, doublereal *a, integer * | ||
896 | lda, doublereal *tau, doublereal *work, integer *info); | ||
897 | |||
898 | int qr_l_R(KDMAT(a), DVEC(tau), DMAT(r)) { | ||
899 | integer m = ar; | ||
900 | integer n = ac; | ||
901 | integer mn = MIN(m,n); | ||
902 | REQUIRES(m>=1 && n >=1 && rr== m && rc == n && taun == mn, BAD_SIZE); | ||
903 | DEBUGMSG("qr_l_R"); | ||
904 | double *WORK = (double*)malloc(n*sizeof(double)); | ||
905 | CHECK(!WORK,MEM); | ||
906 | memcpy(rp,ap,m*n*sizeof(double)); | ||
907 | integer res; | ||
908 | dgeqr2_ (&m,&n,rp,&m,taup,WORK,&res); | ||
909 | CHECK(res,res); | ||
910 | free(WORK); | ||
911 | OK | ||
912 | } | ||
913 | |||
914 | /* Subroutine */ int zgeqr2_(integer *m, integer *n, doublecomplex *a, | ||
915 | integer *lda, doublecomplex *tau, doublecomplex *work, integer *info); | ||
916 | |||
917 | int qr_l_C(KCMAT(a), CVEC(tau), CMAT(r)) { | ||
918 | integer m = ar; | ||
919 | integer n = ac; | ||
920 | integer mn = MIN(m,n); | ||
921 | REQUIRES(m>=1 && n >=1 && rr== m && rc == n && taun == mn, BAD_SIZE); | ||
922 | DEBUGMSG("qr_l_C"); | ||
923 | doublecomplex *WORK = (doublecomplex*)malloc(n*sizeof(doublecomplex)); | ||
924 | CHECK(!WORK,MEM); | ||
925 | memcpy(rp,ap,m*n*sizeof(doublecomplex)); | ||
926 | integer res; | ||
927 | zgeqr2_ (&m,&n,rp,&m,taup,WORK,&res); | ||
928 | CHECK(res,res); | ||
929 | free(WORK); | ||
930 | OK | ||
931 | } | ||
932 | |||
933 | /* Subroutine */ int dorgqr_(integer *m, integer *n, integer *k, doublereal * | ||
934 | a, integer *lda, doublereal *tau, doublereal *work, integer *lwork, | ||
935 | integer *info); | ||
936 | |||
937 | int c_dorgqr(KDMAT(a), KDVEC(tau), DMAT(r)) { | ||
938 | integer m = ar; | ||
939 | integer n = MIN(ac,ar); | ||
940 | integer k = taun; | ||
941 | DEBUGMSG("c_dorgqr"); | ||
942 | integer lwork = 8*n; // FIXME | ||
943 | double *WORK = (double*)malloc(lwork*sizeof(double)); | ||
944 | CHECK(!WORK,MEM); | ||
945 | memcpy(rp,ap,m*k*sizeof(double)); | ||
946 | integer res; | ||
947 | dorgqr_ (&m,&n,&k,rp,&m,(double*)taup,WORK,&lwork,&res); | ||
948 | CHECK(res,res); | ||
949 | free(WORK); | ||
950 | OK | ||
951 | } | ||
952 | |||
953 | /* Subroutine */ int zungqr_(integer *m, integer *n, integer *k, | ||
954 | doublecomplex *a, integer *lda, doublecomplex *tau, doublecomplex * | ||
955 | work, integer *lwork, integer *info); | ||
956 | |||
957 | int c_zungqr(KCMAT(a), KCVEC(tau), CMAT(r)) { | ||
958 | integer m = ar; | ||
959 | integer n = MIN(ac,ar); | ||
960 | integer k = taun; | ||
961 | DEBUGMSG("z_ungqr"); | ||
962 | integer lwork = 8*n; // FIXME | ||
963 | doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); | ||
964 | CHECK(!WORK,MEM); | ||
965 | memcpy(rp,ap,m*k*sizeof(doublecomplex)); | ||
966 | integer res; | ||
967 | zungqr_ (&m,&n,&k,rp,&m,(doublecomplex*)taup,WORK,&lwork,&res); | ||
968 | CHECK(res,res); | ||
969 | free(WORK); | ||
970 | OK | ||
971 | } | ||
972 | |||
973 | |||
974 | //////////////////// Hessenberg factorization ///////////////////////// | ||
975 | |||
976 | /* Subroutine */ int dgehrd_(integer *n, integer *ilo, integer *ihi, | ||
977 | doublereal *a, integer *lda, doublereal *tau, doublereal *work, | ||
978 | integer *lwork, integer *info); | ||
979 | |||
980 | int hess_l_R(KDMAT(a), DVEC(tau), DMAT(r)) { | ||
981 | integer m = ar; | ||
982 | integer n = ac; | ||
983 | integer mn = MIN(m,n); | ||
984 | REQUIRES(m>=1 && n == m && rr== m && rc == n && taun == mn-1, BAD_SIZE); | ||
985 | DEBUGMSG("hess_l_R"); | ||
986 | integer lwork = 5*n; // fixme | ||
987 | double *WORK = (double*)malloc(lwork*sizeof(double)); | ||
988 | CHECK(!WORK,MEM); | ||
989 | memcpy(rp,ap,m*n*sizeof(double)); | ||
990 | integer res; | ||
991 | integer one = 1; | ||
992 | dgehrd_ (&n,&one,&n,rp,&n,taup,WORK,&lwork,&res); | ||
993 | CHECK(res,res); | ||
994 | free(WORK); | ||
995 | OK | ||
996 | } | ||
997 | |||
998 | |||
999 | /* Subroutine */ int zgehrd_(integer *n, integer *ilo, integer *ihi, | ||
1000 | doublecomplex *a, integer *lda, doublecomplex *tau, doublecomplex * | ||
1001 | work, integer *lwork, integer *info); | ||
1002 | |||
1003 | int hess_l_C(KCMAT(a), CVEC(tau), CMAT(r)) { | ||
1004 | integer m = ar; | ||
1005 | integer n = ac; | ||
1006 | integer mn = MIN(m,n); | ||
1007 | REQUIRES(m>=1 && n == m && rr== m && rc == n && taun == mn-1, BAD_SIZE); | ||
1008 | DEBUGMSG("hess_l_C"); | ||
1009 | integer lwork = 5*n; // fixme | ||
1010 | doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); | ||
1011 | CHECK(!WORK,MEM); | ||
1012 | memcpy(rp,ap,m*n*sizeof(doublecomplex)); | ||
1013 | integer res; | ||
1014 | integer one = 1; | ||
1015 | zgehrd_ (&n,&one,&n,rp,&n,taup,WORK,&lwork,&res); | ||
1016 | CHECK(res,res); | ||
1017 | free(WORK); | ||
1018 | OK | ||
1019 | } | ||
1020 | |||
1021 | //////////////////// Schur factorization ///////////////////////// | ||
1022 | |||
1023 | /* Subroutine */ int dgees_(char *jobvs, char *sort, L_fp select, integer *n, | ||
1024 | doublereal *a, integer *lda, integer *sdim, doublereal *wr, | ||
1025 | doublereal *wi, doublereal *vs, integer *ldvs, doublereal *work, | ||
1026 | integer *lwork, logical *bwork, integer *info); | ||
1027 | |||
1028 | int schur_l_R(KDMAT(a), DMAT(u), DMAT(s)) { | ||
1029 | integer m = ar; | ||
1030 | integer n = ac; | ||
1031 | REQUIRES(m>=1 && n==m && ur==n && uc==n && sr==n && sc==n, BAD_SIZE); | ||
1032 | DEBUGMSG("schur_l_R"); | ||
1033 | //int k; | ||
1034 | //printf("---------------------------\n"); | ||
1035 | //printf("%p: ",ap); for(k=0;k<n*n;k++) printf("%f ",ap[k]); printf("\n"); | ||
1036 | //printf("%p: ",up); for(k=0;k<n*n;k++) printf("%f ",up[k]); printf("\n"); | ||
1037 | //printf("%p: ",sp); for(k=0;k<n*n;k++) printf("%f ",sp[k]); printf("\n"); | ||
1038 | memcpy(sp,ap,n*n*sizeof(double)); | ||
1039 | integer lwork = 6*n; // fixme | ||
1040 | double *WORK = (double*)malloc(lwork*sizeof(double)); | ||
1041 | double *WR = (double*)malloc(n*sizeof(double)); | ||
1042 | double *WI = (double*)malloc(n*sizeof(double)); | ||
1043 | // WR and WI not really required in this call | ||
1044 | logical *BWORK = (logical*)malloc(n*sizeof(logical)); | ||
1045 | integer res; | ||
1046 | integer sdim; | ||
1047 | dgees_ ("V","N",NULL,&n,sp,&n,&sdim,WR,WI,up,&n,WORK,&lwork,BWORK,&res); | ||
1048 | //printf("%p: ",ap); for(k=0;k<n*n;k++) printf("%f ",ap[k]); printf("\n"); | ||
1049 | //printf("%p: ",up); for(k=0;k<n*n;k++) printf("%f ",up[k]); printf("\n"); | ||
1050 | //printf("%p: ",sp); for(k=0;k<n*n;k++) printf("%f ",sp[k]); printf("\n"); | ||
1051 | if(res>0) { | ||
1052 | return NOCONVER; | ||
1053 | } | ||
1054 | CHECK(res,res); | ||
1055 | free(WR); | ||
1056 | free(WI); | ||
1057 | free(BWORK); | ||
1058 | free(WORK); | ||
1059 | OK | ||
1060 | } | ||
1061 | |||
1062 | |||
1063 | /* Subroutine */ int zgees_(char *jobvs, char *sort, L_fp select, integer *n, | ||
1064 | doublecomplex *a, integer *lda, integer *sdim, doublecomplex *w, | ||
1065 | doublecomplex *vs, integer *ldvs, doublecomplex *work, integer *lwork, | ||
1066 | doublereal *rwork, logical *bwork, integer *info); | ||
1067 | |||
1068 | int schur_l_C(KCMAT(a), CMAT(u), CMAT(s)) { | ||
1069 | integer m = ar; | ||
1070 | integer n = ac; | ||
1071 | REQUIRES(m>=1 && n==m && ur==n && uc==n && sr==n && sc==n, BAD_SIZE); | ||
1072 | DEBUGMSG("schur_l_C"); | ||
1073 | memcpy(sp,ap,n*n*sizeof(doublecomplex)); | ||
1074 | integer lwork = 6*n; // fixme | ||
1075 | doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); | ||
1076 | doublecomplex *W = (doublecomplex*)malloc(n*sizeof(doublecomplex)); | ||
1077 | // W not really required in this call | ||
1078 | logical *BWORK = (logical*)malloc(n*sizeof(logical)); | ||
1079 | double *RWORK = (double*)malloc(n*sizeof(double)); | ||
1080 | integer res; | ||
1081 | integer sdim; | ||
1082 | zgees_ ("V","N",NULL,&n,sp,&n,&sdim,W, | ||
1083 | up,&n, | ||
1084 | WORK,&lwork,RWORK,BWORK,&res); | ||
1085 | if(res>0) { | ||
1086 | return NOCONVER; | ||
1087 | } | ||
1088 | CHECK(res,res); | ||
1089 | free(W); | ||
1090 | free(BWORK); | ||
1091 | free(WORK); | ||
1092 | OK | ||
1093 | } | ||
1094 | |||
1095 | //////////////////// LU factorization ///////////////////////// | ||
1096 | |||
1097 | /* Subroutine */ int dgetrf_(integer *m, integer *n, doublereal *a, integer * | ||
1098 | lda, integer *ipiv, integer *info); | ||
1099 | |||
1100 | int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r)) { | ||
1101 | integer m = ar; | ||
1102 | integer n = ac; | ||
1103 | integer mn = MIN(m,n); | ||
1104 | REQUIRES(m>=1 && n >=1 && ipivn == mn, BAD_SIZE); | ||
1105 | DEBUGMSG("lu_l_R"); | ||
1106 | integer* auxipiv = (integer*)malloc(mn*sizeof(integer)); | ||
1107 | memcpy(rp,ap,m*n*sizeof(double)); | ||
1108 | integer res; | ||
1109 | dgetrf_ (&m,&n,rp,&m,auxipiv,&res); | ||
1110 | if(res>0) { | ||
1111 | res = 0; // fixme | ||
1112 | } | ||
1113 | CHECK(res,res); | ||
1114 | int k; | ||
1115 | for (k=0; k<mn; k++) { | ||
1116 | ipivp[k] = auxipiv[k]; | ||
1117 | } | ||
1118 | free(auxipiv); | ||
1119 | OK | ||
1120 | } | ||
1121 | |||
1122 | |||
1123 | /* Subroutine */ int zgetrf_(integer *m, integer *n, doublecomplex *a, | ||
1124 | integer *lda, integer *ipiv, integer *info); | ||
1125 | |||
1126 | int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)) { | ||
1127 | integer m = ar; | ||
1128 | integer n = ac; | ||
1129 | integer mn = MIN(m,n); | ||
1130 | REQUIRES(m>=1 && n >=1 && ipivn == mn, BAD_SIZE); | ||
1131 | DEBUGMSG("lu_l_C"); | ||
1132 | integer* auxipiv = (integer*)malloc(mn*sizeof(integer)); | ||
1133 | memcpy(rp,ap,m*n*sizeof(doublecomplex)); | ||
1134 | integer res; | ||
1135 | zgetrf_ (&m,&n,rp,&m,auxipiv,&res); | ||
1136 | if(res>0) { | ||
1137 | res = 0; // fixme | ||
1138 | } | ||
1139 | CHECK(res,res); | ||
1140 | int k; | ||
1141 | for (k=0; k<mn; k++) { | ||
1142 | ipivp[k] = auxipiv[k]; | ||
1143 | } | ||
1144 | free(auxipiv); | ||
1145 | OK | ||
1146 | } | ||
1147 | |||
1148 | |||
1149 | //////////////////// LU substitution ///////////////////////// | ||
1150 | |||
1151 | /* Subroutine */ int dgetrs_(char *trans, integer *n, integer *nrhs, | ||
1152 | doublereal *a, integer *lda, integer *ipiv, doublereal *b, integer * | ||
1153 | ldb, integer *info); | ||
1154 | |||
1155 | int luS_l_R(KDMAT(a), KDVEC(ipiv), KDMAT(b), DMAT(x)) { | ||
1156 | integer m = ar; | ||
1157 | integer n = ac; | ||
1158 | integer mrhs = br; | ||
1159 | integer nrhs = bc; | ||
1160 | |||
1161 | REQUIRES(m==n && m==mrhs && m==ipivn,BAD_SIZE); | ||
1162 | integer* auxipiv = (integer*)malloc(n*sizeof(integer)); | ||
1163 | int k; | ||
1164 | for (k=0; k<n; k++) { | ||
1165 | auxipiv[k] = (integer)ipivp[k]; | ||
1166 | } | ||
1167 | integer res; | ||
1168 | memcpy(xp,bp,mrhs*nrhs*sizeof(double)); | ||
1169 | dgetrs_ ("N",&n,&nrhs,(/*no const (!?)*/ double*)ap,&m,auxipiv,xp,&mrhs,&res); | ||
1170 | CHECK(res,res); | ||
1171 | free(auxipiv); | ||
1172 | OK | ||
1173 | } | ||
1174 | |||
1175 | |||
1176 | /* Subroutine */ int zgetrs_(char *trans, integer *n, integer *nrhs, | ||
1177 | doublecomplex *a, integer *lda, integer *ipiv, doublecomplex *b, | ||
1178 | integer *ldb, integer *info); | ||
1179 | |||
1180 | int luS_l_C(KCMAT(a), KDVEC(ipiv), KCMAT(b), CMAT(x)) { | ||
1181 | integer m = ar; | ||
1182 | integer n = ac; | ||
1183 | integer mrhs = br; | ||
1184 | integer nrhs = bc; | ||
1185 | |||
1186 | REQUIRES(m==n && m==mrhs && m==ipivn,BAD_SIZE); | ||
1187 | integer* auxipiv = (integer*)malloc(n*sizeof(integer)); | ||
1188 | int k; | ||
1189 | for (k=0; k<n; k++) { | ||
1190 | auxipiv[k] = (integer)ipivp[k]; | ||
1191 | } | ||
1192 | integer res; | ||
1193 | memcpy(xp,bp,mrhs*nrhs*sizeof(doublecomplex)); | ||
1194 | zgetrs_ ("N",&n,&nrhs,(doublecomplex*)ap,&m,auxipiv,xp,&mrhs,&res); | ||
1195 | CHECK(res,res); | ||
1196 | free(auxipiv); | ||
1197 | OK | ||
1198 | } | ||
1199 | |||
1200 | //////////////////// Matrix Product ///////////////////////// | ||
1201 | |||
1202 | void dgemm_(char *, char *, integer *, integer *, integer *, | ||
1203 | double *, const double *, integer *, const double *, | ||
1204 | integer *, double *, double *, integer *); | ||
1205 | |||
1206 | int multiplyR(int ta, int tb, KDMAT(a),KDMAT(b),DMAT(r)) { | ||
1207 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
1208 | DEBUGMSG("dgemm_"); | ||
1209 | CHECKNANR(a,"NaN multR Input\n") | ||
1210 | CHECKNANR(b,"NaN multR Input\n") | ||
1211 | integer m = ta?ac:ar; | ||
1212 | integer n = tb?br:bc; | ||
1213 | integer k = ta?ar:ac; | ||
1214 | integer lda = ar; | ||
1215 | integer ldb = br; | ||
1216 | integer ldc = rr; | ||
1217 | double alpha = 1; | ||
1218 | double beta = 0; | ||
1219 | dgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc); | ||
1220 | CHECKNANR(r,"NaN multR Output\n") | ||
1221 | OK | ||
1222 | } | ||
1223 | |||
1224 | void zgemm_(char *, char *, integer *, integer *, integer *, | ||
1225 | doublecomplex *, const doublecomplex *, integer *, const doublecomplex *, | ||
1226 | integer *, doublecomplex *, doublecomplex *, integer *); | ||
1227 | |||
1228 | int multiplyC(int ta, int tb, KCMAT(a),KCMAT(b),CMAT(r)) { | ||
1229 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
1230 | DEBUGMSG("zgemm_"); | ||
1231 | CHECKNANC(a,"NaN multC Input\n") | ||
1232 | CHECKNANC(b,"NaN multC Input\n") | ||
1233 | integer m = ta?ac:ar; | ||
1234 | integer n = tb?br:bc; | ||
1235 | integer k = ta?ar:ac; | ||
1236 | integer lda = ar; | ||
1237 | integer ldb = br; | ||
1238 | integer ldc = rr; | ||
1239 | doublecomplex alpha = {1,0}; | ||
1240 | doublecomplex beta = {0,0}; | ||
1241 | zgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha, | ||
1242 | ap,&lda, | ||
1243 | bp,&ldb,&beta, | ||
1244 | rp,&ldc); | ||
1245 | CHECKNANC(r,"NaN multC Output\n") | ||
1246 | OK | ||
1247 | } | ||
1248 | |||
1249 | void sgemm_(char *, char *, integer *, integer *, integer *, | ||
1250 | float *, const float *, integer *, const float *, | ||
1251 | integer *, float *, float *, integer *); | ||
1252 | |||
1253 | int multiplyF(int ta, int tb, KFMAT(a),KFMAT(b),FMAT(r)) { | ||
1254 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
1255 | DEBUGMSG("sgemm_"); | ||
1256 | integer m = ta?ac:ar; | ||
1257 | integer n = tb?br:bc; | ||
1258 | integer k = ta?ar:ac; | ||
1259 | integer lda = ar; | ||
1260 | integer ldb = br; | ||
1261 | integer ldc = rr; | ||
1262 | float alpha = 1; | ||
1263 | float beta = 0; | ||
1264 | sgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc); | ||
1265 | OK | ||
1266 | } | ||
1267 | |||
1268 | void cgemm_(char *, char *, integer *, integer *, integer *, | ||
1269 | complex *, const complex *, integer *, const complex *, | ||
1270 | integer *, complex *, complex *, integer *); | ||
1271 | |||
1272 | int multiplyQ(int ta, int tb, KQMAT(a),KQMAT(b),QMAT(r)) { | ||
1273 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
1274 | DEBUGMSG("cgemm_"); | ||
1275 | integer m = ta?ac:ar; | ||
1276 | integer n = tb?br:bc; | ||
1277 | integer k = ta?ar:ac; | ||
1278 | integer lda = ar; | ||
1279 | integer ldb = br; | ||
1280 | integer ldc = rr; | ||
1281 | complex alpha = {1,0}; | ||
1282 | complex beta = {0,0}; | ||
1283 | cgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha, | ||
1284 | ap,&lda, | ||
1285 | bp,&ldb,&beta, | ||
1286 | rp,&ldc); | ||
1287 | OK | ||
1288 | } | ||
1289 | |||
1290 | //////////////////// transpose ///////////////////////// | ||
1291 | |||
1292 | int transF(KFMAT(x),FMAT(t)) { | ||
1293 | REQUIRES(xr==tc && xc==tr,BAD_SIZE); | ||
1294 | DEBUGMSG("transF"); | ||
1295 | int i,j; | ||
1296 | for (i=0; i<tr; i++) { | ||
1297 | for (j=0; j<tc; j++) { | ||
1298 | tp[i*tc+j] = xp[j*xc+i]; | ||
1299 | } | ||
1300 | } | ||
1301 | OK | ||
1302 | } | ||
1303 | |||
1304 | int transR(KDMAT(x),DMAT(t)) { | ||
1305 | REQUIRES(xr==tc && xc==tr,BAD_SIZE); | ||
1306 | DEBUGMSG("transR"); | ||
1307 | int i,j; | ||
1308 | for (i=0; i<tr; i++) { | ||
1309 | for (j=0; j<tc; j++) { | ||
1310 | tp[i*tc+j] = xp[j*xc+i]; | ||
1311 | } | ||
1312 | } | ||
1313 | OK | ||
1314 | } | ||
1315 | |||
1316 | int transQ(KQMAT(x),QMAT(t)) { | ||
1317 | REQUIRES(xr==tc && xc==tr,BAD_SIZE); | ||
1318 | DEBUGMSG("transQ"); | ||
1319 | int i,j; | ||
1320 | for (i=0; i<tr; i++) { | ||
1321 | for (j=0; j<tc; j++) { | ||
1322 | tp[i*tc+j] = xp[j*xc+i]; | ||
1323 | } | ||
1324 | } | ||
1325 | OK | ||
1326 | } | ||
1327 | |||
1328 | int transC(KCMAT(x),CMAT(t)) { | ||
1329 | REQUIRES(xr==tc && xc==tr,BAD_SIZE); | ||
1330 | DEBUGMSG("transC"); | ||
1331 | int i,j; | ||
1332 | for (i=0; i<tr; i++) { | ||
1333 | for (j=0; j<tc; j++) { | ||
1334 | tp[i*tc+j] = xp[j*xc+i]; | ||
1335 | } | ||
1336 | } | ||
1337 | OK | ||
1338 | } | ||
1339 | |||
1340 | int transP(KPMAT(x), PMAT(t)) { | ||
1341 | REQUIRES(xr==tc && xc==tr,BAD_SIZE); | ||
1342 | REQUIRES(xs==ts,NOCONVER); | ||
1343 | DEBUGMSG("transP"); | ||
1344 | int i,j; | ||
1345 | for (i=0; i<tr; i++) { | ||
1346 | for (j=0; j<tc; j++) { | ||
1347 | memcpy(tp+(i*tc+j)*xs,xp +(j*xc+i)*xs,xs); | ||
1348 | } | ||
1349 | } | ||
1350 | OK | ||
1351 | } | ||
1352 | |||
1353 | //////////////////// constant ///////////////////////// | ||
1354 | |||
1355 | int constantF(float * pval, FVEC(r)) { | ||
1356 | DEBUGMSG("constantF") | ||
1357 | int k; | ||
1358 | double val = *pval; | ||
1359 | for(k=0;k<rn;k++) { | ||
1360 | rp[k]=val; | ||
1361 | } | ||
1362 | OK | ||
1363 | } | ||
1364 | |||
1365 | int constantR(double * pval, DVEC(r)) { | ||
1366 | DEBUGMSG("constantR") | ||
1367 | int k; | ||
1368 | double val = *pval; | ||
1369 | for(k=0;k<rn;k++) { | ||
1370 | rp[k]=val; | ||
1371 | } | ||
1372 | OK | ||
1373 | } | ||
1374 | |||
1375 | int constantQ(complex* pval, QVEC(r)) { | ||
1376 | DEBUGMSG("constantQ") | ||
1377 | int k; | ||
1378 | complex val = *pval; | ||
1379 | for(k=0;k<rn;k++) { | ||
1380 | rp[k]=val; | ||
1381 | } | ||
1382 | OK | ||
1383 | } | ||
1384 | |||
1385 | int constantC(doublecomplex* pval, CVEC(r)) { | ||
1386 | DEBUGMSG("constantC") | ||
1387 | int k; | ||
1388 | doublecomplex val = *pval; | ||
1389 | for(k=0;k<rn;k++) { | ||
1390 | rp[k]=val; | ||
1391 | } | ||
1392 | OK | ||
1393 | } | ||
1394 | |||
1395 | int constantP(void* pval, PVEC(r)) { | ||
1396 | DEBUGMSG("constantP") | ||
1397 | int k; | ||
1398 | for(k=0;k<rn;k++) { | ||
1399 | memcpy(rp+k*rs,pval,rs); | ||
1400 | } | ||
1401 | OK | ||
1402 | } | ||
1403 | |||
1404 | //////////////////// float-double conversion ///////////////////////// | ||
1405 | |||
1406 | int float2double(FVEC(x),DVEC(y)) { | ||
1407 | DEBUGMSG("float2double") | ||
1408 | int k; | ||
1409 | for(k=0;k<xn;k++) { | ||
1410 | yp[k]=xp[k]; | ||
1411 | } | ||
1412 | OK | ||
1413 | } | ||
1414 | |||
1415 | int double2float(DVEC(x),FVEC(y)) { | ||
1416 | DEBUGMSG("double2float") | ||
1417 | int k; | ||
1418 | for(k=0;k<xn;k++) { | ||
1419 | yp[k]=xp[k]; | ||
1420 | } | ||
1421 | OK | ||
1422 | } | ||
1423 | |||
1424 | //////////////////// conjugate ///////////////////////// | ||
1425 | |||
1426 | int conjugateQ(KQVEC(x),QVEC(t)) { | ||
1427 | REQUIRES(xn==tn,BAD_SIZE); | ||
1428 | DEBUGMSG("conjugateQ"); | ||
1429 | int k; | ||
1430 | for(k=0;k<xn;k++) { | ||
1431 | tp[k].r = xp[k].r; | ||
1432 | tp[k].i = -xp[k].i; | ||
1433 | } | ||
1434 | OK | ||
1435 | } | ||
1436 | |||
1437 | int conjugateC(KCVEC(x),CVEC(t)) { | ||
1438 | REQUIRES(xn==tn,BAD_SIZE); | ||
1439 | DEBUGMSG("conjugateC"); | ||
1440 | int k; | ||
1441 | for(k=0;k<xn;k++) { | ||
1442 | tp[k].r = xp[k].r; | ||
1443 | tp[k].i = -xp[k].i; | ||
1444 | } | ||
1445 | OK | ||
1446 | } | ||
1447 | |||
1448 | //////////////////// step ///////////////////////// | ||
1449 | |||
1450 | int stepF(FVEC(x),FVEC(y)) { | ||
1451 | DEBUGMSG("stepF") | ||
1452 | int k; | ||
1453 | for(k=0;k<xn;k++) { | ||
1454 | yp[k]=xp[k]>0; | ||
1455 | } | ||
1456 | OK | ||
1457 | } | ||
1458 | |||
1459 | int stepD(DVEC(x),DVEC(y)) { | ||
1460 | DEBUGMSG("stepD") | ||
1461 | int k; | ||
1462 | for(k=0;k<xn;k++) { | ||
1463 | yp[k]=xp[k]>0; | ||
1464 | } | ||
1465 | OK | ||
1466 | } | ||
1467 | |||
1468 | //////////////////// cond ///////////////////////// | ||
1469 | |||
1470 | int condF(FVEC(x),FVEC(y),FVEC(lt),FVEC(eq),FVEC(gt),FVEC(r)) { | ||
1471 | REQUIRES(xn==yn && xn==ltn && xn==eqn && xn==gtn && xn==rn ,BAD_SIZE); | ||
1472 | DEBUGMSG("condF") | ||
1473 | int k; | ||
1474 | for(k=0;k<xn;k++) { | ||
1475 | rp[k] = xp[k]<yp[k]?ltp[k]:(xp[k]>yp[k]?gtp[k]:eqp[k]); | ||
1476 | } | ||
1477 | OK | ||
1478 | } | ||
1479 | |||
1480 | int condD(DVEC(x),DVEC(y),DVEC(lt),DVEC(eq),DVEC(gt),DVEC(r)) { | ||
1481 | REQUIRES(xn==yn && xn==ltn && xn==eqn && xn==gtn && xn==rn ,BAD_SIZE); | ||
1482 | DEBUGMSG("condD") | ||
1483 | int k; | ||
1484 | for(k=0;k<xn;k++) { | ||
1485 | rp[k] = xp[k]<yp[k]?ltp[k]:(xp[k]>yp[k]?gtp[k]:eqp[k]); | ||
1486 | } | ||
1487 | OK | ||
1488 | } | ||
1489 | |||
diff --git a/packages/base/src/C/lapack-aux.h b/packages/base/src/C/lapack-aux.h new file mode 100644 index 0000000..a3f1899 --- /dev/null +++ b/packages/base/src/C/lapack-aux.h | |||
@@ -0,0 +1,60 @@ | |||
1 | /* | ||
2 | * We have copied the definitions in f2c.h required | ||
3 | * to compile clapack.h, modified to support both | ||
4 | * 32 and 64 bit | ||
5 | |||
6 | http://opengrok.creo.hu/dragonfly/xref/src/contrib/gcc-3.4/libf2c/readme.netlib | ||
7 | http://www.ibm.com/developerworks/library/l-port64.html | ||
8 | */ | ||
9 | |||
10 | #ifdef _LP64 | ||
11 | typedef int integer; | ||
12 | typedef unsigned int uinteger; | ||
13 | typedef int logical; | ||
14 | typedef long longint; /* system-dependent */ | ||
15 | typedef unsigned long ulongint; /* system-dependent */ | ||
16 | #else | ||
17 | typedef long int integer; | ||
18 | typedef unsigned long int uinteger; | ||
19 | typedef long int logical; | ||
20 | typedef long long longint; /* system-dependent */ | ||
21 | typedef unsigned long long ulongint; /* system-dependent */ | ||
22 | #endif | ||
23 | |||
24 | typedef char *address; | ||
25 | typedef short int shortint; | ||
26 | typedef float real; | ||
27 | typedef double doublereal; | ||
28 | typedef struct { real r, i; } complex; | ||
29 | typedef struct { doublereal r, i; } doublecomplex; | ||
30 | typedef short int shortlogical; | ||
31 | typedef char logical1; | ||
32 | typedef char integer1; | ||
33 | |||
34 | typedef logical (*L_fp)(); | ||
35 | typedef short ftnlen; | ||
36 | |||
37 | /********************************************************/ | ||
38 | |||
39 | #define FVEC(A) int A##n, float*A##p | ||
40 | #define DVEC(A) int A##n, double*A##p | ||
41 | #define QVEC(A) int A##n, complex*A##p | ||
42 | #define CVEC(A) int A##n, doublecomplex*A##p | ||
43 | #define PVEC(A) int A##n, void* A##p, int A##s | ||
44 | #define FMAT(A) int A##r, int A##c, float* A##p | ||
45 | #define DMAT(A) int A##r, int A##c, double* A##p | ||
46 | #define QMAT(A) int A##r, int A##c, complex* A##p | ||
47 | #define CMAT(A) int A##r, int A##c, doublecomplex* A##p | ||
48 | #define PMAT(A) int A##r, int A##c, void* A##p, int A##s | ||
49 | |||
50 | #define KFVEC(A) int A##n, const float*A##p | ||
51 | #define KDVEC(A) int A##n, const double*A##p | ||
52 | #define KQVEC(A) int A##n, const complex*A##p | ||
53 | #define KCVEC(A) int A##n, const doublecomplex*A##p | ||
54 | #define KPVEC(A) int A##n, const void* A##p, int A##s | ||
55 | #define KFMAT(A) int A##r, int A##c, const float* A##p | ||
56 | #define KDMAT(A) int A##r, int A##c, const double* A##p | ||
57 | #define KQMAT(A) int A##r, int A##c, const complex* A##p | ||
58 | #define KCMAT(A) int A##r, int A##c, const doublecomplex* A##p | ||
59 | #define KPMAT(A) int A##r, int A##c, const void* A##p, int A##s | ||
60 | |||
diff --git a/packages/base/src/Data/Packed.hs b/packages/base/src/Data/Packed.hs new file mode 100644 index 0000000..c66718a --- /dev/null +++ b/packages/base/src/Data/Packed.hs | |||
@@ -0,0 +1,25 @@ | |||
1 | ----------------------------------------------------------------------------- | ||
2 | {- | | ||
3 | Module : Data.Packed | ||
4 | Copyright : (c) Alberto Ruiz 2006-2014 | ||
5 | License : BSD3 | ||
6 | Maintainer : Alberto Ruiz | ||
7 | Stability : provisional | ||
8 | |||
9 | Types for dense 'Vector' and 'Matrix' of 'Storable' elements. | ||
10 | |||
11 | -} | ||
12 | ----------------------------------------------------------------------------- | ||
13 | |||
14 | module Data.Packed ( | ||
15 | -- * Vector | ||
16 | -- | ||
17 | -- | Vectors are @Data.Vector.Storable.Vector@ from the \"vector\" package. | ||
18 | module Data.Packed.Vector, | ||
19 | -- * Matrix | ||
20 | module Data.Packed.Matrix, | ||
21 | ) where | ||
22 | |||
23 | import Data.Packed.Vector | ||
24 | import Data.Packed.Matrix | ||
25 | |||
diff --git a/packages/base/src/Data/Packed/Development.hs b/packages/base/src/Data/Packed/Development.hs new file mode 100644 index 0000000..777b6c5 --- /dev/null +++ b/packages/base/src/Data/Packed/Development.hs | |||
@@ -0,0 +1,31 @@ | |||
1 | |||
2 | ----------------------------------------------------------------------------- | ||
3 | -- | | ||
4 | -- Module : Data.Packed.Development | ||
5 | -- Copyright : (c) Alberto Ruiz 2009 | ||
6 | -- License : BSD3 | ||
7 | -- Maintainer : Alberto Ruiz | ||
8 | -- Stability : provisional | ||
9 | -- Portability : portable | ||
10 | -- | ||
11 | -- The library can be easily extended with additional foreign functions | ||
12 | -- using the tools in this module. Illustrative usage examples can be found | ||
13 | -- in the @examples\/devel@ folder included in the package. | ||
14 | -- | ||
15 | ----------------------------------------------------------------------------- | ||
16 | |||
17 | module Data.Packed.Development ( | ||
18 | createVector, createMatrix, | ||
19 | vec, mat, | ||
20 | app1, app2, app3, app4, | ||
21 | app5, app6, app7, app8, app9, app10, | ||
22 | MatrixOrder(..), orderOf, cmat, fmat, | ||
23 | matrixFromVector, | ||
24 | unsafeFromForeignPtr, | ||
25 | unsafeToForeignPtr, | ||
26 | check, (//), | ||
27 | at', atM' | ||
28 | ) where | ||
29 | |||
30 | import Data.Packed.Internal | ||
31 | |||
diff --git a/packages/base/src/Data/Packed/Foreign.hs b/packages/base/src/Data/Packed/Foreign.hs new file mode 100644 index 0000000..efa51ca --- /dev/null +++ b/packages/base/src/Data/Packed/Foreign.hs | |||
@@ -0,0 +1,99 @@ | |||
1 | {-# LANGUAGE MagicHash, UnboxedTuples #-} | ||
2 | -- | FFI and hmatrix helpers. | ||
3 | -- | ||
4 | -- Sample usage, to upload a perspective matrix to a shader. | ||
5 | -- | ||
6 | -- @ glUniformMatrix4fv 0 1 (fromIntegral gl_TRUE) \`appMatrix\` perspective 0.01 100 (pi\/2) (4\/3) | ||
7 | -- @ | ||
8 | -- | ||
9 | module Data.Packed.Foreign | ||
10 | ( app | ||
11 | , appVector, appVectorLen | ||
12 | , appMatrix, appMatrixLen, appMatrixRaw, appMatrixRawLen | ||
13 | , unsafeMatrixToVector, unsafeMatrixToForeignPtr | ||
14 | ) where | ||
15 | import Data.Packed.Internal | ||
16 | import qualified Data.Vector.Storable as S | ||
17 | import Foreign (Ptr, ForeignPtr, Storable) | ||
18 | import Foreign.C.Types (CInt) | ||
19 | import GHC.Base (IO(..), realWorld#) | ||
20 | |||
21 | {-# INLINE unsafeInlinePerformIO #-} | ||
22 | -- | If we use unsafePerformIO, it may not get inlined, so in a function that returns IO (which are all safe uses of app* in this module), there would be | ||
23 | -- unecessary calls to unsafePerformIO or its internals. | ||
24 | unsafeInlinePerformIO :: IO a -> a | ||
25 | unsafeInlinePerformIO (IO f) = case f realWorld# of | ||
26 | (# _, x #) -> x | ||
27 | |||
28 | {-# INLINE app #-} | ||
29 | -- | Only useful since it is left associated with a precedence of 1, unlike 'Prelude.$', which is right associative. | ||
30 | -- e.g. | ||
31 | -- | ||
32 | -- @ | ||
33 | -- someFunction | ||
34 | -- \`appMatrixLen\` m | ||
35 | -- \`appVectorLen\` v | ||
36 | -- \`app\` other | ||
37 | -- \`app\` arguments | ||
38 | -- \`app\` go here | ||
39 | -- @ | ||
40 | -- | ||
41 | -- One could also write: | ||
42 | -- | ||
43 | -- @ | ||
44 | -- (someFunction | ||
45 | -- \`appMatrixLen\` m | ||
46 | -- \`appVectorLen\` v) | ||
47 | -- other | ||
48 | -- arguments | ||
49 | -- (go here) | ||
50 | -- @ | ||
51 | -- | ||
52 | app :: (a -> b) -> a -> b | ||
53 | app f = f | ||
54 | |||
55 | {-# INLINE appVector #-} | ||
56 | appVector :: Storable a => (Ptr a -> b) -> Vector a -> b | ||
57 | appVector f x = unsafeInlinePerformIO (S.unsafeWith x (return . f)) | ||
58 | |||
59 | {-# INLINE appVectorLen #-} | ||
60 | appVectorLen :: Storable a => (CInt -> Ptr a -> b) -> Vector a -> b | ||
61 | appVectorLen f x = unsafeInlinePerformIO (S.unsafeWith x (return . f (fromIntegral (S.length x)))) | ||
62 | |||
63 | {-# INLINE appMatrix #-} | ||
64 | appMatrix :: Element a => (Ptr a -> b) -> Matrix a -> b | ||
65 | appMatrix f x = unsafeInlinePerformIO (S.unsafeWith (flatten x) (return . f)) | ||
66 | |||
67 | {-# INLINE appMatrixLen #-} | ||
68 | appMatrixLen :: Element a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b | ||
69 | appMatrixLen f x = unsafeInlinePerformIO (S.unsafeWith (flatten x) (return . f r c)) | ||
70 | where | ||
71 | r = fromIntegral (rows x) | ||
72 | c = fromIntegral (cols x) | ||
73 | |||
74 | {-# INLINE appMatrixRaw #-} | ||
75 | appMatrixRaw :: Storable a => (Ptr a -> b) -> Matrix a -> b | ||
76 | appMatrixRaw f x = unsafeInlinePerformIO (S.unsafeWith (xdat x) (return . f)) | ||
77 | |||
78 | {-# INLINE appMatrixRawLen #-} | ||
79 | appMatrixRawLen :: Element a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b | ||
80 | appMatrixRawLen f x = unsafeInlinePerformIO (S.unsafeWith (xdat x) (return . f r c)) | ||
81 | where | ||
82 | r = fromIntegral (rows x) | ||
83 | c = fromIntegral (cols x) | ||
84 | |||
85 | infixl 1 `app` | ||
86 | infixl 1 `appVector` | ||
87 | infixl 1 `appMatrix` | ||
88 | infixl 1 `appMatrixRaw` | ||
89 | |||
90 | {-# INLINE unsafeMatrixToVector #-} | ||
91 | -- | This will disregard the order of the matrix, and simply return it as-is. | ||
92 | -- If the order of the matrix is RowMajor, this function is identical to 'flatten'. | ||
93 | unsafeMatrixToVector :: Matrix a -> Vector a | ||
94 | unsafeMatrixToVector = xdat | ||
95 | |||
96 | {-# INLINE unsafeMatrixToForeignPtr #-} | ||
97 | unsafeMatrixToForeignPtr :: Storable a => Matrix a -> (ForeignPtr a, Int) | ||
98 | unsafeMatrixToForeignPtr m = S.unsafeToForeignPtr0 (xdat m) | ||
99 | |||
diff --git a/packages/base/src/Data/Packed/Internal.hs b/packages/base/src/Data/Packed/Internal.hs new file mode 100644 index 0000000..537e51e --- /dev/null +++ b/packages/base/src/Data/Packed/Internal.hs | |||
@@ -0,0 +1,26 @@ | |||
1 | ----------------------------------------------------------------------------- | ||
2 | -- | | ||
3 | -- Module : Data.Packed.Internal | ||
4 | -- Copyright : (c) Alberto Ruiz 2007 | ||
5 | -- License : GPL-style | ||
6 | -- | ||
7 | -- Maintainer : Alberto Ruiz <aruiz@um.es> | ||
8 | -- Stability : provisional | ||
9 | -- Portability : portable | ||
10 | -- | ||
11 | -- Reexports all internal modules | ||
12 | -- | ||
13 | ----------------------------------------------------------------------------- | ||
14 | -- #hide | ||
15 | |||
16 | module Data.Packed.Internal ( | ||
17 | module Data.Packed.Internal.Common, | ||
18 | module Data.Packed.Internal.Signatures, | ||
19 | module Data.Packed.Internal.Vector, | ||
20 | module Data.Packed.Internal.Matrix, | ||
21 | ) where | ||
22 | |||
23 | import Data.Packed.Internal.Common | ||
24 | import Data.Packed.Internal.Signatures | ||
25 | import Data.Packed.Internal.Vector | ||
26 | import Data.Packed.Internal.Matrix | ||
diff --git a/packages/base/src/Data/Packed/Internal/Common.hs b/packages/base/src/Data/Packed/Internal/Common.hs new file mode 100644 index 0000000..615bbdf --- /dev/null +++ b/packages/base/src/Data/Packed/Internal/Common.hs | |||
@@ -0,0 +1,160 @@ | |||
1 | {-# LANGUAGE CPP #-} | ||
2 | -- | | ||
3 | -- Module : Data.Packed.Internal.Common | ||
4 | -- Copyright : (c) Alberto Ruiz 2007 | ||
5 | -- License : BSD3 | ||
6 | -- Maintainer : Alberto Ruiz | ||
7 | -- Stability : provisional | ||
8 | -- | ||
9 | -- | ||
10 | -- Development utilities. | ||
11 | -- | ||
12 | |||
13 | |||
14 | module Data.Packed.Internal.Common( | ||
15 | Adapt, | ||
16 | app1, app2, app3, app4, | ||
17 | app5, app6, app7, app8, app9, app10, | ||
18 | (//), check, mbCatch, | ||
19 | splitEvery, common, compatdim, | ||
20 | fi, | ||
21 | table, | ||
22 | finit | ||
23 | ) where | ||
24 | |||
25 | import Control.Monad(when) | ||
26 | import Foreign.C.Types | ||
27 | import Foreign.Storable.Complex() | ||
28 | import Data.List(transpose,intersperse) | ||
29 | import Control.Exception as E | ||
30 | |||
31 | -- | @splitEvery 3 [1..9] == [[1,2,3],[4,5,6],[7,8,9]]@ | ||
32 | splitEvery :: Int -> [a] -> [[a]] | ||
33 | splitEvery _ [] = [] | ||
34 | splitEvery k l = take k l : splitEvery k (drop k l) | ||
35 | |||
36 | -- | obtains the common value of a property of a list | ||
37 | common :: (Eq a) => (b->a) -> [b] -> Maybe a | ||
38 | common f = commonval . map f where | ||
39 | commonval :: (Eq a) => [a] -> Maybe a | ||
40 | commonval [] = Nothing | ||
41 | commonval [a] = Just a | ||
42 | commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing | ||
43 | |||
44 | -- | common value with \"adaptable\" 1 | ||
45 | compatdim :: [Int] -> Maybe Int | ||
46 | compatdim [] = Nothing | ||
47 | compatdim [a] = Just a | ||
48 | compatdim (a:b:xs) | ||
49 | | a==b = compatdim (b:xs) | ||
50 | | a==1 = compatdim (b:xs) | ||
51 | | b==1 = compatdim (a:xs) | ||
52 | | otherwise = Nothing | ||
53 | |||
54 | -- | Formatting tool | ||
55 | table :: String -> [[String]] -> String | ||
56 | table sep as = unlines . map unwords' $ transpose mtp where | ||
57 | mt = transpose as | ||
58 | longs = map (maximum . map length) mt | ||
59 | mtp = zipWith (\a b -> map (pad a) b) longs mt | ||
60 | pad n str = replicate (n - length str) ' ' ++ str | ||
61 | unwords' = concat . intersperse sep | ||
62 | |||
63 | -- | postfix function application (@flip ($)@) | ||
64 | (//) :: x -> (x -> y) -> y | ||
65 | infixl 0 // | ||
66 | (//) = flip ($) | ||
67 | |||
68 | -- | specialized fromIntegral | ||
69 | fi :: Int -> CInt | ||
70 | fi = fromIntegral | ||
71 | |||
72 | -- hmm.. | ||
73 | ww2 w1 o1 w2 o2 f = w1 o1 $ w2 o2 . f | ||
74 | ww3 w1 o1 w2 o2 w3 o3 f = w1 o1 $ ww2 w2 o2 w3 o3 . f | ||
75 | ww4 w1 o1 w2 o2 w3 o3 w4 o4 f = w1 o1 $ ww3 w2 o2 w3 o3 w4 o4 . f | ||
76 | ww5 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 f = w1 o1 $ ww4 w2 o2 w3 o3 w4 o4 w5 o5 . f | ||
77 | ww6 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 f = w1 o1 $ ww5 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 . f | ||
78 | ww7 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 f = w1 o1 $ ww6 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 . f | ||
79 | ww8 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 f = w1 o1 $ ww7 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 . f | ||
80 | ww9 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 f = w1 o1 $ ww8 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 . f | ||
81 | ww10 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 w10 o10 f = w1 o1 $ ww9 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 w10 o10 . f | ||
82 | |||
83 | type Adapt f t r = t -> ((f -> r) -> IO()) -> IO() | ||
84 | |||
85 | type Adapt1 f t1 = Adapt f t1 (IO CInt) -> t1 -> String -> IO() | ||
86 | type Adapt2 f t1 r1 t2 = Adapt f t1 r1 -> t1 -> Adapt1 r1 t2 | ||
87 | type Adapt3 f t1 r1 t2 r2 t3 = Adapt f t1 r1 -> t1 -> Adapt2 r1 t2 r2 t3 | ||
88 | type Adapt4 f t1 r1 t2 r2 t3 r3 t4 = Adapt f t1 r1 -> t1 -> Adapt3 r1 t2 r2 t3 r3 t4 | ||
89 | type Adapt5 f t1 r1 t2 r2 t3 r3 t4 r4 t5 = Adapt f t1 r1 -> t1 -> Adapt4 r1 t2 r2 t3 r3 t4 r4 t5 | ||
90 | type Adapt6 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 = Adapt f t1 r1 -> t1 -> Adapt5 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 | ||
91 | type Adapt7 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 = Adapt f t1 r1 -> t1 -> Adapt6 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 | ||
92 | type Adapt8 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 = Adapt f t1 r1 -> t1 -> Adapt7 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 | ||
93 | type Adapt9 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 = Adapt f t1 r1 -> t1 -> Adapt8 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 | ||
94 | type Adapt10 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 r9 t10 = Adapt f t1 r1 -> t1 -> Adapt9 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 r9 t10 | ||
95 | |||
96 | app1 :: f -> Adapt1 f t1 | ||
97 | app2 :: f -> Adapt2 f t1 r1 t2 | ||
98 | app3 :: f -> Adapt3 f t1 r1 t2 r2 t3 | ||
99 | app4 :: f -> Adapt4 f t1 r1 t2 r2 t3 r3 t4 | ||
100 | app5 :: f -> Adapt5 f t1 r1 t2 r2 t3 r3 t4 r4 t5 | ||
101 | app6 :: f -> Adapt6 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 | ||
102 | app7 :: f -> Adapt7 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 | ||
103 | app8 :: f -> Adapt8 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 | ||
104 | app9 :: f -> Adapt9 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 | ||
105 | app10 :: f -> Adapt10 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 r9 t10 | ||
106 | |||
107 | app1 f w1 o1 s = w1 o1 $ \a1 -> f // a1 // check s | ||
108 | app2 f w1 o1 w2 o2 s = ww2 w1 o1 w2 o2 $ \a1 a2 -> f // a1 // a2 // check s | ||
109 | app3 f w1 o1 w2 o2 w3 o3 s = ww3 w1 o1 w2 o2 w3 o3 $ | ||
110 | \a1 a2 a3 -> f // a1 // a2 // a3 // check s | ||
111 | app4 f w1 o1 w2 o2 w3 o3 w4 o4 s = ww4 w1 o1 w2 o2 w3 o3 w4 o4 $ | ||
112 | \a1 a2 a3 a4 -> f // a1 // a2 // a3 // a4 // check s | ||
113 | app5 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 s = ww5 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 $ | ||
114 | \a1 a2 a3 a4 a5 -> f // a1 // a2 // a3 // a4 // a5 // check s | ||
115 | app6 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 s = ww6 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 $ | ||
116 | \a1 a2 a3 a4 a5 a6 -> f // a1 // a2 // a3 // a4 // a5 // a6 // check s | ||
117 | app7 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 s = ww7 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 $ | ||
118 | \a1 a2 a3 a4 a5 a6 a7 -> f // a1 // a2 // a3 // a4 // a5 // a6 // a7 // check s | ||
119 | app8 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 s = ww8 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 $ | ||
120 | \a1 a2 a3 a4 a5 a6 a7 a8 -> f // a1 // a2 // a3 // a4 // a5 // a6 // a7 // a8 // check s | ||
121 | app9 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 s = ww9 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 $ | ||
122 | \a1 a2 a3 a4 a5 a6 a7 a8 a9 -> f // a1 // a2 // a3 // a4 // a5 // a6 // a7 // a8 // a9 // check s | ||
123 | app10 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 w10 o10 s = ww10 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 w10 o10 $ | ||
124 | \a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 -> f // a1 // a2 // a3 // a4 // a5 // a6 // a7 // a8 // a9 // a10 // check s | ||
125 | |||
126 | |||
127 | |||
128 | -- GSL error codes are <= 1024 | ||
129 | -- | error codes for the auxiliary functions required by the wrappers | ||
130 | errorCode :: CInt -> String | ||
131 | errorCode 2000 = "bad size" | ||
132 | errorCode 2001 = "bad function code" | ||
133 | errorCode 2002 = "memory problem" | ||
134 | errorCode 2003 = "bad file" | ||
135 | errorCode 2004 = "singular" | ||
136 | errorCode 2005 = "didn't converge" | ||
137 | errorCode 2006 = "the input matrix is not positive definite" | ||
138 | errorCode 2007 = "not yet supported in this OS" | ||
139 | errorCode n = "code "++show n | ||
140 | |||
141 | |||
142 | -- | clear the fpu | ||
143 | foreign import ccall unsafe "asm_finit" finit :: IO () | ||
144 | |||
145 | -- | check the error code | ||
146 | check :: String -> IO CInt -> IO () | ||
147 | check msg f = do | ||
148 | #if FINIT | ||
149 | finit | ||
150 | #endif | ||
151 | err <- f | ||
152 | when (err/=0) $ error (msg++": "++errorCode err) | ||
153 | return () | ||
154 | |||
155 | -- | Error capture and conversion to Maybe | ||
156 | mbCatch :: IO x -> IO (Maybe x) | ||
157 | mbCatch act = E.catch (Just `fmap` act) f | ||
158 | where f :: SomeException -> IO (Maybe x) | ||
159 | f _ = return Nothing | ||
160 | |||
diff --git a/packages/base/src/Data/Packed/Internal/Matrix.hs b/packages/base/src/Data/Packed/Internal/Matrix.hs new file mode 100644 index 0000000..9b831cc --- /dev/null +++ b/packages/base/src/Data/Packed/Internal/Matrix.hs | |||
@@ -0,0 +1,422 @@ | |||
1 | {-# LANGUAGE ForeignFunctionInterface #-} | ||
2 | {-# LANGUAGE FlexibleContexts #-} | ||
3 | {-# LANGUAGE FlexibleInstances #-} | ||
4 | {-# LANGUAGE BangPatterns #-} | ||
5 | |||
6 | -- | | ||
7 | -- Module : Data.Packed.Internal.Matrix | ||
8 | -- Copyright : (c) Alberto Ruiz 2007 | ||
9 | -- License : BSD3 | ||
10 | -- Maintainer : Alberto Ruiz | ||
11 | -- Stability : provisional | ||
12 | -- | ||
13 | -- Internal matrix representation | ||
14 | -- | ||
15 | |||
16 | module Data.Packed.Internal.Matrix( | ||
17 | Matrix(..), rows, cols, cdat, fdat, | ||
18 | MatrixOrder(..), orderOf, | ||
19 | createMatrix, mat, | ||
20 | cmat, fmat, | ||
21 | toLists, flatten, reshape, | ||
22 | Element(..), | ||
23 | trans, | ||
24 | fromRows, toRows, fromColumns, toColumns, | ||
25 | matrixFromVector, | ||
26 | subMatrix, | ||
27 | liftMatrix, liftMatrix2, | ||
28 | (@@>), atM', | ||
29 | singleton, | ||
30 | emptyM, | ||
31 | size, shSize, conformVs, conformMs, conformVTo, conformMTo | ||
32 | ) where | ||
33 | |||
34 | import Data.Packed.Internal.Common | ||
35 | import Data.Packed.Internal.Signatures | ||
36 | import Data.Packed.Internal.Vector | ||
37 | |||
38 | import Foreign.Marshal.Alloc(alloca, free) | ||
39 | import Foreign.Marshal.Array(newArray) | ||
40 | import Foreign.Ptr(Ptr, castPtr) | ||
41 | import Foreign.Storable(Storable, peekElemOff, pokeElemOff, poke, sizeOf) | ||
42 | import Data.Complex(Complex) | ||
43 | import Foreign.C.Types | ||
44 | import System.IO.Unsafe(unsafePerformIO) | ||
45 | import Control.DeepSeq | ||
46 | |||
47 | ----------------------------------------------------------------- | ||
48 | |||
49 | {- Design considerations for the Matrix Type | ||
50 | ----------------------------------------- | ||
51 | |||
52 | - we must easily handle both row major and column major order, | ||
53 | for bindings to LAPACK and GSL/C | ||
54 | |||
55 | - we'd like to simplify redundant matrix transposes: | ||
56 | - Some of them arise from the order requirements of some functions | ||
57 | - some functions (matrix product) admit transposed arguments | ||
58 | |||
59 | - maybe we don't really need this kind of simplification: | ||
60 | - more complex code | ||
61 | - some computational overhead | ||
62 | - only appreciable gain in code with a lot of redundant transpositions | ||
63 | and cheap matrix computations | ||
64 | |||
65 | - we could carry both the matrix and its (lazily computed) transpose. | ||
66 | This may save some transpositions, but it is necessary to keep track of the | ||
67 | data which is actually computed to be used by functions like the matrix product | ||
68 | which admit both orders. | ||
69 | |||
70 | - but if we need the transposed data and it is not in the structure, we must make | ||
71 | sure that we touch the same foreignptr that is used in the computation. | ||
72 | |||
73 | - a reasonable solution is using two constructors for a matrix. Transposition just | ||
74 | "flips" the constructor. Actual data transposition is not done if followed by a | ||
75 | matrix product or another transpose. | ||
76 | |||
77 | -} | ||
78 | |||
79 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) | ||
80 | |||
81 | transOrder RowMajor = ColumnMajor | ||
82 | transOrder ColumnMajor = RowMajor | ||
83 | {- | Matrix representation suitable for GSL and LAPACK computations. | ||
84 | |||
85 | The elements are stored in a continuous memory array. | ||
86 | |||
87 | -} | ||
88 | |||
89 | data Matrix t = Matrix { irows :: {-# UNPACK #-} !Int | ||
90 | , icols :: {-# UNPACK #-} !Int | ||
91 | , xdat :: {-# UNPACK #-} !(Vector t) | ||
92 | , order :: !MatrixOrder } | ||
93 | -- RowMajor: preferred by C, fdat may require a transposition | ||
94 | -- ColumnMajor: preferred by LAPACK, cdat may require a transposition | ||
95 | |||
96 | cdat = xdat | ||
97 | fdat = xdat | ||
98 | |||
99 | rows :: Matrix t -> Int | ||
100 | rows = irows | ||
101 | |||
102 | cols :: Matrix t -> Int | ||
103 | cols = icols | ||
104 | |||
105 | orderOf :: Matrix t -> MatrixOrder | ||
106 | orderOf = order | ||
107 | |||
108 | |||
109 | -- | Matrix transpose. | ||
110 | trans :: Matrix t -> Matrix t | ||
111 | trans Matrix {irows = r, icols = c, xdat = d, order = o } = Matrix { irows = c, icols = r, xdat = d, order = transOrder o} | ||
112 | |||
113 | cmat :: (Element t) => Matrix t -> Matrix t | ||
114 | cmat m@Matrix{order = RowMajor} = m | ||
115 | cmat Matrix {irows = r, icols = c, xdat = d, order = ColumnMajor } = Matrix { irows = r, icols = c, xdat = transdata r d c, order = RowMajor} | ||
116 | |||
117 | fmat :: (Element t) => Matrix t -> Matrix t | ||
118 | fmat m@Matrix{order = ColumnMajor} = m | ||
119 | fmat Matrix {irows = r, icols = c, xdat = d, order = RowMajor } = Matrix { irows = r, icols = c, xdat = transdata c d r, order = ColumnMajor} | ||
120 | |||
121 | -- C-Haskell matrix adapter | ||
122 | -- mat :: Adapt (CInt -> CInt -> Ptr t -> r) (Matrix t) r | ||
123 | |||
124 | mat :: (Storable t) => Matrix t -> (((CInt -> CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b | ||
125 | mat a f = | ||
126 | unsafeWith (xdat a) $ \p -> do | ||
127 | let m g = do | ||
128 | g (fi (rows a)) (fi (cols a)) p | ||
129 | f m | ||
130 | |||
131 | {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. | ||
132 | |||
133 | >>> flatten (ident 3) | ||
134 | fromList [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0] | ||
135 | |||
136 | -} | ||
137 | flatten :: Element t => Matrix t -> Vector t | ||
138 | flatten = xdat . cmat | ||
139 | |||
140 | {- | ||
141 | type Mt t s = Int -> Int -> Ptr t -> s | ||
142 | |||
143 | infixr 6 ::> | ||
144 | type t ::> s = Mt t s | ||
145 | -} | ||
146 | |||
147 | -- | the inverse of 'Data.Packed.Matrix.fromLists' | ||
148 | toLists :: (Element t) => Matrix t -> [[t]] | ||
149 | toLists m = splitEvery (cols m) . toList . flatten $ m | ||
150 | |||
151 | -- | Create a matrix from a list of vectors. | ||
152 | -- All vectors must have the same dimension, | ||
153 | -- or dimension 1, which is are automatically expanded. | ||
154 | fromRows :: Element t => [Vector t] -> Matrix t | ||
155 | fromRows [] = emptyM 0 0 | ||
156 | fromRows vs = case compatdim (map dim vs) of | ||
157 | Nothing -> error $ "fromRows expects vectors with equal sizes (or singletons), given: " ++ show (map dim vs) | ||
158 | Just 0 -> emptyM r 0 | ||
159 | Just c -> matrixFromVector RowMajor r c . vjoin . map (adapt c) $ vs | ||
160 | where | ||
161 | r = length vs | ||
162 | adapt c v | ||
163 | | c == 0 = fromList[] | ||
164 | | dim v == c = v | ||
165 | | otherwise = constantD (v@>0) c | ||
166 | |||
167 | -- | extracts the rows of a matrix as a list of vectors | ||
168 | toRows :: Element t => Matrix t -> [Vector t] | ||
169 | toRows m | ||
170 | | c == 0 = replicate r (fromList[]) | ||
171 | | otherwise = toRows' 0 | ||
172 | where | ||
173 | v = flatten m | ||
174 | r = rows m | ||
175 | c = cols m | ||
176 | toRows' k | k == r*c = [] | ||
177 | | otherwise = subVector k c v : toRows' (k+c) | ||
178 | |||
179 | -- | Creates a matrix from a list of vectors, as columns | ||
180 | fromColumns :: Element t => [Vector t] -> Matrix t | ||
181 | fromColumns m = trans . fromRows $ m | ||
182 | |||
183 | -- | Creates a list of vectors from the columns of a matrix | ||
184 | toColumns :: Element t => Matrix t -> [Vector t] | ||
185 | toColumns m = toRows . trans $ m | ||
186 | |||
187 | -- | Reads a matrix position. | ||
188 | (@@>) :: Storable t => Matrix t -> (Int,Int) -> t | ||
189 | infixl 9 @@> | ||
190 | m@Matrix {irows = r, icols = c} @@> (i,j) | ||
191 | | safe = if i<0 || i>=r || j<0 || j>=c | ||
192 | then error "matrix indexing out of range" | ||
193 | else atM' m i j | ||
194 | | otherwise = atM' m i j | ||
195 | {-# INLINE (@@>) #-} | ||
196 | |||
197 | -- Unsafe matrix access without range checking | ||
198 | atM' Matrix {icols = c, xdat = v, order = RowMajor} i j = v `at'` (i*c+j) | ||
199 | atM' Matrix {irows = r, xdat = v, order = ColumnMajor} i j = v `at'` (j*r+i) | ||
200 | {-# INLINE atM' #-} | ||
201 | |||
202 | ------------------------------------------------------------------ | ||
203 | |||
204 | matrixFromVector o r c v | ||
205 | | r * c == dim v = m | ||
206 | | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m | ||
207 | where | ||
208 | m = Matrix { irows = r, icols = c, xdat = v, order = o } | ||
209 | |||
210 | -- allocates memory for a new matrix | ||
211 | createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a) | ||
212 | createMatrix ord r c = do | ||
213 | p <- createVector (r*c) | ||
214 | return (matrixFromVector ord r c p) | ||
215 | |||
216 | {- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. (GNU-Octave groups by columns. To do it you can define @reshapeF r = trans . reshape r@ | ||
217 | where r is the desired number of rows.) | ||
218 | |||
219 | >>> reshape 4 (fromList [1..12]) | ||
220 | (3><4) | ||
221 | [ 1.0, 2.0, 3.0, 4.0 | ||
222 | , 5.0, 6.0, 7.0, 8.0 | ||
223 | , 9.0, 10.0, 11.0, 12.0 ] | ||
224 | |||
225 | -} | ||
226 | reshape :: Storable t => Int -> Vector t -> Matrix t | ||
227 | reshape 0 v = matrixFromVector RowMajor 0 0 v | ||
228 | reshape c v = matrixFromVector RowMajor (dim v `div` c) c v | ||
229 | |||
230 | singleton x = reshape 1 (fromList [x]) | ||
231 | |||
232 | -- | application of a vector function on the flattened matrix elements | ||
233 | liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b | ||
234 | liftMatrix f Matrix { irows = r, icols = c, xdat = d, order = o } = matrixFromVector o r c (f d) | ||
235 | |||
236 | -- | application of a vector function on the flattened matrices elements | ||
237 | liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t | ||
238 | liftMatrix2 f m1 m2 | ||
239 | | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2" | ||
240 | | otherwise = case orderOf m1 of | ||
241 | RowMajor -> matrixFromVector RowMajor (rows m1) (cols m1) (f (xdat m1) (flatten m2)) | ||
242 | ColumnMajor -> matrixFromVector ColumnMajor (rows m1) (cols m1) (f (xdat m1) ((xdat.fmat) m2)) | ||
243 | |||
244 | |||
245 | compat :: Matrix a -> Matrix b -> Bool | ||
246 | compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2 | ||
247 | |||
248 | ------------------------------------------------------------------ | ||
249 | |||
250 | {- | Supported matrix elements. | ||
251 | |||
252 | This class provides optimized internal | ||
253 | operations for selected element types. | ||
254 | It provides unoptimised defaults for any 'Storable' type, | ||
255 | so you can create instances simply as: | ||
256 | @instance Element Foo@. | ||
257 | -} | ||
258 | class (Storable a) => Element a where | ||
259 | subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position | ||
260 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix | ||
261 | -> Matrix a -> Matrix a | ||
262 | subMatrixD = subMatrix' | ||
263 | transdata :: Int -> Vector a -> Int -> Vector a | ||
264 | transdata = transdataP -- transdata' | ||
265 | constantD :: a -> Int -> Vector a | ||
266 | constantD = constantP -- constant' | ||
267 | |||
268 | |||
269 | instance Element Float where | ||
270 | transdata = transdataAux ctransF | ||
271 | constantD = constantAux cconstantF | ||
272 | |||
273 | instance Element Double where | ||
274 | transdata = transdataAux ctransR | ||
275 | constantD = constantAux cconstantR | ||
276 | |||
277 | instance Element (Complex Float) where | ||
278 | transdata = transdataAux ctransQ | ||
279 | constantD = constantAux cconstantQ | ||
280 | |||
281 | instance Element (Complex Double) where | ||
282 | transdata = transdataAux ctransC | ||
283 | constantD = constantAux cconstantC | ||
284 | |||
285 | ------------------------------------------------------------------- | ||
286 | |||
287 | transdataAux fun c1 d c2 = | ||
288 | if noneed | ||
289 | then d | ||
290 | else unsafePerformIO $ do | ||
291 | v <- createVector (dim d) | ||
292 | unsafeWith d $ \pd -> | ||
293 | unsafeWith v $ \pv -> | ||
294 | fun (fi r1) (fi c1) pd (fi r2) (fi c2) pv // check "transdataAux" | ||
295 | return v | ||
296 | where r1 = dim d `div` c1 | ||
297 | r2 = dim d `div` c2 | ||
298 | noneed = dim d == 0 || r1 == 1 || c1 == 1 | ||
299 | |||
300 | transdataP :: Storable a => Int -> Vector a -> Int -> Vector a | ||
301 | transdataP c1 d c2 = | ||
302 | if noneed | ||
303 | then d | ||
304 | else unsafePerformIO $ do | ||
305 | v <- createVector (dim d) | ||
306 | unsafeWith d $ \pd -> | ||
307 | unsafeWith v $ \pv -> | ||
308 | ctransP (fi r1) (fi c1) (castPtr pd) (fi sz) (fi r2) (fi c2) (castPtr pv) (fi sz) // check "transdataP" | ||
309 | return v | ||
310 | where r1 = dim d `div` c1 | ||
311 | r2 = dim d `div` c2 | ||
312 | sz = sizeOf (d @> 0) | ||
313 | noneed = dim d == 0 || r1 == 1 || c1 == 1 | ||
314 | |||
315 | foreign import ccall unsafe "transF" ctransF :: TFMFM | ||
316 | foreign import ccall unsafe "transR" ctransR :: TMM | ||
317 | foreign import ccall unsafe "transQ" ctransQ :: TQMQM | ||
318 | foreign import ccall unsafe "transC" ctransC :: TCMCM | ||
319 | foreign import ccall unsafe "transP" ctransP :: CInt -> CInt -> Ptr () -> CInt -> CInt -> CInt -> Ptr () -> CInt -> IO CInt | ||
320 | |||
321 | ---------------------------------------------------------------------- | ||
322 | |||
323 | constantAux fun x n = unsafePerformIO $ do | ||
324 | v <- createVector n | ||
325 | px <- newArray [x] | ||
326 | app1 (fun px) vec v "constantAux" | ||
327 | free px | ||
328 | return v | ||
329 | |||
330 | foreign import ccall unsafe "constantF" cconstantF :: Ptr Float -> TF | ||
331 | |||
332 | foreign import ccall unsafe "constantR" cconstantR :: Ptr Double -> TV | ||
333 | |||
334 | foreign import ccall unsafe "constantQ" cconstantQ :: Ptr (Complex Float) -> TQV | ||
335 | |||
336 | foreign import ccall unsafe "constantC" cconstantC :: Ptr (Complex Double) -> TCV | ||
337 | |||
338 | constantP :: Storable a => a -> Int -> Vector a | ||
339 | constantP a n = unsafePerformIO $ do | ||
340 | let sz = sizeOf a | ||
341 | v <- createVector n | ||
342 | unsafeWith v $ \p -> do | ||
343 | alloca $ \k -> do | ||
344 | poke k a | ||
345 | cconstantP (castPtr k) (fi n) (castPtr p) (fi sz) // check "constantP" | ||
346 | return v | ||
347 | foreign import ccall unsafe "constantP" cconstantP :: Ptr () -> CInt -> Ptr () -> CInt -> IO CInt | ||
348 | |||
349 | ---------------------------------------------------------------------- | ||
350 | |||
351 | -- | Extracts a submatrix from a matrix. | ||
352 | subMatrix :: Element a | ||
353 | => (Int,Int) -- ^ (r0,c0) starting position | ||
354 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix | ||
355 | -> Matrix a -- ^ input matrix | ||
356 | -> Matrix a -- ^ result | ||
357 | subMatrix (r0,c0) (rt,ct) m | ||
358 | | 0 <= r0 && 0 <= rt && r0+rt <= (rows m) && | ||
359 | 0 <= c0 && 0 <= ct && c0+ct <= (cols m) = subMatrixD (r0,c0) (rt,ct) m | ||
360 | | otherwise = error $ "wrong subMatrix "++ | ||
361 | show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m) | ||
362 | |||
363 | subMatrix'' (r0,c0) (rt,ct) c v = unsafePerformIO $ do | ||
364 | w <- createVector (rt*ct) | ||
365 | unsafeWith v $ \p -> | ||
366 | unsafeWith w $ \q -> do | ||
367 | let go (-1) _ = return () | ||
368 | go !i (-1) = go (i-1) (ct-1) | ||
369 | go !i !j = do x <- peekElemOff p ((i+r0)*c+j+c0) | ||
370 | pokeElemOff q (i*ct+j) x | ||
371 | go i (j-1) | ||
372 | go (rt-1) (ct-1) | ||
373 | return w | ||
374 | |||
375 | subMatrix' (r0,c0) (rt,ct) (Matrix { icols = c, xdat = v, order = RowMajor}) = Matrix rt ct (subMatrix'' (r0,c0) (rt,ct) c v) RowMajor | ||
376 | subMatrix' (r0,c0) (rt,ct) m = trans $ subMatrix' (c0,r0) (ct,rt) (trans m) | ||
377 | |||
378 | -------------------------------------------------------------------------- | ||
379 | |||
380 | maxZ xs = if minimum xs == 0 then 0 else maximum xs | ||
381 | |||
382 | conformMs ms = map (conformMTo (r,c)) ms | ||
383 | where | ||
384 | r = maxZ (map rows ms) | ||
385 | c = maxZ (map cols ms) | ||
386 | |||
387 | |||
388 | conformVs vs = map (conformVTo n) vs | ||
389 | where | ||
390 | n = maxZ (map dim vs) | ||
391 | |||
392 | conformMTo (r,c) m | ||
393 | | size m == (r,c) = m | ||
394 | | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) | ||
395 | | size m == (r,1) = repCols c m | ||
396 | | size m == (1,c) = repRows r m | ||
397 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")" | ||
398 | |||
399 | conformVTo n v | ||
400 | | dim v == n = v | ||
401 | | dim v == 1 = constantD (v@>0) n | ||
402 | | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n | ||
403 | |||
404 | repRows n x = fromRows (replicate n (flatten x)) | ||
405 | repCols n x = fromColumns (replicate n (flatten x)) | ||
406 | |||
407 | size m = (rows m, cols m) | ||
408 | |||
409 | shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")" | ||
410 | |||
411 | emptyM r c = matrixFromVector RowMajor r c (fromList[]) | ||
412 | |||
413 | ---------------------------------------------------------------------- | ||
414 | |||
415 | instance (Storable t, NFData t) => NFData (Matrix t) | ||
416 | where | ||
417 | rnf m | d > 0 = rnf (v @> 0) | ||
418 | | otherwise = () | ||
419 | where | ||
420 | d = dim v | ||
421 | v = xdat m | ||
422 | |||
diff --git a/packages/base/src/Data/Packed/Internal/Signatures.hs b/packages/base/src/Data/Packed/Internal/Signatures.hs new file mode 100644 index 0000000..acc3070 --- /dev/null +++ b/packages/base/src/Data/Packed/Internal/Signatures.hs | |||
@@ -0,0 +1,70 @@ | |||
1 | -- | | ||
2 | -- Module : Data.Packed.Internal.Signatures | ||
3 | -- Copyright : (c) Alberto Ruiz 2009 | ||
4 | -- License : BSD3 | ||
5 | -- Maintainer : Alberto Ruiz | ||
6 | -- Stability : provisional | ||
7 | -- | ||
8 | -- Signatures of the C functions. | ||
9 | -- | ||
10 | |||
11 | |||
12 | module Data.Packed.Internal.Signatures where | ||
13 | |||
14 | import Foreign.Ptr(Ptr) | ||
15 | import Data.Complex(Complex) | ||
16 | import Foreign.C.Types(CInt) | ||
17 | |||
18 | type PF = Ptr Float -- | ||
19 | type PD = Ptr Double -- | ||
20 | type PQ = Ptr (Complex Float) -- | ||
21 | type PC = Ptr (Complex Double) -- | ||
22 | type TF = CInt -> PF -> IO CInt -- | ||
23 | type TFF = CInt -> PF -> TF -- | ||
24 | type TFV = CInt -> PF -> TV -- | ||
25 | type TVF = CInt -> PD -> TF -- | ||
26 | type TFFF = CInt -> PF -> TFF -- | ||
27 | type TV = CInt -> PD -> IO CInt -- | ||
28 | type TVV = CInt -> PD -> TV -- | ||
29 | type TVVV = CInt -> PD -> TVV -- | ||
30 | type TFM = CInt -> CInt -> PF -> IO CInt -- | ||
31 | type TFMFM = CInt -> CInt -> PF -> TFM -- | ||
32 | type TFMFMFM = CInt -> CInt -> PF -> TFMFM -- | ||
33 | type TM = CInt -> CInt -> PD -> IO CInt -- | ||
34 | type TMM = CInt -> CInt -> PD -> TM -- | ||
35 | type TVMM = CInt -> PD -> TMM -- | ||
36 | type TMVMM = CInt -> CInt -> PD -> TVMM -- | ||
37 | type TMMM = CInt -> CInt -> PD -> TMM -- | ||
38 | type TVM = CInt -> PD -> TM -- | ||
39 | type TVVM = CInt -> PD -> TVM -- | ||
40 | type TMV = CInt -> CInt -> PD -> TV -- | ||
41 | type TMMV = CInt -> CInt -> PD -> TMV -- | ||
42 | type TMVM = CInt -> CInt -> PD -> TVM -- | ||
43 | type TMMVM = CInt -> CInt -> PD -> TMVM -- | ||
44 | type TCM = CInt -> CInt -> PC -> IO CInt -- | ||
45 | type TCVCM = CInt -> PC -> TCM -- | ||
46 | type TCMCVCM = CInt -> CInt -> PC -> TCVCM -- | ||
47 | type TMCMCVCM = CInt -> CInt -> PD -> TCMCVCM -- | ||
48 | type TCMCMCVCM = CInt -> CInt -> PC -> TCMCVCM -- | ||
49 | type TCMCM = CInt -> CInt -> PC -> TCM -- | ||
50 | type TVCM = CInt -> PD -> TCM -- | ||
51 | type TCMVCM = CInt -> CInt -> PC -> TVCM -- | ||
52 | type TCMCMVCM = CInt -> CInt -> PC -> TCMVCM -- | ||
53 | type TCMCMCM = CInt -> CInt -> PC -> TCMCM -- | ||
54 | type TCV = CInt -> PC -> IO CInt -- | ||
55 | type TCVCV = CInt -> PC -> TCV -- | ||
56 | type TCVCVCV = CInt -> PC -> TCVCV -- | ||
57 | type TCVV = CInt -> PC -> TV -- | ||
58 | type TQV = CInt -> PQ -> IO CInt -- | ||
59 | type TQVQV = CInt -> PQ -> TQV -- | ||
60 | type TQVQVQV = CInt -> PQ -> TQVQV -- | ||
61 | type TQVF = CInt -> PQ -> TF -- | ||
62 | type TQM = CInt -> CInt -> PQ -> IO CInt -- | ||
63 | type TQMQM = CInt -> CInt -> PQ -> TQM -- | ||
64 | type TQMQMQM = CInt -> CInt -> PQ -> TQMQM -- | ||
65 | type TCMCV = CInt -> CInt -> PC -> TCV -- | ||
66 | type TVCV = CInt -> PD -> TCV -- | ||
67 | type TCVM = CInt -> PC -> TM -- | ||
68 | type TMCVM = CInt -> CInt -> PD -> TCVM -- | ||
69 | type TMMCVM = CInt -> CInt -> PD -> TMCVM -- | ||
70 | |||
diff --git a/packages/base/src/Data/Packed/Internal/Vector.hs b/packages/base/src/Data/Packed/Internal/Vector.hs new file mode 100644 index 0000000..d0bc143 --- /dev/null +++ b/packages/base/src/Data/Packed/Internal/Vector.hs | |||
@@ -0,0 +1,471 @@ | |||
1 | {-# LANGUAGE MagicHash, CPP, UnboxedTuples, BangPatterns, FlexibleContexts #-} | ||
2 | -- | | ||
3 | -- Module : Data.Packed.Internal.Vector | ||
4 | -- Copyright : (c) Alberto Ruiz 2007 | ||
5 | -- License : BSD3 | ||
6 | -- Maintainer : Alberto Ruiz | ||
7 | -- Stability : provisional | ||
8 | -- | ||
9 | -- Vector implementation | ||
10 | -- | ||
11 | -------------------------------------------------------------------------------- | ||
12 | |||
13 | module Data.Packed.Internal.Vector ( | ||
14 | Vector, dim, | ||
15 | fromList, toList, (|>), | ||
16 | vjoin, (@>), safe, at, at', subVector, takesV, | ||
17 | mapVector, mapVectorWithIndex, zipVectorWith, unzipVectorWith, | ||
18 | mapVectorM, mapVectorM_, mapVectorWithIndexM, mapVectorWithIndexM_, | ||
19 | foldVector, foldVectorG, foldLoop, foldVectorWithIndex, | ||
20 | createVector, vec, | ||
21 | asComplex, asReal, float2DoubleV, double2FloatV, | ||
22 | stepF, stepD, condF, condD, | ||
23 | conjugateQ, conjugateC, | ||
24 | cloneVector, | ||
25 | unsafeToForeignPtr, | ||
26 | unsafeFromForeignPtr, | ||
27 | unsafeWith | ||
28 | ) where | ||
29 | |||
30 | import Data.Packed.Internal.Common | ||
31 | import Data.Packed.Internal.Signatures | ||
32 | import Foreign.Marshal.Array(peekArray, copyArray, advancePtr) | ||
33 | import Foreign.ForeignPtr(ForeignPtr, castForeignPtr) | ||
34 | import Foreign.Ptr(Ptr) | ||
35 | import Foreign.Storable(Storable, peekElemOff, pokeElemOff, sizeOf) | ||
36 | import Foreign.C.Types | ||
37 | import Data.Complex | ||
38 | import Control.Monad(when) | ||
39 | import System.IO.Unsafe(unsafePerformIO) | ||
40 | |||
41 | #if __GLASGOW_HASKELL__ >= 605 | ||
42 | import GHC.ForeignPtr (mallocPlainForeignPtrBytes) | ||
43 | #else | ||
44 | import Foreign.ForeignPtr (mallocForeignPtrBytes) | ||
45 | #endif | ||
46 | |||
47 | import GHC.Base | ||
48 | #if __GLASGOW_HASKELL__ < 612 | ||
49 | import GHC.IOBase hiding (liftIO) | ||
50 | #endif | ||
51 | |||
52 | import qualified Data.Vector.Storable as Vector | ||
53 | import Data.Vector.Storable(Vector, | ||
54 | fromList, | ||
55 | unsafeToForeignPtr, | ||
56 | unsafeFromForeignPtr, | ||
57 | unsafeWith) | ||
58 | |||
59 | |||
60 | -- | Number of elements | ||
61 | dim :: (Storable t) => Vector t -> Int | ||
62 | dim = Vector.length | ||
63 | |||
64 | |||
65 | -- C-Haskell vector adapter | ||
66 | -- vec :: Adapt (CInt -> Ptr t -> r) (Vector t) r | ||
67 | vec :: (Storable t) => Vector t -> (((CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b | ||
68 | vec x f = unsafeWith x $ \p -> do | ||
69 | let v g = do | ||
70 | g (fi $ dim x) p | ||
71 | f v | ||
72 | {-# INLINE vec #-} | ||
73 | |||
74 | |||
75 | -- allocates memory for a new vector | ||
76 | createVector :: Storable a => Int -> IO (Vector a) | ||
77 | createVector n = do | ||
78 | when (n < 0) $ error ("trying to createVector of negative dim: "++show n) | ||
79 | fp <- doMalloc undefined | ||
80 | return $ unsafeFromForeignPtr fp 0 n | ||
81 | where | ||
82 | -- | ||
83 | -- Use the much cheaper Haskell heap allocated storage | ||
84 | -- for foreign pointer space we control | ||
85 | -- | ||
86 | doMalloc :: Storable b => b -> IO (ForeignPtr b) | ||
87 | doMalloc dummy = do | ||
88 | #if __GLASGOW_HASKELL__ >= 605 | ||
89 | mallocPlainForeignPtrBytes (n * sizeOf dummy) | ||
90 | #else | ||
91 | mallocForeignPtrBytes (n * sizeOf dummy) | ||
92 | #endif | ||
93 | |||
94 | {- | creates a Vector from a list: | ||
95 | |||
96 | @> fromList [2,3,5,7] | ||
97 | 4 |> [2.0,3.0,5.0,7.0]@ | ||
98 | |||
99 | -} | ||
100 | |||
101 | safeRead v = inlinePerformIO . unsafeWith v | ||
102 | {-# INLINE safeRead #-} | ||
103 | |||
104 | inlinePerformIO :: IO a -> a | ||
105 | inlinePerformIO (IO m) = case m realWorld# of (# _, r #) -> r | ||
106 | {-# INLINE inlinePerformIO #-} | ||
107 | |||
108 | {- | extracts the Vector elements to a list | ||
109 | |||
110 | >>> toList (linspace 5 (1,10)) | ||
111 | [1.0,3.25,5.5,7.75,10.0] | ||
112 | |||
113 | -} | ||
114 | toList :: Storable a => Vector a -> [a] | ||
115 | toList v = safeRead v $ peekArray (dim v) | ||
116 | |||
117 | {- | Create a vector from a list of elements and explicit dimension. The input | ||
118 | list is explicitly truncated if it is too long, so it may safely | ||
119 | be used, for instance, with infinite lists. | ||
120 | |||
121 | >>> 5 |> [1..] | ||
122 | fromList [1.0,2.0,3.0,4.0,5.0] | ||
123 | |||
124 | -} | ||
125 | (|>) :: (Storable a) => Int -> [a] -> Vector a | ||
126 | infixl 9 |> | ||
127 | n |> l = if length l' == n | ||
128 | then fromList l' | ||
129 | else error "list too short for |>" | ||
130 | where l' = take n l | ||
131 | |||
132 | |||
133 | -- | access to Vector elements without range checking | ||
134 | at' :: Storable a => Vector a -> Int -> a | ||
135 | at' v n = safeRead v $ flip peekElemOff n | ||
136 | {-# INLINE at' #-} | ||
137 | |||
138 | -- | ||
139 | -- turn off bounds checking with -funsafe at configure time. | ||
140 | -- ghc will optimise away the salways true case at compile time. | ||
141 | -- | ||
142 | #if defined(UNSAFE) | ||
143 | safe :: Bool | ||
144 | safe = False | ||
145 | #else | ||
146 | safe = True | ||
147 | #endif | ||
148 | |||
149 | -- | access to Vector elements with range checking. | ||
150 | at :: Storable a => Vector a -> Int -> a | ||
151 | at v n | ||
152 | | safe = if n >= 0 && n < dim v | ||
153 | then at' v n | ||
154 | else error "vector index out of range" | ||
155 | | otherwise = at' v n | ||
156 | {-# INLINE at #-} | ||
157 | |||
158 | {- | takes a number of consecutive elements from a Vector | ||
159 | |||
160 | >>> subVector 2 3 (fromList [1..10]) | ||
161 | fromList [3.0,4.0,5.0] | ||
162 | |||
163 | -} | ||
164 | subVector :: Storable t => Int -- ^ index of the starting element | ||
165 | -> Int -- ^ number of elements to extract | ||
166 | -> Vector t -- ^ source | ||
167 | -> Vector t -- ^ result | ||
168 | subVector = Vector.slice | ||
169 | |||
170 | |||
171 | {- | Reads a vector position: | ||
172 | |||
173 | >>> fromList [0..9] @> 7 | ||
174 | 7.0 | ||
175 | |||
176 | -} | ||
177 | (@>) :: Storable t => Vector t -> Int -> t | ||
178 | infixl 9 @> | ||
179 | (@>) = at | ||
180 | |||
181 | |||
182 | {- | concatenate a list of vectors | ||
183 | |||
184 | >>> vjoin [fromList [1..5::Double], konst 1 3] | ||
185 | fromList [1.0,2.0,3.0,4.0,5.0,1.0,1.0,1.0] | ||
186 | |||
187 | -} | ||
188 | vjoin :: Storable t => [Vector t] -> Vector t | ||
189 | vjoin [] = fromList [] | ||
190 | vjoin [v] = v | ||
191 | vjoin as = unsafePerformIO $ do | ||
192 | let tot = sum (map dim as) | ||
193 | r <- createVector tot | ||
194 | unsafeWith r $ \ptr -> | ||
195 | joiner as tot ptr | ||
196 | return r | ||
197 | where joiner [] _ _ = return () | ||
198 | joiner (v:cs) _ p = do | ||
199 | let n = dim v | ||
200 | unsafeWith v $ \pb -> copyArray p pb n | ||
201 | joiner cs 0 (advancePtr p n) | ||
202 | |||
203 | |||
204 | {- | Extract consecutive subvectors of the given sizes. | ||
205 | |||
206 | >>> takesV [3,4] (linspace 10 (1,10::Double)) | ||
207 | [fromList [1.0,2.0,3.0],fromList [4.0,5.0,6.0,7.0]] | ||
208 | |||
209 | -} | ||
210 | takesV :: Storable t => [Int] -> Vector t -> [Vector t] | ||
211 | takesV ms w | sum ms > dim w = error $ "takesV " ++ show ms ++ " on dim = " ++ (show $ dim w) | ||
212 | | otherwise = go ms w | ||
213 | where go [] _ = [] | ||
214 | go (n:ns) v = subVector 0 n v | ||
215 | : go ns (subVector n (dim v - n) v) | ||
216 | |||
217 | --------------------------------------------------------------- | ||
218 | |||
219 | -- | transforms a complex vector into a real vector with alternating real and imaginary parts | ||
220 | asReal :: (RealFloat a, Storable a) => Vector (Complex a) -> Vector a | ||
221 | asReal v = unsafeFromForeignPtr (castForeignPtr fp) (2*i) (2*n) | ||
222 | where (fp,i,n) = unsafeToForeignPtr v | ||
223 | |||
224 | -- | transforms a real vector into a complex vector with alternating real and imaginary parts | ||
225 | asComplex :: (RealFloat a, Storable a) => Vector a -> Vector (Complex a) | ||
226 | asComplex v = unsafeFromForeignPtr (castForeignPtr fp) (i `div` 2) (n `div` 2) | ||
227 | where (fp,i,n) = unsafeToForeignPtr v | ||
228 | |||
229 | --------------------------------------------------------------- | ||
230 | |||
231 | float2DoubleV :: Vector Float -> Vector Double | ||
232 | float2DoubleV v = unsafePerformIO $ do | ||
233 | r <- createVector (dim v) | ||
234 | app2 c_float2double vec v vec r "float2double" | ||
235 | return r | ||
236 | |||
237 | double2FloatV :: Vector Double -> Vector Float | ||
238 | double2FloatV v = unsafePerformIO $ do | ||
239 | r <- createVector (dim v) | ||
240 | app2 c_double2float vec v vec r "double2float2" | ||
241 | return r | ||
242 | |||
243 | |||
244 | foreign import ccall unsafe "float2double" c_float2double:: TFV | ||
245 | foreign import ccall unsafe "double2float" c_double2float:: TVF | ||
246 | |||
247 | --------------------------------------------------------------- | ||
248 | |||
249 | stepF :: Vector Float -> Vector Float | ||
250 | stepF v = unsafePerformIO $ do | ||
251 | r <- createVector (dim v) | ||
252 | app2 c_stepF vec v vec r "stepF" | ||
253 | return r | ||
254 | |||
255 | stepD :: Vector Double -> Vector Double | ||
256 | stepD v = unsafePerformIO $ do | ||
257 | r <- createVector (dim v) | ||
258 | app2 c_stepD vec v vec r "stepD" | ||
259 | return r | ||
260 | |||
261 | foreign import ccall unsafe "stepF" c_stepF :: TFF | ||
262 | foreign import ccall unsafe "stepD" c_stepD :: TVV | ||
263 | |||
264 | --------------------------------------------------------------- | ||
265 | |||
266 | condF :: Vector Float -> Vector Float -> Vector Float -> Vector Float -> Vector Float -> Vector Float | ||
267 | condF x y l e g = unsafePerformIO $ do | ||
268 | r <- createVector (dim x) | ||
269 | app6 c_condF vec x vec y vec l vec e vec g vec r "condF" | ||
270 | return r | ||
271 | |||
272 | condD :: Vector Double -> Vector Double -> Vector Double -> Vector Double -> Vector Double -> Vector Double | ||
273 | condD x y l e g = unsafePerformIO $ do | ||
274 | r <- createVector (dim x) | ||
275 | app6 c_condD vec x vec y vec l vec e vec g vec r "condD" | ||
276 | return r | ||
277 | |||
278 | foreign import ccall unsafe "condF" c_condF :: CInt -> PF -> CInt -> PF -> CInt -> PF -> TFFF | ||
279 | foreign import ccall unsafe "condD" c_condD :: CInt -> PD -> CInt -> PD -> CInt -> PD -> TVVV | ||
280 | |||
281 | -------------------------------------------------------------------------------- | ||
282 | |||
283 | conjugateAux fun x = unsafePerformIO $ do | ||
284 | v <- createVector (dim x) | ||
285 | app2 fun vec x vec v "conjugateAux" | ||
286 | return v | ||
287 | |||
288 | conjugateQ :: Vector (Complex Float) -> Vector (Complex Float) | ||
289 | conjugateQ = conjugateAux c_conjugateQ | ||
290 | foreign import ccall unsafe "conjugateQ" c_conjugateQ :: TQVQV | ||
291 | |||
292 | conjugateC :: Vector (Complex Double) -> Vector (Complex Double) | ||
293 | conjugateC = conjugateAux c_conjugateC | ||
294 | foreign import ccall unsafe "conjugateC" c_conjugateC :: TCVCV | ||
295 | |||
296 | -------------------------------------------------------------------------------- | ||
297 | |||
298 | cloneVector :: Storable t => Vector t -> IO (Vector t) | ||
299 | cloneVector v = do | ||
300 | let n = dim v | ||
301 | r <- createVector n | ||
302 | let f _ s _ d = copyArray d s n >> return 0 | ||
303 | app2 f vec v vec r "cloneVector" | ||
304 | return r | ||
305 | |||
306 | ------------------------------------------------------------------ | ||
307 | |||
308 | -- | map on Vectors | ||
309 | mapVector :: (Storable a, Storable b) => (a-> b) -> Vector a -> Vector b | ||
310 | mapVector f v = unsafePerformIO $ do | ||
311 | w <- createVector (dim v) | ||
312 | unsafeWith v $ \p -> | ||
313 | unsafeWith w $ \q -> do | ||
314 | let go (-1) = return () | ||
315 | go !k = do x <- peekElemOff p k | ||
316 | pokeElemOff q k (f x) | ||
317 | go (k-1) | ||
318 | go (dim v -1) | ||
319 | return w | ||
320 | {-# INLINE mapVector #-} | ||
321 | |||
322 | -- | zipWith for Vectors | ||
323 | zipVectorWith :: (Storable a, Storable b, Storable c) => (a-> b -> c) -> Vector a -> Vector b -> Vector c | ||
324 | zipVectorWith f u v = unsafePerformIO $ do | ||
325 | let n = min (dim u) (dim v) | ||
326 | w <- createVector n | ||
327 | unsafeWith u $ \pu -> | ||
328 | unsafeWith v $ \pv -> | ||
329 | unsafeWith w $ \pw -> do | ||
330 | let go (-1) = return () | ||
331 | go !k = do x <- peekElemOff pu k | ||
332 | y <- peekElemOff pv k | ||
333 | pokeElemOff pw k (f x y) | ||
334 | go (k-1) | ||
335 | go (n -1) | ||
336 | return w | ||
337 | {-# INLINE zipVectorWith #-} | ||
338 | |||
339 | -- | unzipWith for Vectors | ||
340 | unzipVectorWith :: (Storable (a,b), Storable c, Storable d) | ||
341 | => ((a,b) -> (c,d)) -> Vector (a,b) -> (Vector c,Vector d) | ||
342 | unzipVectorWith f u = unsafePerformIO $ do | ||
343 | let n = dim u | ||
344 | v <- createVector n | ||
345 | w <- createVector n | ||
346 | unsafeWith u $ \pu -> | ||
347 | unsafeWith v $ \pv -> | ||
348 | unsafeWith w $ \pw -> do | ||
349 | let go (-1) = return () | ||
350 | go !k = do z <- peekElemOff pu k | ||
351 | let (x,y) = f z | ||
352 | pokeElemOff pv k x | ||
353 | pokeElemOff pw k y | ||
354 | go (k-1) | ||
355 | go (n-1) | ||
356 | return (v,w) | ||
357 | {-# INLINE unzipVectorWith #-} | ||
358 | |||
359 | foldVector :: Storable a => (a -> b -> b) -> b -> Vector a -> b | ||
360 | foldVector f x v = unsafePerformIO $ | ||
361 | unsafeWith v $ \p -> do | ||
362 | let go (-1) s = return s | ||
363 | go !k !s = do y <- peekElemOff p k | ||
364 | go (k-1::Int) (f y s) | ||
365 | go (dim v -1) x | ||
366 | {-# INLINE foldVector #-} | ||
367 | |||
368 | -- the zero-indexed index is passed to the folding function | ||
369 | foldVectorWithIndex :: Storable a => (Int -> a -> b -> b) -> b -> Vector a -> b | ||
370 | foldVectorWithIndex f x v = unsafePerformIO $ | ||
371 | unsafeWith v $ \p -> do | ||
372 | let go (-1) s = return s | ||
373 | go !k !s = do y <- peekElemOff p k | ||
374 | go (k-1::Int) (f k y s) | ||
375 | go (dim v -1) x | ||
376 | {-# INLINE foldVectorWithIndex #-} | ||
377 | |||
378 | foldLoop f s0 d = go (d - 1) s0 | ||
379 | where | ||
380 | go 0 s = f (0::Int) s | ||
381 | go !j !s = go (j - 1) (f j s) | ||
382 | |||
383 | foldVectorG f s0 v = foldLoop g s0 (dim v) | ||
384 | where g !k !s = f k (at' v) s | ||
385 | {-# INLINE g #-} -- Thanks to Ryan Ingram (http://permalink.gmane.org/gmane.comp.lang.haskell.cafe/46479) | ||
386 | {-# INLINE foldVectorG #-} | ||
387 | |||
388 | ------------------------------------------------------------------- | ||
389 | |||
390 | -- | monadic map over Vectors | ||
391 | -- the monad @m@ must be strict | ||
392 | mapVectorM :: (Storable a, Storable b, Monad m) => (a -> m b) -> Vector a -> m (Vector b) | ||
393 | mapVectorM f v = do | ||
394 | w <- return $! unsafePerformIO $! createVector (dim v) | ||
395 | mapVectorM' w 0 (dim v -1) | ||
396 | return w | ||
397 | where mapVectorM' w' !k !t | ||
398 | | k == t = do | ||
399 | x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k | ||
400 | y <- f x | ||
401 | return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y | ||
402 | | otherwise = do | ||
403 | x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k | ||
404 | y <- f x | ||
405 | _ <- return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y | ||
406 | mapVectorM' w' (k+1) t | ||
407 | {-# INLINE mapVectorM #-} | ||
408 | |||
409 | -- | monadic map over Vectors | ||
410 | mapVectorM_ :: (Storable a, Monad m) => (a -> m ()) -> Vector a -> m () | ||
411 | mapVectorM_ f v = do | ||
412 | mapVectorM' 0 (dim v -1) | ||
413 | where mapVectorM' !k !t | ||
414 | | k == t = do | ||
415 | x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k | ||
416 | f x | ||
417 | | otherwise = do | ||
418 | x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k | ||
419 | _ <- f x | ||
420 | mapVectorM' (k+1) t | ||
421 | {-# INLINE mapVectorM_ #-} | ||
422 | |||
423 | -- | monadic map over Vectors with the zero-indexed index passed to the mapping function | ||
424 | -- the monad @m@ must be strict | ||
425 | mapVectorWithIndexM :: (Storable a, Storable b, Monad m) => (Int -> a -> m b) -> Vector a -> m (Vector b) | ||
426 | mapVectorWithIndexM f v = do | ||
427 | w <- return $! unsafePerformIO $! createVector (dim v) | ||
428 | mapVectorM' w 0 (dim v -1) | ||
429 | return w | ||
430 | where mapVectorM' w' !k !t | ||
431 | | k == t = do | ||
432 | x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k | ||
433 | y <- f k x | ||
434 | return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y | ||
435 | | otherwise = do | ||
436 | x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k | ||
437 | y <- f k x | ||
438 | _ <- return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y | ||
439 | mapVectorM' w' (k+1) t | ||
440 | {-# INLINE mapVectorWithIndexM #-} | ||
441 | |||
442 | -- | monadic map over Vectors with the zero-indexed index passed to the mapping function | ||
443 | mapVectorWithIndexM_ :: (Storable a, Monad m) => (Int -> a -> m ()) -> Vector a -> m () | ||
444 | mapVectorWithIndexM_ f v = do | ||
445 | mapVectorM' 0 (dim v -1) | ||
446 | where mapVectorM' !k !t | ||
447 | | k == t = do | ||
448 | x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k | ||
449 | f k x | ||
450 | | otherwise = do | ||
451 | x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k | ||
452 | _ <- f k x | ||
453 | mapVectorM' (k+1) t | ||
454 | {-# INLINE mapVectorWithIndexM_ #-} | ||
455 | |||
456 | |||
457 | mapVectorWithIndex :: (Storable a, Storable b) => (Int -> a -> b) -> Vector a -> Vector b | ||
458 | --mapVectorWithIndex g = head . mapVectorWithIndexM (\a b -> [g a b]) | ||
459 | mapVectorWithIndex f v = unsafePerformIO $ do | ||
460 | w <- createVector (dim v) | ||
461 | unsafeWith v $ \p -> | ||
462 | unsafeWith w $ \q -> do | ||
463 | let go (-1) = return () | ||
464 | go !k = do x <- peekElemOff p k | ||
465 | pokeElemOff q k (f k x) | ||
466 | go (k-1) | ||
467 | go (dim v -1) | ||
468 | return w | ||
469 | {-# INLINE mapVectorWithIndex #-} | ||
470 | |||
471 | |||
diff --git a/packages/base/src/Data/Packed/Matrix.hs b/packages/base/src/Data/Packed/Matrix.hs new file mode 100644 index 0000000..d94d167 --- /dev/null +++ b/packages/base/src/Data/Packed/Matrix.hs | |||
@@ -0,0 +1,490 @@ | |||
1 | {-# LANGUAGE TypeFamilies #-} | ||
2 | {-# LANGUAGE FlexibleContexts #-} | ||
3 | {-# LANGUAGE FlexibleInstances #-} | ||
4 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
5 | {-# LANGUAGE CPP #-} | ||
6 | |||
7 | ----------------------------------------------------------------------------- | ||
8 | -- | | ||
9 | -- Module : Data.Packed.Matrix | ||
10 | -- Copyright : (c) Alberto Ruiz 2007-10 | ||
11 | -- License : GPL | ||
12 | -- | ||
13 | -- Maintainer : Alberto Ruiz <aruiz@um.es> | ||
14 | -- Stability : provisional | ||
15 | -- | ||
16 | -- A Matrix representation suitable for numerical computations using LAPACK and GSL. | ||
17 | -- | ||
18 | -- This module provides basic functions for manipulation of structure. | ||
19 | |||
20 | ----------------------------------------------------------------------------- | ||
21 | {-# OPTIONS_HADDOCK hide #-} | ||
22 | |||
23 | module Data.Packed.Matrix ( | ||
24 | Matrix, | ||
25 | Element, | ||
26 | rows,cols, | ||
27 | (><), | ||
28 | trans, | ||
29 | reshape, flatten, | ||
30 | fromLists, toLists, buildMatrix, | ||
31 | (@@>), | ||
32 | asRow, asColumn, | ||
33 | fromRows, toRows, fromColumns, toColumns, | ||
34 | fromBlocks, diagBlock, toBlocks, toBlocksEvery, | ||
35 | repmat, | ||
36 | flipud, fliprl, | ||
37 | subMatrix, takeRows, dropRows, takeColumns, dropColumns, | ||
38 | extractRows, extractColumns, | ||
39 | diagRect, takeDiag, | ||
40 | mapMatrix, mapMatrixWithIndex, mapMatrixWithIndexM, mapMatrixWithIndexM_, | ||
41 | liftMatrix, liftMatrix2, liftMatrix2Auto,fromArray2D | ||
42 | ) where | ||
43 | |||
44 | import Data.Packed.Internal | ||
45 | import qualified Data.Packed.ST as ST | ||
46 | import Data.Array | ||
47 | |||
48 | import Data.List(transpose,intersperse) | ||
49 | import Foreign.Storable(Storable) | ||
50 | import Control.Monad(liftM) | ||
51 | |||
52 | ------------------------------------------------------------------- | ||
53 | |||
54 | #ifdef BINARY | ||
55 | |||
56 | import Data.Binary | ||
57 | import Control.Monad(replicateM) | ||
58 | |||
59 | instance (Binary a, Element a, Storable a) => Binary (Matrix a) where | ||
60 | put m = do | ||
61 | let r = rows m | ||
62 | let c = cols m | ||
63 | put r | ||
64 | put c | ||
65 | mapM_ (\i -> mapM_ (\j -> put $ m @@> (i,j)) [0..(c-1)]) [0..(r-1)] | ||
66 | get = do | ||
67 | r <- get | ||
68 | c <- get | ||
69 | xs <- replicateM r $ replicateM c get | ||
70 | return $ fromLists xs | ||
71 | |||
72 | #endif | ||
73 | |||
74 | ------------------------------------------------------------------- | ||
75 | |||
76 | instance (Show a, Element a) => (Show (Matrix a)) where | ||
77 | show m | rows m == 0 || cols m == 0 = sizes m ++" []" | ||
78 | show m = (sizes m++) . dsp . map (map show) . toLists $ m | ||
79 | |||
80 | sizes m = "("++show (rows m)++"><"++show (cols m)++")\n" | ||
81 | |||
82 | dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp | ||
83 | where | ||
84 | mt = transpose as | ||
85 | longs = map (maximum . map length) mt | ||
86 | mtp = zipWith (\a b -> map (pad a) b) longs mt | ||
87 | pad n str = replicate (n - length str) ' ' ++ str | ||
88 | unwords' = concat . intersperse ", " | ||
89 | |||
90 | ------------------------------------------------------------------ | ||
91 | |||
92 | instance (Element a, Read a) => Read (Matrix a) where | ||
93 | readsPrec _ s = [((rs><cs) . read $ listnums, rest)] | ||
94 | where (thing,rest) = breakAt ']' s | ||
95 | (dims,listnums) = breakAt ')' thing | ||
96 | cs = read . init . fst. breakAt ')' . snd . breakAt '<' $ dims | ||
97 | rs = read . snd . breakAt '(' .init . fst . breakAt '>' $ dims | ||
98 | |||
99 | |||
100 | breakAt c l = (a++[c],tail b) where | ||
101 | (a,b) = break (==c) l | ||
102 | |||
103 | ------------------------------------------------------------------ | ||
104 | |||
105 | -- | creates a matrix from a vertical list of matrices | ||
106 | joinVert :: Element t => [Matrix t] -> Matrix t | ||
107 | joinVert [] = emptyM 0 0 | ||
108 | joinVert ms = case common cols ms of | ||
109 | Nothing -> error "(impossible) joinVert on matrices with different number of columns" | ||
110 | Just c -> matrixFromVector RowMajor (sum (map rows ms)) c $ vjoin (map flatten ms) | ||
111 | |||
112 | -- | creates a matrix from a horizontal list of matrices | ||
113 | joinHoriz :: Element t => [Matrix t] -> Matrix t | ||
114 | joinHoriz ms = trans. joinVert . map trans $ ms | ||
115 | |||
116 | {- | Create a matrix from blocks given as a list of lists of matrices. | ||
117 | |||
118 | Single row-column components are automatically expanded to match the | ||
119 | corresponding common row and column: | ||
120 | |||
121 | @ | ||
122 | disp = putStr . dispf 2 | ||
123 | @ | ||
124 | |||
125 | >>> disp $ fromBlocks [[ident 5, 7, row[10,20]], [3, diagl[1,2,3], 0]] | ||
126 | 8x10 | ||
127 | 1 0 0 0 0 7 7 7 10 20 | ||
128 | 0 1 0 0 0 7 7 7 10 20 | ||
129 | 0 0 1 0 0 7 7 7 10 20 | ||
130 | 0 0 0 1 0 7 7 7 10 20 | ||
131 | 0 0 0 0 1 7 7 7 10 20 | ||
132 | 3 3 3 3 3 1 0 0 0 0 | ||
133 | 3 3 3 3 3 0 2 0 0 0 | ||
134 | 3 3 3 3 3 0 0 3 0 0 | ||
135 | |||
136 | -} | ||
137 | fromBlocks :: Element t => [[Matrix t]] -> Matrix t | ||
138 | fromBlocks = fromBlocksRaw . adaptBlocks | ||
139 | |||
140 | fromBlocksRaw mms = joinVert . map joinHoriz $ mms | ||
141 | |||
142 | adaptBlocks ms = ms' where | ||
143 | bc = case common length ms of | ||
144 | Just c -> c | ||
145 | Nothing -> error "fromBlocks requires rectangular [[Matrix]]" | ||
146 | rs = map (compatdim . map rows) ms | ||
147 | cs = map (compatdim . map cols) (transpose ms) | ||
148 | szs = sequence [rs,cs] | ||
149 | ms' = splitEvery bc $ zipWith g szs (concat ms) | ||
150 | |||
151 | g [Just nr,Just nc] m | ||
152 | | nr == r && nc == c = m | ||
153 | | r == 1 && c == 1 = matrixFromVector RowMajor nr nc (constantD x (nr*nc)) | ||
154 | | r == 1 = fromRows (replicate nr (flatten m)) | ||
155 | | otherwise = fromColumns (replicate nc (flatten m)) | ||
156 | where | ||
157 | r = rows m | ||
158 | c = cols m | ||
159 | x = m@@>(0,0) | ||
160 | g _ _ = error "inconsistent dimensions in fromBlocks" | ||
161 | |||
162 | |||
163 | -------------------------------------------------------------------------------- | ||
164 | |||
165 | {- | create a block diagonal matrix | ||
166 | |||
167 | >>> disp 2 $ diagBlock [konst 1 (2,2), konst 2 (3,5), col [5,7]] | ||
168 | 7x8 | ||
169 | 1 1 0 0 0 0 0 0 | ||
170 | 1 1 0 0 0 0 0 0 | ||
171 | 0 0 2 2 2 2 2 0 | ||
172 | 0 0 2 2 2 2 2 0 | ||
173 | 0 0 2 2 2 2 2 0 | ||
174 | 0 0 0 0 0 0 0 5 | ||
175 | 0 0 0 0 0 0 0 7 | ||
176 | |||
177 | >>> diagBlock [(0><4)[], konst 2 (2,3)] :: Matrix Double | ||
178 | (2><7) | ||
179 | [ 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0 | ||
180 | , 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0 ] | ||
181 | |||
182 | -} | ||
183 | diagBlock :: (Element t, Num t) => [Matrix t] -> Matrix t | ||
184 | diagBlock ms = fromBlocks $ zipWith f ms [0..] | ||
185 | where | ||
186 | f m k = take n $ replicate k z ++ m : repeat z | ||
187 | n = length ms | ||
188 | z = (1><1) [0] | ||
189 | |||
190 | -------------------------------------------------------------------------------- | ||
191 | |||
192 | |||
193 | -- | Reverse rows | ||
194 | flipud :: Element t => Matrix t -> Matrix t | ||
195 | flipud m = extractRows [r-1,r-2 .. 0] $ m | ||
196 | where | ||
197 | r = rows m | ||
198 | |||
199 | -- | Reverse columns | ||
200 | fliprl :: Element t => Matrix t -> Matrix t | ||
201 | fliprl m = extractColumns [c-1,c-2 .. 0] $ m | ||
202 | where | ||
203 | c = cols m | ||
204 | |||
205 | ------------------------------------------------------------ | ||
206 | |||
207 | {- | creates a rectangular diagonal matrix: | ||
208 | |||
209 | >>> diagRect 7 (fromList [10,20,30]) 4 5 :: Matrix Double | ||
210 | (4><5) | ||
211 | [ 10.0, 7.0, 7.0, 7.0, 7.0 | ||
212 | , 7.0, 20.0, 7.0, 7.0, 7.0 | ||
213 | , 7.0, 7.0, 30.0, 7.0, 7.0 | ||
214 | , 7.0, 7.0, 7.0, 7.0, 7.0 ] | ||
215 | |||
216 | -} | ||
217 | diagRect :: (Storable t) => t -> Vector t -> Int -> Int -> Matrix t | ||
218 | diagRect z v r c = ST.runSTMatrix $ do | ||
219 | m <- ST.newMatrix z r c | ||
220 | let d = min r c `min` (dim v) | ||
221 | mapM_ (\k -> ST.writeMatrix m k k (v@>k)) [0..d-1] | ||
222 | return m | ||
223 | |||
224 | -- | extracts the diagonal from a rectangular matrix | ||
225 | takeDiag :: (Element t) => Matrix t -> Vector t | ||
226 | takeDiag m = fromList [flatten m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] | ||
227 | |||
228 | ------------------------------------------------------------ | ||
229 | |||
230 | {- | An easy way to create a matrix: | ||
231 | |||
232 | >>> (2><3)[2,4,7,-3,11,0] | ||
233 | (2><3) | ||
234 | [ 2.0, 4.0, 7.0 | ||
235 | , -3.0, 11.0, 0.0 ] | ||
236 | |||
237 | This is the format produced by the instances of Show (Matrix a), which | ||
238 | can also be used for input. | ||
239 | |||
240 | The input list is explicitly truncated, so that it can | ||
241 | safely be used with lists that are too long (like infinite lists). | ||
242 | |||
243 | >>> (2><3)[1..] | ||
244 | (2><3) | ||
245 | [ 1.0, 2.0, 3.0 | ||
246 | , 4.0, 5.0, 6.0 ] | ||
247 | |||
248 | |||
249 | -} | ||
250 | (><) :: (Storable a) => Int -> Int -> [a] -> Matrix a | ||
251 | r >< c = f where | ||
252 | f l | dim v == r*c = matrixFromVector RowMajor r c v | ||
253 | | otherwise = error $ "inconsistent list size = " | ||
254 | ++show (dim v) ++" in ("++show r++"><"++show c++")" | ||
255 | where v = fromList $ take (r*c) l | ||
256 | |||
257 | ---------------------------------------------------------------- | ||
258 | |||
259 | -- | Creates a matrix with the first n rows of another matrix | ||
260 | takeRows :: Element t => Int -> Matrix t -> Matrix t | ||
261 | takeRows n mt = subMatrix (0,0) (n, cols mt) mt | ||
262 | -- | Creates a copy of a matrix without the first n rows | ||
263 | dropRows :: Element t => Int -> Matrix t -> Matrix t | ||
264 | dropRows n mt = subMatrix (n,0) (rows mt - n, cols mt) mt | ||
265 | -- |Creates a matrix with the first n columns of another matrix | ||
266 | takeColumns :: Element t => Int -> Matrix t -> Matrix t | ||
267 | takeColumns n mt = subMatrix (0,0) (rows mt, n) mt | ||
268 | -- | Creates a copy of a matrix without the first n columns | ||
269 | dropColumns :: Element t => Int -> Matrix t -> Matrix t | ||
270 | dropColumns n mt = subMatrix (0,n) (rows mt, cols mt - n) mt | ||
271 | |||
272 | ---------------------------------------------------------------- | ||
273 | |||
274 | {- | Creates a 'Matrix' from a list of lists (considered as rows). | ||
275 | |||
276 | >>> fromLists [[1,2],[3,4],[5,6]] | ||
277 | (3><2) | ||
278 | [ 1.0, 2.0 | ||
279 | , 3.0, 4.0 | ||
280 | , 5.0, 6.0 ] | ||
281 | |||
282 | -} | ||
283 | fromLists :: Element t => [[t]] -> Matrix t | ||
284 | fromLists = fromRows . map fromList | ||
285 | |||
286 | -- | creates a 1-row matrix from a vector | ||
287 | -- | ||
288 | -- >>> asRow (fromList [1..5]) | ||
289 | -- (1><5) | ||
290 | -- [ 1.0, 2.0, 3.0, 4.0, 5.0 ] | ||
291 | -- | ||
292 | asRow :: Storable a => Vector a -> Matrix a | ||
293 | asRow v = reshape (dim v) v | ||
294 | |||
295 | -- | creates a 1-column matrix from a vector | ||
296 | -- | ||
297 | -- >>> asColumn (fromList [1..5]) | ||
298 | -- (5><1) | ||
299 | -- [ 1.0 | ||
300 | -- , 2.0 | ||
301 | -- , 3.0 | ||
302 | -- , 4.0 | ||
303 | -- , 5.0 ] | ||
304 | -- | ||
305 | asColumn :: Storable a => Vector a -> Matrix a | ||
306 | asColumn = trans . asRow | ||
307 | |||
308 | |||
309 | |||
310 | {- | creates a Matrix of the specified size using the supplied function to | ||
311 | to map the row\/column position to the value at that row\/column position. | ||
312 | |||
313 | @> buildMatrix 3 4 (\\(r,c) -> fromIntegral r * fromIntegral c) | ||
314 | (3><4) | ||
315 | [ 0.0, 0.0, 0.0, 0.0, 0.0 | ||
316 | , 0.0, 1.0, 2.0, 3.0, 4.0 | ||
317 | , 0.0, 2.0, 4.0, 6.0, 8.0]@ | ||
318 | |||
319 | Hilbert matrix of order N: | ||
320 | |||
321 | @hilb n = buildMatrix n n (\\(i,j)->1/(fromIntegral i + fromIntegral j +1))@ | ||
322 | |||
323 | -} | ||
324 | buildMatrix :: Element a => Int -> Int -> ((Int, Int) -> a) -> Matrix a | ||
325 | buildMatrix rc cc f = | ||
326 | fromLists $ map (map f) | ||
327 | $ map (\ ri -> map (\ ci -> (ri, ci)) [0 .. (cc - 1)]) [0 .. (rc - 1)] | ||
328 | |||
329 | ----------------------------------------------------- | ||
330 | |||
331 | fromArray2D :: (Storable e) => Array (Int, Int) e -> Matrix e | ||
332 | fromArray2D m = (r><c) (elems m) | ||
333 | where ((r0,c0),(r1,c1)) = bounds m | ||
334 | r = r1-r0+1 | ||
335 | c = c1-c0+1 | ||
336 | |||
337 | |||
338 | -- | rearranges the rows of a matrix according to the order given in a list of integers. | ||
339 | extractRows :: Element t => [Int] -> Matrix t -> Matrix t | ||
340 | extractRows [] m = emptyM 0 (cols m) | ||
341 | extractRows l m = fromRows $ extract (toRows m) l | ||
342 | where | ||
343 | extract l' is = [l'!!i | i<- map verify is] | ||
344 | verify k | ||
345 | | k >= 0 && k < rows m = k | ||
346 | | otherwise = error $ "can't extract row " | ||
347 | ++show k++" in list " ++ show l ++ " from matrix " ++ shSize m | ||
348 | |||
349 | -- | rearranges the rows of a matrix according to the order given in a list of integers. | ||
350 | extractColumns :: Element t => [Int] -> Matrix t -> Matrix t | ||
351 | extractColumns l m = trans . extractRows (map verify l) . trans $ m | ||
352 | where | ||
353 | verify k | ||
354 | | k >= 0 && k < cols m = k | ||
355 | | otherwise = error $ "can't extract column " | ||
356 | ++show k++" in list " ++ show l ++ " from matrix " ++ shSize m | ||
357 | |||
358 | |||
359 | |||
360 | {- | creates matrix by repetition of a matrix a given number of rows and columns | ||
361 | |||
362 | >>> repmat (ident 2) 2 3 | ||
363 | (4><6) | ||
364 | [ 1.0, 0.0, 1.0, 0.0, 1.0, 0.0 | ||
365 | , 0.0, 1.0, 0.0, 1.0, 0.0, 1.0 | ||
366 | , 1.0, 0.0, 1.0, 0.0, 1.0, 0.0 | ||
367 | , 0.0, 1.0, 0.0, 1.0, 0.0, 1.0 ] | ||
368 | |||
369 | -} | ||
370 | repmat :: (Element t) => Matrix t -> Int -> Int -> Matrix t | ||
371 | repmat m r c | ||
372 | | r == 0 || c == 0 = emptyM (r*rows m) (c*cols m) | ||
373 | | otherwise = fromBlocks $ replicate r $ replicate c $ m | ||
374 | |||
375 | -- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix. | ||
376 | liftMatrix2Auto :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t | ||
377 | liftMatrix2Auto f m1 m2 | ||
378 | | compat' m1 m2 = lM f m1 m2 | ||
379 | | ok = lM f m1' m2' | ||
380 | | otherwise = error $ "nonconformable matrices in liftMatrix2Auto: " ++ shSize m1 ++ ", " ++ shSize m2 | ||
381 | where | ||
382 | (r1,c1) = size m1 | ||
383 | (r2,c2) = size m2 | ||
384 | r = max r1 r2 | ||
385 | c = max c1 c2 | ||
386 | r0 = min r1 r2 | ||
387 | c0 = min c1 c2 | ||
388 | ok = r0 == 1 || r1 == r2 && c0 == 1 || c1 == c2 | ||
389 | m1' = conformMTo (r,c) m1 | ||
390 | m2' = conformMTo (r,c) m2 | ||
391 | |||
392 | -- FIXME do not flatten if equal order | ||
393 | lM f m1 m2 = matrixFromVector | ||
394 | RowMajor | ||
395 | (max (rows m1) (rows m2)) | ||
396 | (max (cols m1) (cols m2)) | ||
397 | (f (flatten m1) (flatten m2)) | ||
398 | |||
399 | compat' :: Matrix a -> Matrix b -> Bool | ||
400 | compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2 | ||
401 | where | ||
402 | s1 = size m1 | ||
403 | s2 = size m2 | ||
404 | |||
405 | ------------------------------------------------------------ | ||
406 | |||
407 | toBlockRows [r] m | r == rows m = [m] | ||
408 | toBlockRows rs m = map (reshape (cols m)) (takesV szs (flatten m)) | ||
409 | where szs = map (* cols m) rs | ||
410 | |||
411 | toBlockCols [c] m | c == cols m = [m] | ||
412 | toBlockCols cs m = map trans . toBlockRows cs . trans $ m | ||
413 | |||
414 | -- | Partition a matrix into blocks with the given numbers of rows and columns. | ||
415 | -- The remaining rows and columns are discarded. | ||
416 | toBlocks :: (Element t) => [Int] -> [Int] -> Matrix t -> [[Matrix t]] | ||
417 | toBlocks rs cs m = map (toBlockCols cs) . toBlockRows rs $ m | ||
418 | |||
419 | -- | Fully partition a matrix into blocks of the same size. If the dimensions are not | ||
420 | -- a multiple of the given size the last blocks will be smaller. | ||
421 | toBlocksEvery :: (Element t) => Int -> Int -> Matrix t -> [[Matrix t]] | ||
422 | toBlocksEvery r c m | ||
423 | | r < 1 || c < 1 = error $ "toBlocksEvery expects block sizes > 0, given "++show r++" and "++ show c | ||
424 | | otherwise = toBlocks rs cs m | ||
425 | where | ||
426 | (qr,rr) = rows m `divMod` r | ||
427 | (qc,rc) = cols m `divMod` c | ||
428 | rs = replicate qr r ++ if rr > 0 then [rr] else [] | ||
429 | cs = replicate qc c ++ if rc > 0 then [rc] else [] | ||
430 | |||
431 | ------------------------------------------------------------------- | ||
432 | |||
433 | -- Given a column number and a function taking matrix indexes, returns | ||
434 | -- a function which takes vector indexes (that can be used on the | ||
435 | -- flattened matrix). | ||
436 | mk :: Int -> ((Int, Int) -> t) -> (Int -> t) | ||
437 | mk c g = \k -> g (divMod k c) | ||
438 | |||
439 | {- | | ||
440 | |||
441 | >>> mapMatrixWithIndexM_ (\(i,j) v -> printf "m[%d,%d] = %.f\n" i j v :: IO()) ((2><3)[1 :: Double ..]) | ||
442 | m[0,0] = 1 | ||
443 | m[0,1] = 2 | ||
444 | m[0,2] = 3 | ||
445 | m[1,0] = 4 | ||
446 | m[1,1] = 5 | ||
447 | m[1,2] = 6 | ||
448 | |||
449 | -} | ||
450 | mapMatrixWithIndexM_ | ||
451 | :: (Element a, Num a, Monad m) => | ||
452 | ((Int, Int) -> a -> m ()) -> Matrix a -> m () | ||
453 | mapMatrixWithIndexM_ g m = mapVectorWithIndexM_ (mk c g) . flatten $ m | ||
454 | where | ||
455 | c = cols m | ||
456 | |||
457 | {- | | ||
458 | |||
459 | >>> mapMatrixWithIndexM (\(i,j) v -> Just $ 100*v + 10*fromIntegral i + fromIntegral j) (ident 3:: Matrix Double) | ||
460 | Just (3><3) | ||
461 | [ 100.0, 1.0, 2.0 | ||
462 | , 10.0, 111.0, 12.0 | ||
463 | , 20.0, 21.0, 122.0 ] | ||
464 | |||
465 | -} | ||
466 | mapMatrixWithIndexM | ||
467 | :: (Element a, Storable b, Monad m) => | ||
468 | ((Int, Int) -> a -> m b) -> Matrix a -> m (Matrix b) | ||
469 | mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . flatten $ m | ||
470 | where | ||
471 | c = cols m | ||
472 | |||
473 | {- | | ||
474 | |||
475 | >>> mapMatrixWithIndex (\\(i,j) v -> 100*v + 10*fromIntegral i + fromIntegral j) (ident 3:: Matrix Double) | ||
476 | (3><3) | ||
477 | [ 100.0, 1.0, 2.0 | ||
478 | , 10.0, 111.0, 12.0 | ||
479 | , 20.0, 21.0, 122.0 ] | ||
480 | |||
481 | -} | ||
482 | mapMatrixWithIndex | ||
483 | :: (Element a, Storable b) => | ||
484 | ((Int, Int) -> a -> b) -> Matrix a -> Matrix b | ||
485 | mapMatrixWithIndex g m = reshape c . mapVectorWithIndex (mk c g) . flatten $ m | ||
486 | where | ||
487 | c = cols m | ||
488 | |||
489 | mapMatrix :: (Storable a, Storable b) => (a -> b) -> Matrix a -> Matrix b | ||
490 | mapMatrix f = liftMatrix (mapVector f) | ||
diff --git a/packages/base/src/Data/Packed/ST.hs b/packages/base/src/Data/Packed/ST.hs new file mode 100644 index 0000000..dae457c --- /dev/null +++ b/packages/base/src/Data/Packed/ST.hs | |||
@@ -0,0 +1,178 @@ | |||
1 | {-# LANGUAGE CPP #-} | ||
2 | {-# LANGUAGE TypeOperators #-} | ||
3 | {-# LANGUAGE Rank2Types #-} | ||
4 | {-# LANGUAGE BangPatterns #-} | ||
5 | ----------------------------------------------------------------------------- | ||
6 | -- | | ||
7 | -- Module : Data.Packed.ST | ||
8 | -- Copyright : (c) Alberto Ruiz 2008 | ||
9 | -- License : BSD3 | ||
10 | -- Maintainer : Alberto Ruiz | ||
11 | -- Stability : provisional | ||
12 | -- Portability : portable | ||
13 | -- | ||
14 | -- In-place manipulation inside the ST monad. | ||
15 | -- See examples/inplace.hs in the distribution. | ||
16 | -- | ||
17 | ----------------------------------------------------------------------------- | ||
18 | |||
19 | module Data.Packed.ST ( | ||
20 | -- * Mutable Vectors | ||
21 | STVector, newVector, thawVector, freezeVector, runSTVector, | ||
22 | readVector, writeVector, modifyVector, liftSTVector, | ||
23 | -- * Mutable Matrices | ||
24 | STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, | ||
25 | readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, | ||
26 | -- * Unsafe functions | ||
27 | newUndefinedVector, | ||
28 | unsafeReadVector, unsafeWriteVector, | ||
29 | unsafeThawVector, unsafeFreezeVector, | ||
30 | newUndefinedMatrix, | ||
31 | unsafeReadMatrix, unsafeWriteMatrix, | ||
32 | unsafeThawMatrix, unsafeFreezeMatrix | ||
33 | ) where | ||
34 | |||
35 | import Data.Packed.Internal | ||
36 | |||
37 | import Control.Monad.ST(ST, runST) | ||
38 | import Foreign.Storable(Storable, peekElemOff, pokeElemOff) | ||
39 | |||
40 | #if MIN_VERSION_base(4,4,0) | ||
41 | import Control.Monad.ST.Unsafe(unsafeIOToST) | ||
42 | #else | ||
43 | import Control.Monad.ST(unsafeIOToST) | ||
44 | #endif | ||
45 | |||
46 | {-# INLINE ioReadV #-} | ||
47 | ioReadV :: Storable t => Vector t -> Int -> IO t | ||
48 | ioReadV v k = unsafeWith v $ \s -> peekElemOff s k | ||
49 | |||
50 | {-# INLINE ioWriteV #-} | ||
51 | ioWriteV :: Storable t => Vector t -> Int -> t -> IO () | ||
52 | ioWriteV v k x = unsafeWith v $ \s -> pokeElemOff s k x | ||
53 | |||
54 | newtype STVector s t = STVector (Vector t) | ||
55 | |||
56 | thawVector :: Storable t => Vector t -> ST s (STVector s t) | ||
57 | thawVector = unsafeIOToST . fmap STVector . cloneVector | ||
58 | |||
59 | unsafeThawVector :: Storable t => Vector t -> ST s (STVector s t) | ||
60 | unsafeThawVector = unsafeIOToST . return . STVector | ||
61 | |||
62 | runSTVector :: Storable t => (forall s . ST s (STVector s t)) -> Vector t | ||
63 | runSTVector st = runST (st >>= unsafeFreezeVector) | ||
64 | |||
65 | {-# INLINE unsafeReadVector #-} | ||
66 | unsafeReadVector :: Storable t => STVector s t -> Int -> ST s t | ||
67 | unsafeReadVector (STVector x) = unsafeIOToST . ioReadV x | ||
68 | |||
69 | {-# INLINE unsafeWriteVector #-} | ||
70 | unsafeWriteVector :: Storable t => STVector s t -> Int -> t -> ST s () | ||
71 | unsafeWriteVector (STVector x) k = unsafeIOToST . ioWriteV x k | ||
72 | |||
73 | {-# INLINE modifyVector #-} | ||
74 | modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s () | ||
75 | modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k | ||
76 | |||
77 | liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s1 t -> ST s2 a | ||
78 | liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x | ||
79 | |||
80 | freezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t) | ||
81 | freezeVector v = liftSTVector id v | ||
82 | |||
83 | unsafeFreezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t) | ||
84 | unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x | ||
85 | |||
86 | {-# INLINE safeIndexV #-} | ||
87 | safeIndexV f (STVector v) k | ||
88 | | k < 0 || k>= dim v = error $ "out of range error in vector (dim=" | ||
89 | ++show (dim v)++", pos="++show k++")" | ||
90 | | otherwise = f (STVector v) k | ||
91 | |||
92 | {-# INLINE readVector #-} | ||
93 | readVector :: Storable t => STVector s t -> Int -> ST s t | ||
94 | readVector = safeIndexV unsafeReadVector | ||
95 | |||
96 | {-# INLINE writeVector #-} | ||
97 | writeVector :: Storable t => STVector s t -> Int -> t -> ST s () | ||
98 | writeVector = safeIndexV unsafeWriteVector | ||
99 | |||
100 | newUndefinedVector :: Storable t => Int -> ST s (STVector s t) | ||
101 | newUndefinedVector = unsafeIOToST . fmap STVector . createVector | ||
102 | |||
103 | {-# INLINE newVector #-} | ||
104 | newVector :: Storable t => t -> Int -> ST s (STVector s t) | ||
105 | newVector x n = do | ||
106 | v <- newUndefinedVector n | ||
107 | let go (-1) = return v | ||
108 | go !k = unsafeWriteVector v k x >> go (k-1 :: Int) | ||
109 | go (n-1) | ||
110 | |||
111 | ------------------------------------------------------------------------- | ||
112 | |||
113 | {-# INLINE ioReadM #-} | ||
114 | ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t | ||
115 | ioReadM (Matrix _ nc cv RowMajor) r c = ioReadV cv (r*nc+c) | ||
116 | ioReadM (Matrix nr _ fv ColumnMajor) r c = ioReadV fv (c*nr+r) | ||
117 | |||
118 | {-# INLINE ioWriteM #-} | ||
119 | ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO () | ||
120 | ioWriteM (Matrix _ nc cv RowMajor) r c val = ioWriteV cv (r*nc+c) val | ||
121 | ioWriteM (Matrix nr _ fv ColumnMajor) r c val = ioWriteV fv (c*nr+r) val | ||
122 | |||
123 | newtype STMatrix s t = STMatrix (Matrix t) | ||
124 | |||
125 | thawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) | ||
126 | thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix | ||
127 | |||
128 | unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) | ||
129 | unsafeThawMatrix = unsafeIOToST . return . STMatrix | ||
130 | |||
131 | runSTMatrix :: Storable t => (forall s . ST s (STMatrix s t)) -> Matrix t | ||
132 | runSTMatrix st = runST (st >>= unsafeFreezeMatrix) | ||
133 | |||
134 | {-# INLINE unsafeReadMatrix #-} | ||
135 | unsafeReadMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t | ||
136 | unsafeReadMatrix (STMatrix x) r = unsafeIOToST . ioReadM x r | ||
137 | |||
138 | {-# INLINE unsafeWriteMatrix #-} | ||
139 | unsafeWriteMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s () | ||
140 | unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c | ||
141 | |||
142 | {-# INLINE modifyMatrix #-} | ||
143 | modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () | ||
144 | modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c | ||
145 | |||
146 | liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s1 t -> ST s2 a | ||
147 | liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x | ||
148 | |||
149 | unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) | ||
150 | unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x | ||
151 | |||
152 | freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) | ||
153 | freezeMatrix m = liftSTMatrix id m | ||
154 | |||
155 | cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o) | ||
156 | |||
157 | {-# INLINE safeIndexM #-} | ||
158 | safeIndexM f (STMatrix m) r c | ||
159 | | r<0 || r>=rows m || | ||
160 | c<0 || c>=cols m = error $ "out of range error in matrix (size=" | ||
161 | ++show (rows m,cols m)++", pos="++show (r,c)++")" | ||
162 | | otherwise = f (STMatrix m) r c | ||
163 | |||
164 | {-# INLINE readMatrix #-} | ||
165 | readMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t | ||
166 | readMatrix = safeIndexM unsafeReadMatrix | ||
167 | |||
168 | {-# INLINE writeMatrix #-} | ||
169 | writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s () | ||
170 | writeMatrix = safeIndexM unsafeWriteMatrix | ||
171 | |||
172 | newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t) | ||
173 | newUndefinedMatrix ord r c = unsafeIOToST $ fmap STMatrix $ createMatrix ord r c | ||
174 | |||
175 | {-# NOINLINE newMatrix #-} | ||
176 | newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t) | ||
177 | newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c) | ||
178 | |||
diff --git a/packages/base/src/Data/Packed/Vector.hs b/packages/base/src/Data/Packed/Vector.hs new file mode 100644 index 0000000..b5a4318 --- /dev/null +++ b/packages/base/src/Data/Packed/Vector.hs | |||
@@ -0,0 +1,96 @@ | |||
1 | {-# LANGUAGE FlexibleContexts #-} | ||
2 | {-# LANGUAGE CPP #-} | ||
3 | ----------------------------------------------------------------------------- | ||
4 | -- | | ||
5 | -- Module : Data.Packed.Vector | ||
6 | -- Copyright : (c) Alberto Ruiz 2007-10 | ||
7 | -- License : GPL | ||
8 | -- | ||
9 | -- Maintainer : Alberto Ruiz <aruiz@um.es> | ||
10 | -- Stability : provisional | ||
11 | -- | ||
12 | -- 1D arrays suitable for numeric computations using external libraries. | ||
13 | -- | ||
14 | -- This module provides basic functions for manipulation of structure. | ||
15 | -- | ||
16 | ----------------------------------------------------------------------------- | ||
17 | {-# OPTIONS_HADDOCK hide #-} | ||
18 | |||
19 | module Data.Packed.Vector ( | ||
20 | Vector, | ||
21 | fromList, (|>), toList, buildVector, | ||
22 | dim, (@>), | ||
23 | subVector, takesV, vjoin, join, | ||
24 | mapVector, mapVectorWithIndex, zipVector, zipVectorWith, unzipVector, unzipVectorWith, | ||
25 | mapVectorM, mapVectorM_, mapVectorWithIndexM, mapVectorWithIndexM_, | ||
26 | foldLoop, foldVector, foldVectorG, foldVectorWithIndex | ||
27 | ) where | ||
28 | |||
29 | import Data.Packed.Internal.Vector | ||
30 | import Foreign.Storable | ||
31 | |||
32 | ------------------------------------------------------------------- | ||
33 | |||
34 | #ifdef BINARY | ||
35 | |||
36 | import Data.Binary | ||
37 | import Control.Monad(replicateM) | ||
38 | |||
39 | -- a 64K cache, with a Double taking 13 bytes in Bytestring, | ||
40 | -- implies a chunk size of 5041 | ||
41 | chunk :: Int | ||
42 | chunk = 5000 | ||
43 | |||
44 | chunks :: Int -> [Int] | ||
45 | chunks d = let c = d `div` chunk | ||
46 | m = d `mod` chunk | ||
47 | in if m /= 0 then reverse (m:(replicate c chunk)) else (replicate c chunk) | ||
48 | |||
49 | putVector v = do | ||
50 | let d = dim v | ||
51 | mapM_ (\i -> put $ v @> i) [0..(d-1)] | ||
52 | |||
53 | getVector d = do | ||
54 | xs <- replicateM d get | ||
55 | return $! fromList xs | ||
56 | |||
57 | instance (Binary a, Storable a) => Binary (Vector a) where | ||
58 | put v = do | ||
59 | let d = dim v | ||
60 | put d | ||
61 | mapM_ putVector $! takesV (chunks d) v | ||
62 | get = do | ||
63 | d <- get | ||
64 | vs <- mapM getVector $ chunks d | ||
65 | return $! vjoin vs | ||
66 | |||
67 | #endif | ||
68 | |||
69 | ------------------------------------------------------------------- | ||
70 | |||
71 | {- | creates a Vector of the specified length using the supplied function to | ||
72 | to map the index to the value at that index. | ||
73 | |||
74 | @> buildVector 4 fromIntegral | ||
75 | 4 |> [0.0,1.0,2.0,3.0]@ | ||
76 | |||
77 | -} | ||
78 | buildVector :: Storable a => Int -> (Int -> a) -> Vector a | ||
79 | buildVector len f = | ||
80 | fromList $ map f [0 .. (len - 1)] | ||
81 | |||
82 | |||
83 | -- | zip for Vectors | ||
84 | zipVector :: (Storable a, Storable b, Storable (a,b)) => Vector a -> Vector b -> Vector (a,b) | ||
85 | zipVector = zipVectorWith (,) | ||
86 | |||
87 | -- | unzip for Vectors | ||
88 | unzipVector :: (Storable a, Storable b, Storable (a,b)) => Vector (a,b) -> (Vector a,Vector b) | ||
89 | unzipVector = unzipVectorWith id | ||
90 | |||
91 | ------------------------------------------------------------------- | ||
92 | |||
93 | {-# DEPRECATED join "use vjoin or Data.Vector.concat" #-} | ||
94 | join :: Storable t => [Vector t] -> Vector t | ||
95 | join = vjoin | ||
96 | |||