summaryrefslogtreecommitdiff
path: root/packages/base
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2014-05-08 12:16:42 +0200
committerAlberto Ruiz <aruiz@um.es>2014-05-08 12:16:42 +0200
commit5992d92357cfd911c8f2e9f5faaa4fd8a323fd9a (patch)
treec6cdc40bc0121fd87a5137d8b7cb4d9f064cab47 /packages/base
parentd558f5165e7e7f4daffadae1197e53618727971d (diff)
Data.Packed -> base (I)
Diffstat (limited to 'packages/base')
-rw-r--r--packages/base/hmatrix-base.cabal30
-rw-r--r--packages/base/src/C/lapack-aux.c1489
-rw-r--r--packages/base/src/C/lapack-aux.h60
-rw-r--r--packages/base/src/Data/Packed.hs25
-rw-r--r--packages/base/src/Data/Packed/Development.hs31
-rw-r--r--packages/base/src/Data/Packed/Foreign.hs99
-rw-r--r--packages/base/src/Data/Packed/Internal.hs26
-rw-r--r--packages/base/src/Data/Packed/Internal/Common.hs160
-rw-r--r--packages/base/src/Data/Packed/Internal/Matrix.hs422
-rw-r--r--packages/base/src/Data/Packed/Internal/Signatures.hs70
-rw-r--r--packages/base/src/Data/Packed/Internal/Vector.hs471
-rw-r--r--packages/base/src/Data/Packed/Matrix.hs490
-rw-r--r--packages/base/src/Data/Packed/ST.hs178
-rw-r--r--packages/base/src/Data/Packed/Vector.hs96
14 files changed, 3642 insertions, 5 deletions
diff --git a/packages/base/hmatrix-base.cabal b/packages/base/hmatrix-base.cabal
index 148c8f3..96f90e2 100644
--- a/packages/base/hmatrix-base.cabal
+++ b/packages/base/hmatrix-base.cabal
@@ -15,22 +15,42 @@ cabal-version: >=1.8
15 15
16build-type: Simple 16build-type: Simple
17 17
18extra-source-files: 18extra-source-files: src/C/lapack-aux.h
19 19
20library 20library
21 21
22 Build-Depends: base >= 4 && < 5 22 Build-Depends: base,
23 binary,
24 array,
25 deepseq,
26 storable-complex,
27 vector >= 0.8
23 28
24 hs-source-dirs: src 29 hs-source-dirs: src
25 30
26 exposed-modules: 31 exposed-modules: Data.Packed,
27 32 Data.Packed.Vector,
28 other-modules: 33 Data.Packed.Matrix,
34 Data.Packed.Foreign,
35 Data.Packed.ST,
36 Data.Packed.Development
37
38 other-modules: Data.Packed.Internal,
39 Data.Packed.Internal.Common,
40 Data.Packed.Internal.Signatures,
41 Data.Packed.Internal.Vector,
42 Data.Packed.Internal.Matrix
43
44 C-sources: src/C/lapack-aux.c
45
29 46
30 extensions: ForeignFunctionInterface, 47 extensions: ForeignFunctionInterface,
31 CPP 48 CPP
32 49
33 ghc-options: -Wall 50 ghc-options: -Wall
51 -fno-warn-missing-signatures
52 -fno-warn-orphans
53-- -fno-warn-unused-binds
34 54
35 cc-options: -O4 -msse2 -Wall 55 cc-options: -O4 -msse2 -Wall
36 56
diff --git a/packages/base/src/C/lapack-aux.c b/packages/base/src/C/lapack-aux.c
new file mode 100644
index 0000000..e5e45ef
--- /dev/null
+++ b/packages/base/src/C/lapack-aux.c
@@ -0,0 +1,1489 @@
1#include <stdio.h>
2#include <stdlib.h>
3#include <string.h>
4#include <math.h>
5#include <time.h>
6#include "lapack-aux.h"
7
8#define MACRO(B) do {B} while (0)
9#define ERROR(CODE) MACRO(return CODE;)
10#define REQUIRES(COND, CODE) MACRO(if(!(COND)) {ERROR(CODE);})
11
12#define MIN(A,B) ((A)<(B)?(A):(B))
13#define MAX(A,B) ((A)>(B)?(A):(B))
14
15// #define DBGL
16
17#ifdef DBGL
18#define DEBUGMSG(M) printf("\nLAPACK "M"\n");
19#else
20#define DEBUGMSG(M)
21#endif
22
23#define OK return 0;
24
25// #ifdef DBGL
26// #define DEBUGMSG(M) printf("LAPACK Wrapper "M"\n: "); size_t t0 = time(NULL);
27// #define OK MACRO(printf("%ld s\n",time(0)-t0); return 0;);
28// #else
29// #define DEBUGMSG(M)
30// #define OK return 0;
31// #endif
32
33#define TRACEMAT(M) {int q; printf(" %d x %d: ",M##r,M##c); \
34 for(q=0;q<M##r*M##c;q++) printf("%.1f ",M##p[q]); printf("\n");}
35
36#define CHECK(RES,CODE) MACRO(if(RES) return CODE;)
37
38#define BAD_SIZE 2000
39#define BAD_CODE 2001
40#define MEM 2002
41#define BAD_FILE 2003
42#define SINGULAR 2004
43#define NOCONVER 2005
44#define NODEFPOS 2006
45#define NOSPRTD 2007
46
47//---------------------------------------
48void asm_finit() {
49#ifdef i386
50
51// asm("finit");
52
53 static unsigned char buf[108];
54 asm("FSAVE %0":"=m" (buf));
55
56 #if FPUDEBUG
57 if(buf[8]!=255 || buf[9]!=255) { // print warning in red
58 printf("%c[;31mWarning: FPU TAG = %x %x\%c[0m\n",0x1B,buf[8],buf[9],0x1B);
59 }
60 #endif
61
62 #if NANDEBUG
63 asm("FRSTOR %0":"=m" (buf));
64 #endif
65
66#endif
67}
68
69//---------------------------------------
70
71#if NANDEBUG
72
73#define CHECKNANR(M,msg) \
74{ int k; \
75for(k=0; k<(M##r * M##c); k++) { \
76 if(M##p[k] != M##p[k]) { \
77 printf(msg); \
78 TRACEMAT(M) \
79 /*exit(1);*/ \
80 } \
81} \
82}
83
84#define CHECKNANC(M,msg) \
85{ int k; \
86for(k=0; k<(M##r * M##c); k++) { \
87 if( M##p[k].r != M##p[k].r \
88 || M##p[k].i != M##p[k].i) { \
89 printf(msg); \
90 /*exit(1);*/ \
91 } \
92} \
93}
94
95#else
96#define CHECKNANC(M,msg)
97#define CHECKNANR(M,msg)
98#endif
99
100//---------------------------------------
101
102//////////////////// real svd ////////////////////////////////////
103
104/* Subroutine */ int dgesvd_(char *jobu, char *jobvt, integer *m, integer *n,
105 doublereal *a, integer *lda, doublereal *s, doublereal *u, integer *
106 ldu, doublereal *vt, integer *ldvt, doublereal *work, integer *lwork,
107 integer *info);
108
109int svd_l_R(KDMAT(a),DMAT(u), DVEC(s),DMAT(v)) {
110 integer m = ar;
111 integer n = ac;
112 integer q = MIN(m,n);
113 REQUIRES(sn==q,BAD_SIZE);
114 REQUIRES(up==NULL || (ur==m && (uc==m || uc==q)),BAD_SIZE);
115 char* jobu = "A";
116 if (up==NULL) {
117 jobu = "N";
118 } else {
119 if (uc==q) {
120 jobu = "S";
121 }
122 }
123 REQUIRES(vp==NULL || (vc==n && (vr==n || vr==q)),BAD_SIZE);
124 char* jobvt = "A";
125 integer ldvt = n;
126 if (vp==NULL) {
127 jobvt = "N";
128 } else {
129 if (vr==q) {
130 jobvt = "S";
131 ldvt = q;
132 }
133 }
134 DEBUGMSG("svd_l_R");
135 double *B = (double*)malloc(m*n*sizeof(double));
136 CHECK(!B,MEM);
137 memcpy(B,ap,m*n*sizeof(double));
138 integer lwork = -1;
139 integer res;
140 // ask for optimal lwork
141 double ans;
142 dgesvd_ (jobu,jobvt,
143 &m,&n,B,&m,
144 sp,
145 up,&m,
146 vp,&ldvt,
147 &ans, &lwork,
148 &res);
149 lwork = ceil(ans);
150 double * work = (double*)malloc(lwork*sizeof(double));
151 CHECK(!work,MEM);
152 dgesvd_ (jobu,jobvt,
153 &m,&n,B,&m,
154 sp,
155 up,&m,
156 vp,&ldvt,
157 work, &lwork,
158 &res);
159 CHECK(res,res);
160 free(work);
161 free(B);
162 OK
163}
164
165// (alternative version)
166
167/* Subroutine */ int dgesdd_(char *jobz, integer *m, integer *n, doublereal *
168 a, integer *lda, doublereal *s, doublereal *u, integer *ldu,
169 doublereal *vt, integer *ldvt, doublereal *work, integer *lwork,
170 integer *iwork, integer *info);
171
172int svd_l_Rdd(KDMAT(a),DMAT(u), DVEC(s),DMAT(v)) {
173 integer m = ar;
174 integer n = ac;
175 integer q = MIN(m,n);
176 REQUIRES(sn==q,BAD_SIZE);
177 REQUIRES((up == NULL && vp == NULL)
178 || (ur==m && vc==n
179 && ((uc == q && vr == q)
180 || (uc == m && vc==n))),BAD_SIZE);
181 char* jobz = "A";
182 integer ldvt = n;
183 if (up==NULL) {
184 jobz = "N";
185 } else {
186 if (uc==q && vr == q) {
187 jobz = "S";
188 ldvt = q;
189 }
190 }
191 DEBUGMSG("svd_l_Rdd");
192 double *B = (double*)malloc(m*n*sizeof(double));
193 CHECK(!B,MEM);
194 memcpy(B,ap,m*n*sizeof(double));
195 integer* iwk = (integer*) malloc(8*q*sizeof(integer));
196 CHECK(!iwk,MEM);
197 integer lwk = -1;
198 integer res;
199 // ask for optimal lwk
200 double ans;
201 dgesdd_ (jobz,&m,&n,B,&m,sp,up,&m,vp,&ldvt,&ans,&lwk,iwk,&res);
202 lwk = ans;
203 double * workv = (double*)malloc(lwk*sizeof(double));
204 CHECK(!workv,MEM);
205 dgesdd_ (jobz,&m,&n,B,&m,sp,up,&m,vp,&ldvt,workv,&lwk,iwk,&res);
206 CHECK(res,res);
207 free(iwk);
208 free(workv);
209 free(B);
210 OK
211}
212
213//////////////////// complex svd ////////////////////////////////////
214
215// not in clapack.h
216
217int zgesvd_(char *jobu, char *jobvt, integer *m, integer *n,
218 doublecomplex *a, integer *lda, doublereal *s, doublecomplex *u,
219 integer *ldu, doublecomplex *vt, integer *ldvt, doublecomplex *work,
220 integer *lwork, doublereal *rwork, integer *info);
221
222int svd_l_C(KCMAT(a),CMAT(u), DVEC(s),CMAT(v)) {
223 integer m = ar;
224 integer n = ac;
225 integer q = MIN(m,n);
226 REQUIRES(sn==q,BAD_SIZE);
227 REQUIRES(up==NULL || (ur==m && (uc==m || uc==q)),BAD_SIZE);
228 char* jobu = "A";
229 if (up==NULL) {
230 jobu = "N";
231 } else {
232 if (uc==q) {
233 jobu = "S";
234 }
235 }
236 REQUIRES(vp==NULL || (vc==n && (vr==n || vr==q)),BAD_SIZE);
237 char* jobvt = "A";
238 integer ldvt = n;
239 if (vp==NULL) {
240 jobvt = "N";
241 } else {
242 if (vr==q) {
243 jobvt = "S";
244 ldvt = q;
245 }
246 }DEBUGMSG("svd_l_C");
247 doublecomplex *B = (doublecomplex*)malloc(m*n*sizeof(doublecomplex));
248 CHECK(!B,MEM);
249 memcpy(B,ap,m*n*sizeof(doublecomplex));
250
251 double *rwork = (double*) malloc(5*q*sizeof(double));
252 CHECK(!rwork,MEM);
253 integer lwork = -1;
254 integer res;
255 // ask for optimal lwork
256 doublecomplex ans;
257 zgesvd_ (jobu,jobvt,
258 &m,&n,B,&m,
259 sp,
260 up,&m,
261 vp,&ldvt,
262 &ans, &lwork,
263 rwork,
264 &res);
265 lwork = ceil(ans.r);
266 doublecomplex * work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex));
267 CHECK(!work,MEM);
268 zgesvd_ (jobu,jobvt,
269 &m,&n,B,&m,
270 sp,
271 up,&m,
272 vp,&ldvt,
273 work, &lwork,
274 rwork,
275 &res);
276 CHECK(res,res);
277 free(work);
278 free(rwork);
279 free(B);
280 OK
281}
282
283int zgesdd_ (char *jobz, integer *m, integer *n,
284 doublecomplex *a, integer *lda, doublereal *s, doublecomplex *u,
285 integer *ldu, doublecomplex *vt, integer *ldvt, doublecomplex *work,
286 integer *lwork, doublereal *rwork, integer* iwork, integer *info);
287
288int svd_l_Cdd(KCMAT(a),CMAT(u), DVEC(s),CMAT(v)) {
289 //printf("entro\n");
290 integer m = ar;
291 integer n = ac;
292 integer q = MIN(m,n);
293 REQUIRES(sn==q,BAD_SIZE);
294 REQUIRES((up == NULL && vp == NULL)
295 || (ur==m && vc==n
296 && ((uc == q && vr == q)
297 || (uc == m && vc==n))),BAD_SIZE);
298 char* jobz = "A";
299 integer ldvt = n;
300 if (up==NULL) {
301 jobz = "N";
302 } else {
303 if (uc==q && vr == q) {
304 jobz = "S";
305 ldvt = q;
306 }
307 }
308 DEBUGMSG("svd_l_Cdd");
309 doublecomplex *B = (doublecomplex*)malloc(m*n*sizeof(doublecomplex));
310 CHECK(!B,MEM);
311 memcpy(B,ap,m*n*sizeof(doublecomplex));
312 integer* iwk = (integer*) malloc(8*q*sizeof(integer));
313 CHECK(!iwk,MEM);
314 int lrwk;
315 if (0 && *jobz == 'N') {
316 lrwk = 5*q; // does not work, crash at free below
317 } else {
318 lrwk = 5*q*q + 7*q;
319 }
320 double *rwk = (double*)malloc(lrwk*sizeof(double));;
321 CHECK(!rwk,MEM);
322 //printf("%s %ld %d\n",jobz,q,lrwk);
323 integer lwk = -1;
324 integer res;
325 // ask for optimal lwk
326 doublecomplex ans;
327 zgesdd_ (jobz,&m,&n,B,&m,sp,up,&m,vp,&ldvt,&ans,&lwk,rwk,iwk,&res);
328 lwk = ans.r;
329 //printf("lwk = %ld\n",lwk);
330 doublecomplex * workv = (doublecomplex*)malloc(lwk*sizeof(doublecomplex));
331 CHECK(!workv,MEM);
332 zgesdd_ (jobz,&m,&n,B,&m,sp,up,&m,vp,&ldvt,workv,&lwk,rwk,iwk,&res);
333 //printf("res = %ld\n",res);
334 CHECK(res,res);
335 free(workv); // printf("freed workv\n");
336 free(rwk); // printf("freed rwk\n");
337 free(iwk); // printf("freed iwk\n");
338 free(B); // printf("freed B, salgo\n");
339 OK
340}
341
342//////////////////// general complex eigensystem ////////////
343
344/* Subroutine */ int zgeev_(char *jobvl, char *jobvr, integer *n,
345 doublecomplex *a, integer *lda, doublecomplex *w, doublecomplex *vl,
346 integer *ldvl, doublecomplex *vr, integer *ldvr, doublecomplex *work,
347 integer *lwork, doublereal *rwork, integer *info);
348
349int eig_l_C(KCMAT(a), CMAT(u), CVEC(s),CMAT(v)) {
350 integer n = ar;
351 REQUIRES(ac==n && sn==n, BAD_SIZE);
352 REQUIRES(up==NULL || (ur==n && uc==n), BAD_SIZE);
353 char jobvl = up==NULL?'N':'V';
354 REQUIRES(vp==NULL || (vr==n && vc==n), BAD_SIZE);
355 char jobvr = vp==NULL?'N':'V';
356 DEBUGMSG("eig_l_C");
357 doublecomplex *B = (doublecomplex*)malloc(n*n*sizeof(doublecomplex));
358 CHECK(!B,MEM);
359 memcpy(B,ap,n*n*sizeof(doublecomplex));
360 double *rwork = (double*) malloc(2*n*sizeof(double));
361 CHECK(!rwork,MEM);
362 integer lwork = -1;
363 integer res;
364 // ask for optimal lwork
365 doublecomplex ans;
366 //printf("ask zgeev\n");
367 zgeev_ (&jobvl,&jobvr,
368 &n,B,&n,
369 sp,
370 up,&n,
371 vp,&n,
372 &ans, &lwork,
373 rwork,
374 &res);
375 lwork = ceil(ans.r);
376 //printf("ans = %d\n",lwork);
377 doublecomplex * work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex));
378 CHECK(!work,MEM);
379 //printf("zgeev\n");
380 zgeev_ (&jobvl,&jobvr,
381 &n,B,&n,
382 sp,
383 up,&n,
384 vp,&n,
385 work, &lwork,
386 rwork,
387 &res);
388 CHECK(res,res);
389 free(work);
390 free(rwork);
391 free(B);
392 OK
393}
394
395
396
397//////////////////// general real eigensystem ////////////
398
399/* Subroutine */ int dgeev_(char *jobvl, char *jobvr, integer *n, doublereal *
400 a, integer *lda, doublereal *wr, doublereal *wi, doublereal *vl,
401 integer *ldvl, doublereal *vr, integer *ldvr, doublereal *work,
402 integer *lwork, integer *info);
403
404int eig_l_R(KDMAT(a),DMAT(u), CVEC(s),DMAT(v)) {
405 integer n = ar;
406 REQUIRES(ac==n && sn==n, BAD_SIZE);
407 REQUIRES(up==NULL || (ur==n && uc==n), BAD_SIZE);
408 char jobvl = up==NULL?'N':'V';
409 REQUIRES(vp==NULL || (vr==n && vc==n), BAD_SIZE);
410 char jobvr = vp==NULL?'N':'V';
411 DEBUGMSG("eig_l_R");
412 double *B = (double*)malloc(n*n*sizeof(double));
413 CHECK(!B,MEM);
414 memcpy(B,ap,n*n*sizeof(double));
415 integer lwork = -1;
416 integer res;
417 // ask for optimal lwork
418 double ans;
419 //printf("ask dgeev\n");
420 dgeev_ (&jobvl,&jobvr,
421 &n,B,&n,
422 (double*)sp, (double*)sp+n,
423 up,&n,
424 vp,&n,
425 &ans, &lwork,
426 &res);
427 lwork = ceil(ans);
428 //printf("ans = %d\n",lwork);
429 double * work = (double*)malloc(lwork*sizeof(double));
430 CHECK(!work,MEM);
431 //printf("dgeev\n");
432 dgeev_ (&jobvl,&jobvr,
433 &n,B,&n,
434 (double*)sp, (double*)sp+n,
435 up,&n,
436 vp,&n,
437 work, &lwork,
438 &res);
439 CHECK(res,res);
440 free(work);
441 free(B);
442 OK
443}
444
445
446//////////////////// symmetric real eigensystem ////////////
447
448/* Subroutine */ int dsyev_(char *jobz, char *uplo, integer *n, doublereal *a,
449 integer *lda, doublereal *w, doublereal *work, integer *lwork,
450 integer *info);
451
452int eig_l_S(int wantV,KDMAT(a),DVEC(s),DMAT(v)) {
453 integer n = ar;
454 REQUIRES(ac==n && sn==n, BAD_SIZE);
455 REQUIRES(vr==n && vc==n, BAD_SIZE);
456 char jobz = wantV?'V':'N';
457 DEBUGMSG("eig_l_S");
458 memcpy(vp,ap,n*n*sizeof(double));
459 integer lwork = -1;
460 char uplo = 'U';
461 integer res;
462 // ask for optimal lwork
463 double ans;
464 //printf("ask dsyev\n");
465 dsyev_ (&jobz,&uplo,
466 &n,vp,&n,
467 sp,
468 &ans, &lwork,
469 &res);
470 lwork = ceil(ans);
471 //printf("ans = %d\n",lwork);
472 double * work = (double*)malloc(lwork*sizeof(double));
473 CHECK(!work,MEM);
474 dsyev_ (&jobz,&uplo,
475 &n,vp,&n,
476 sp,
477 work, &lwork,
478 &res);
479 CHECK(res,res);
480 free(work);
481 OK
482}
483
484//////////////////// hermitian complex eigensystem ////////////
485
486/* Subroutine */ int zheev_(char *jobz, char *uplo, integer *n, doublecomplex
487 *a, integer *lda, doublereal *w, doublecomplex *work, integer *lwork,
488 doublereal *rwork, integer *info);
489
490int eig_l_H(int wantV,KCMAT(a),DVEC(s),CMAT(v)) {
491 integer n = ar;
492 REQUIRES(ac==n && sn==n, BAD_SIZE);
493 REQUIRES(vr==n && vc==n, BAD_SIZE);
494 char jobz = wantV?'V':'N';
495 DEBUGMSG("eig_l_H");
496 memcpy(vp,ap,2*n*n*sizeof(double));
497 double *rwork = (double*) malloc((3*n-2)*sizeof(double));
498 CHECK(!rwork,MEM);
499 integer lwork = -1;
500 char uplo = 'U';
501 integer res;
502 // ask for optimal lwork
503 doublecomplex ans;
504 //printf("ask zheev\n");
505 zheev_ (&jobz,&uplo,
506 &n,vp,&n,
507 sp,
508 &ans, &lwork,
509 rwork,
510 &res);
511 lwork = ceil(ans.r);
512 //printf("ans = %d\n",lwork);
513 doublecomplex * work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex));
514 CHECK(!work,MEM);
515 zheev_ (&jobz,&uplo,
516 &n,vp,&n,
517 sp,
518 work, &lwork,
519 rwork,
520 &res);
521 CHECK(res,res);
522 free(work);
523 free(rwork);
524 OK
525}
526
527//////////////////// general real linear system ////////////
528
529/* Subroutine */ int dgesv_(integer *n, integer *nrhs, doublereal *a, integer
530 *lda, integer *ipiv, doublereal *b, integer *ldb, integer *info);
531
532int linearSolveR_l(KDMAT(a),KDMAT(b),DMAT(x)) {
533 integer n = ar;
534 integer nhrs = bc;
535 REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE);
536 DEBUGMSG("linearSolveR_l");
537 double*AC = (double*)malloc(n*n*sizeof(double));
538 memcpy(AC,ap,n*n*sizeof(double));
539 memcpy(xp,bp,n*nhrs*sizeof(double));
540 integer * ipiv = (integer*)malloc(n*sizeof(integer));
541 integer res;
542 dgesv_ (&n,&nhrs,
543 AC, &n,
544 ipiv,
545 xp, &n,
546 &res);
547 if(res>0) {
548 return SINGULAR;
549 }
550 CHECK(res,res);
551 free(ipiv);
552 free(AC);
553 OK
554}
555
556//////////////////// general complex linear system ////////////
557
558/* Subroutine */ int zgesv_(integer *n, integer *nrhs, doublecomplex *a,
559 integer *lda, integer *ipiv, doublecomplex *b, integer *ldb, integer *
560 info);
561
562int linearSolveC_l(KCMAT(a),KCMAT(b),CMAT(x)) {
563 integer n = ar;
564 integer nhrs = bc;
565 REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE);
566 DEBUGMSG("linearSolveC_l");
567 doublecomplex*AC = (doublecomplex*)malloc(n*n*sizeof(doublecomplex));
568 memcpy(AC,ap,n*n*sizeof(doublecomplex));
569 memcpy(xp,bp,n*nhrs*sizeof(doublecomplex));
570 integer * ipiv = (integer*)malloc(n*sizeof(integer));
571 integer res;
572 zgesv_ (&n,&nhrs,
573 AC, &n,
574 ipiv,
575 xp, &n,
576 &res);
577 if(res>0) {
578 return SINGULAR;
579 }
580 CHECK(res,res);
581 free(ipiv);
582 free(AC);
583 OK
584}
585
586//////// symmetric positive definite real linear system using Cholesky ////////////
587
588/* Subroutine */ int dpotrs_(char *uplo, integer *n, integer *nrhs,
589 doublereal *a, integer *lda, doublereal *b, integer *ldb, integer *
590 info);
591
592int cholSolveR_l(KDMAT(a),KDMAT(b),DMAT(x)) {
593 integer n = ar;
594 integer nhrs = bc;
595 REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE);
596 DEBUGMSG("cholSolveR_l");
597 memcpy(xp,bp,n*nhrs*sizeof(double));
598 integer res;
599 dpotrs_ ("U",
600 &n,&nhrs,
601 (double*)ap, &n,
602 xp, &n,
603 &res);
604 CHECK(res,res);
605 OK
606}
607
608//////// Hermitian positive definite real linear system using Cholesky ////////////
609
610/* Subroutine */ int zpotrs_(char *uplo, integer *n, integer *nrhs,
611 doublecomplex *a, integer *lda, doublecomplex *b, integer *ldb,
612 integer *info);
613
614int cholSolveC_l(KCMAT(a),KCMAT(b),CMAT(x)) {
615 integer n = ar;
616 integer nhrs = bc;
617 REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE);
618 DEBUGMSG("cholSolveC_l");
619 memcpy(xp,bp,n*nhrs*sizeof(doublecomplex));
620 integer res;
621 zpotrs_ ("U",
622 &n,&nhrs,
623 (doublecomplex*)ap, &n,
624 xp, &n,
625 &res);
626 CHECK(res,res);
627 OK
628}
629
630//////////////////// least squares real linear system ////////////
631
632/* Subroutine */ int dgels_(char *trans, integer *m, integer *n, integer *
633 nrhs, doublereal *a, integer *lda, doublereal *b, integer *ldb,
634 doublereal *work, integer *lwork, integer *info);
635
636int linearSolveLSR_l(KDMAT(a),KDMAT(b),DMAT(x)) {
637 integer m = ar;
638 integer n = ac;
639 integer nrhs = bc;
640 integer ldb = xr;
641 REQUIRES(m>=1 && n>=1 && ar==br && xr==MAX(m,n) && xc == bc, BAD_SIZE);
642 DEBUGMSG("linearSolveLSR_l");
643 double*AC = (double*)malloc(m*n*sizeof(double));
644 memcpy(AC,ap,m*n*sizeof(double));
645 if (m>=n) {
646 memcpy(xp,bp,m*nrhs*sizeof(double));
647 } else {
648 int k;
649 for(k = 0; k<nrhs; k++) {
650 memcpy(xp+ldb*k,bp+m*k,m*sizeof(double));
651 }
652 }
653 integer res;
654 integer lwork = -1;
655 double ans;
656 dgels_ ("N",&m,&n,&nrhs,
657 AC,&m,
658 xp,&ldb,
659 &ans,&lwork,
660 &res);
661 lwork = ceil(ans);
662 //printf("ans = %d\n",lwork);
663 double * work = (double*)malloc(lwork*sizeof(double));
664 dgels_ ("N",&m,&n,&nrhs,
665 AC,&m,
666 xp,&ldb,
667 work,&lwork,
668 &res);
669 if(res>0) {
670 return SINGULAR;
671 }
672 CHECK(res,res);
673 free(work);
674 free(AC);
675 OK
676}
677
678//////////////////// least squares complex linear system ////////////
679
680/* Subroutine */ int zgels_(char *trans, integer *m, integer *n, integer *
681 nrhs, doublecomplex *a, integer *lda, doublecomplex *b, integer *ldb,
682 doublecomplex *work, integer *lwork, integer *info);
683
684int linearSolveLSC_l(KCMAT(a),KCMAT(b),CMAT(x)) {
685 integer m = ar;
686 integer n = ac;
687 integer nrhs = bc;
688 integer ldb = xr;
689 REQUIRES(m>=1 && n>=1 && ar==br && xr==MAX(m,n) && xc == bc, BAD_SIZE);
690 DEBUGMSG("linearSolveLSC_l");
691 doublecomplex*AC = (doublecomplex*)malloc(m*n*sizeof(doublecomplex));
692 memcpy(AC,ap,m*n*sizeof(doublecomplex));
693 if (m>=n) {
694 memcpy(xp,bp,m*nrhs*sizeof(doublecomplex));
695 } else {
696 int k;
697 for(k = 0; k<nrhs; k++) {
698 memcpy(xp+ldb*k,bp+m*k,m*sizeof(doublecomplex));
699 }
700 }
701 integer res;
702 integer lwork = -1;
703 doublecomplex ans;
704 zgels_ ("N",&m,&n,&nrhs,
705 AC,&m,
706 xp,&ldb,
707 &ans,&lwork,
708 &res);
709 lwork = ceil(ans.r);
710 //printf("ans = %d\n",lwork);
711 doublecomplex * work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex));
712 zgels_ ("N",&m,&n,&nrhs,
713 AC,&m,
714 xp,&ldb,
715 work,&lwork,
716 &res);
717 if(res>0) {
718 return SINGULAR;
719 }
720 CHECK(res,res);
721 free(work);
722 free(AC);
723 OK
724}
725
726//////////////////// least squares real linear system using SVD ////////////
727
728/* Subroutine */ int dgelss_(integer *m, integer *n, integer *nrhs,
729 doublereal *a, integer *lda, doublereal *b, integer *ldb, doublereal *
730 s, doublereal *rcond, integer *rank, doublereal *work, integer *lwork,
731 integer *info);
732
733int linearSolveSVDR_l(double rcond,KDMAT(a),KDMAT(b),DMAT(x)) {
734 integer m = ar;
735 integer n = ac;
736 integer nrhs = bc;
737 integer ldb = xr;
738 REQUIRES(m>=1 && n>=1 && ar==br && xr==MAX(m,n) && xc == bc, BAD_SIZE);
739 DEBUGMSG("linearSolveSVDR_l");
740 double*AC = (double*)malloc(m*n*sizeof(double));
741 double*S = (double*)malloc(MIN(m,n)*sizeof(double));
742 memcpy(AC,ap,m*n*sizeof(double));
743 if (m>=n) {
744 memcpy(xp,bp,m*nrhs*sizeof(double));
745 } else {
746 int k;
747 for(k = 0; k<nrhs; k++) {
748 memcpy(xp+ldb*k,bp+m*k,m*sizeof(double));
749 }
750 }
751 integer res;
752 integer lwork = -1;
753 integer rank;
754 double ans;
755 dgelss_ (&m,&n,&nrhs,
756 AC,&m,
757 xp,&ldb,
758 S,
759 &rcond,&rank,
760 &ans,&lwork,
761 &res);
762 lwork = ceil(ans);
763 //printf("ans = %d\n",lwork);
764 double * work = (double*)malloc(lwork*sizeof(double));
765 dgelss_ (&m,&n,&nrhs,
766 AC,&m,
767 xp,&ldb,
768 S,
769 &rcond,&rank,
770 work,&lwork,
771 &res);
772 if(res>0) {
773 return NOCONVER;
774 }
775 CHECK(res,res);
776 free(work);
777 free(S);
778 free(AC);
779 OK
780}
781
782//////////////////// least squares complex linear system using SVD ////////////
783
784// not in clapack.h
785
786int zgelss_(integer *m, integer *n, integer *nhrs,
787 doublecomplex *a, integer *lda, doublecomplex *b, integer *ldb, doublereal *s,
788 doublereal *rcond, integer* rank,
789 doublecomplex *work, integer* lwork, doublereal* rwork,
790 integer *info);
791
792int linearSolveSVDC_l(double rcond, KCMAT(a),KCMAT(b),CMAT(x)) {
793 integer m = ar;
794 integer n = ac;
795 integer nrhs = bc;
796 integer ldb = xr;
797 REQUIRES(m>=1 && n>=1 && ar==br && xr==MAX(m,n) && xc == bc, BAD_SIZE);
798 DEBUGMSG("linearSolveSVDC_l");
799 doublecomplex*AC = (doublecomplex*)malloc(m*n*sizeof(doublecomplex));
800 double*S = (double*)malloc(MIN(m,n)*sizeof(double));
801 double*RWORK = (double*)malloc(5*MIN(m,n)*sizeof(double));
802 memcpy(AC,ap,m*n*sizeof(doublecomplex));
803 if (m>=n) {
804 memcpy(xp,bp,m*nrhs*sizeof(doublecomplex));
805 } else {
806 int k;
807 for(k = 0; k<nrhs; k++) {
808 memcpy(xp+ldb*k,bp+m*k,m*sizeof(doublecomplex));
809 }
810 }
811 integer res;
812 integer lwork = -1;
813 integer rank;
814 doublecomplex ans;
815 zgelss_ (&m,&n,&nrhs,
816 AC,&m,
817 xp,&ldb,
818 S,
819 &rcond,&rank,
820 &ans,&lwork,
821 RWORK,
822 &res);
823 lwork = ceil(ans.r);
824 //printf("ans = %d\n",lwork);
825 doublecomplex * work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex));
826 zgelss_ (&m,&n,&nrhs,
827 AC,&m,
828 xp,&ldb,
829 S,
830 &rcond,&rank,
831 work,&lwork,
832 RWORK,
833 &res);
834 if(res>0) {
835 return NOCONVER;
836 }
837 CHECK(res,res);
838 free(work);
839 free(RWORK);
840 free(S);
841 free(AC);
842 OK
843}
844
845//////////////////// Cholesky factorization /////////////////////////
846
847/* Subroutine */ int zpotrf_(char *uplo, integer *n, doublecomplex *a,
848 integer *lda, integer *info);
849
850int chol_l_H(KCMAT(a),CMAT(l)) {
851 integer n = ar;
852 REQUIRES(n>=1 && ac == n && lr==n && lc==n,BAD_SIZE);
853 DEBUGMSG("chol_l_H");
854 memcpy(lp,ap,n*n*sizeof(doublecomplex));
855 char uplo = 'U';
856 integer res;
857 zpotrf_ (&uplo,&n,lp,&n,&res);
858 CHECK(res>0,NODEFPOS);
859 CHECK(res,res);
860 doublecomplex zero = {0.,0.};
861 int r,c;
862 for (r=0; r<lr-1; r++) {
863 for(c=r+1; c<lc; c++) {
864 lp[r*lc+c] = zero;
865 }
866 }
867 OK
868}
869
870
871/* Subroutine */ int dpotrf_(char *uplo, integer *n, doublereal *a, integer *
872 lda, integer *info);
873
874int chol_l_S(KDMAT(a),DMAT(l)) {
875 integer n = ar;
876 REQUIRES(n>=1 && ac == n && lr==n && lc==n,BAD_SIZE);
877 DEBUGMSG("chol_l_S");
878 memcpy(lp,ap,n*n*sizeof(double));
879 char uplo = 'U';
880 integer res;
881 dpotrf_ (&uplo,&n,lp,&n,&res);
882 CHECK(res>0,NODEFPOS);
883 CHECK(res,res);
884 int r,c;
885 for (r=0; r<lr-1; r++) {
886 for(c=r+1; c<lc; c++) {
887 lp[r*lc+c] = 0.;
888 }
889 }
890 OK
891}
892
893//////////////////// QR factorization /////////////////////////
894
895/* Subroutine */ int dgeqr2_(integer *m, integer *n, doublereal *a, integer *
896 lda, doublereal *tau, doublereal *work, integer *info);
897
898int qr_l_R(KDMAT(a), DVEC(tau), DMAT(r)) {
899 integer m = ar;
900 integer n = ac;
901 integer mn = MIN(m,n);
902 REQUIRES(m>=1 && n >=1 && rr== m && rc == n && taun == mn, BAD_SIZE);
903 DEBUGMSG("qr_l_R");
904 double *WORK = (double*)malloc(n*sizeof(double));
905 CHECK(!WORK,MEM);
906 memcpy(rp,ap,m*n*sizeof(double));
907 integer res;
908 dgeqr2_ (&m,&n,rp,&m,taup,WORK,&res);
909 CHECK(res,res);
910 free(WORK);
911 OK
912}
913
914/* Subroutine */ int zgeqr2_(integer *m, integer *n, doublecomplex *a,
915 integer *lda, doublecomplex *tau, doublecomplex *work, integer *info);
916
917int qr_l_C(KCMAT(a), CVEC(tau), CMAT(r)) {
918 integer m = ar;
919 integer n = ac;
920 integer mn = MIN(m,n);
921 REQUIRES(m>=1 && n >=1 && rr== m && rc == n && taun == mn, BAD_SIZE);
922 DEBUGMSG("qr_l_C");
923 doublecomplex *WORK = (doublecomplex*)malloc(n*sizeof(doublecomplex));
924 CHECK(!WORK,MEM);
925 memcpy(rp,ap,m*n*sizeof(doublecomplex));
926 integer res;
927 zgeqr2_ (&m,&n,rp,&m,taup,WORK,&res);
928 CHECK(res,res);
929 free(WORK);
930 OK
931}
932
933/* Subroutine */ int dorgqr_(integer *m, integer *n, integer *k, doublereal *
934 a, integer *lda, doublereal *tau, doublereal *work, integer *lwork,
935 integer *info);
936
937int c_dorgqr(KDMAT(a), KDVEC(tau), DMAT(r)) {
938 integer m = ar;
939 integer n = MIN(ac,ar);
940 integer k = taun;
941 DEBUGMSG("c_dorgqr");
942 integer lwork = 8*n; // FIXME
943 double *WORK = (double*)malloc(lwork*sizeof(double));
944 CHECK(!WORK,MEM);
945 memcpy(rp,ap,m*k*sizeof(double));
946 integer res;
947 dorgqr_ (&m,&n,&k,rp,&m,(double*)taup,WORK,&lwork,&res);
948 CHECK(res,res);
949 free(WORK);
950 OK
951}
952
953/* Subroutine */ int zungqr_(integer *m, integer *n, integer *k,
954 doublecomplex *a, integer *lda, doublecomplex *tau, doublecomplex *
955 work, integer *lwork, integer *info);
956
957int c_zungqr(KCMAT(a), KCVEC(tau), CMAT(r)) {
958 integer m = ar;
959 integer n = MIN(ac,ar);
960 integer k = taun;
961 DEBUGMSG("z_ungqr");
962 integer lwork = 8*n; // FIXME
963 doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex));
964 CHECK(!WORK,MEM);
965 memcpy(rp,ap,m*k*sizeof(doublecomplex));
966 integer res;
967 zungqr_ (&m,&n,&k,rp,&m,(doublecomplex*)taup,WORK,&lwork,&res);
968 CHECK(res,res);
969 free(WORK);
970 OK
971}
972
973
974//////////////////// Hessenberg factorization /////////////////////////
975
976/* Subroutine */ int dgehrd_(integer *n, integer *ilo, integer *ihi,
977 doublereal *a, integer *lda, doublereal *tau, doublereal *work,
978 integer *lwork, integer *info);
979
980int hess_l_R(KDMAT(a), DVEC(tau), DMAT(r)) {
981 integer m = ar;
982 integer n = ac;
983 integer mn = MIN(m,n);
984 REQUIRES(m>=1 && n == m && rr== m && rc == n && taun == mn-1, BAD_SIZE);
985 DEBUGMSG("hess_l_R");
986 integer lwork = 5*n; // fixme
987 double *WORK = (double*)malloc(lwork*sizeof(double));
988 CHECK(!WORK,MEM);
989 memcpy(rp,ap,m*n*sizeof(double));
990 integer res;
991 integer one = 1;
992 dgehrd_ (&n,&one,&n,rp,&n,taup,WORK,&lwork,&res);
993 CHECK(res,res);
994 free(WORK);
995 OK
996}
997
998
999/* Subroutine */ int zgehrd_(integer *n, integer *ilo, integer *ihi,
1000 doublecomplex *a, integer *lda, doublecomplex *tau, doublecomplex *
1001 work, integer *lwork, integer *info);
1002
1003int hess_l_C(KCMAT(a), CVEC(tau), CMAT(r)) {
1004 integer m = ar;
1005 integer n = ac;
1006 integer mn = MIN(m,n);
1007 REQUIRES(m>=1 && n == m && rr== m && rc == n && taun == mn-1, BAD_SIZE);
1008 DEBUGMSG("hess_l_C");
1009 integer lwork = 5*n; // fixme
1010 doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex));
1011 CHECK(!WORK,MEM);
1012 memcpy(rp,ap,m*n*sizeof(doublecomplex));
1013 integer res;
1014 integer one = 1;
1015 zgehrd_ (&n,&one,&n,rp,&n,taup,WORK,&lwork,&res);
1016 CHECK(res,res);
1017 free(WORK);
1018 OK
1019}
1020
1021//////////////////// Schur factorization /////////////////////////
1022
1023/* Subroutine */ int dgees_(char *jobvs, char *sort, L_fp select, integer *n,
1024 doublereal *a, integer *lda, integer *sdim, doublereal *wr,
1025 doublereal *wi, doublereal *vs, integer *ldvs, doublereal *work,
1026 integer *lwork, logical *bwork, integer *info);
1027
1028int schur_l_R(KDMAT(a), DMAT(u), DMAT(s)) {
1029 integer m = ar;
1030 integer n = ac;
1031 REQUIRES(m>=1 && n==m && ur==n && uc==n && sr==n && sc==n, BAD_SIZE);
1032 DEBUGMSG("schur_l_R");
1033 //int k;
1034 //printf("---------------------------\n");
1035 //printf("%p: ",ap); for(k=0;k<n*n;k++) printf("%f ",ap[k]); printf("\n");
1036 //printf("%p: ",up); for(k=0;k<n*n;k++) printf("%f ",up[k]); printf("\n");
1037 //printf("%p: ",sp); for(k=0;k<n*n;k++) printf("%f ",sp[k]); printf("\n");
1038 memcpy(sp,ap,n*n*sizeof(double));
1039 integer lwork = 6*n; // fixme
1040 double *WORK = (double*)malloc(lwork*sizeof(double));
1041 double *WR = (double*)malloc(n*sizeof(double));
1042 double *WI = (double*)malloc(n*sizeof(double));
1043 // WR and WI not really required in this call
1044 logical *BWORK = (logical*)malloc(n*sizeof(logical));
1045 integer res;
1046 integer sdim;
1047 dgees_ ("V","N",NULL,&n,sp,&n,&sdim,WR,WI,up,&n,WORK,&lwork,BWORK,&res);
1048 //printf("%p: ",ap); for(k=0;k<n*n;k++) printf("%f ",ap[k]); printf("\n");
1049 //printf("%p: ",up); for(k=0;k<n*n;k++) printf("%f ",up[k]); printf("\n");
1050 //printf("%p: ",sp); for(k=0;k<n*n;k++) printf("%f ",sp[k]); printf("\n");
1051 if(res>0) {
1052 return NOCONVER;
1053 }
1054 CHECK(res,res);
1055 free(WR);
1056 free(WI);
1057 free(BWORK);
1058 free(WORK);
1059 OK
1060}
1061
1062
1063/* Subroutine */ int zgees_(char *jobvs, char *sort, L_fp select, integer *n,
1064 doublecomplex *a, integer *lda, integer *sdim, doublecomplex *w,
1065 doublecomplex *vs, integer *ldvs, doublecomplex *work, integer *lwork,
1066 doublereal *rwork, logical *bwork, integer *info);
1067
1068int schur_l_C(KCMAT(a), CMAT(u), CMAT(s)) {
1069 integer m = ar;
1070 integer n = ac;
1071 REQUIRES(m>=1 && n==m && ur==n && uc==n && sr==n && sc==n, BAD_SIZE);
1072 DEBUGMSG("schur_l_C");
1073 memcpy(sp,ap,n*n*sizeof(doublecomplex));
1074 integer lwork = 6*n; // fixme
1075 doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex));
1076 doublecomplex *W = (doublecomplex*)malloc(n*sizeof(doublecomplex));
1077 // W not really required in this call
1078 logical *BWORK = (logical*)malloc(n*sizeof(logical));
1079 double *RWORK = (double*)malloc(n*sizeof(double));
1080 integer res;
1081 integer sdim;
1082 zgees_ ("V","N",NULL,&n,sp,&n,&sdim,W,
1083 up,&n,
1084 WORK,&lwork,RWORK,BWORK,&res);
1085 if(res>0) {
1086 return NOCONVER;
1087 }
1088 CHECK(res,res);
1089 free(W);
1090 free(BWORK);
1091 free(WORK);
1092 OK
1093}
1094
1095//////////////////// LU factorization /////////////////////////
1096
1097/* Subroutine */ int dgetrf_(integer *m, integer *n, doublereal *a, integer *
1098 lda, integer *ipiv, integer *info);
1099
1100int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r)) {
1101 integer m = ar;
1102 integer n = ac;
1103 integer mn = MIN(m,n);
1104 REQUIRES(m>=1 && n >=1 && ipivn == mn, BAD_SIZE);
1105 DEBUGMSG("lu_l_R");
1106 integer* auxipiv = (integer*)malloc(mn*sizeof(integer));
1107 memcpy(rp,ap,m*n*sizeof(double));
1108 integer res;
1109 dgetrf_ (&m,&n,rp,&m,auxipiv,&res);
1110 if(res>0) {
1111 res = 0; // fixme
1112 }
1113 CHECK(res,res);
1114 int k;
1115 for (k=0; k<mn; k++) {
1116 ipivp[k] = auxipiv[k];
1117 }
1118 free(auxipiv);
1119 OK
1120}
1121
1122
1123/* Subroutine */ int zgetrf_(integer *m, integer *n, doublecomplex *a,
1124 integer *lda, integer *ipiv, integer *info);
1125
1126int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)) {
1127 integer m = ar;
1128 integer n = ac;
1129 integer mn = MIN(m,n);
1130 REQUIRES(m>=1 && n >=1 && ipivn == mn, BAD_SIZE);
1131 DEBUGMSG("lu_l_C");
1132 integer* auxipiv = (integer*)malloc(mn*sizeof(integer));
1133 memcpy(rp,ap,m*n*sizeof(doublecomplex));
1134 integer res;
1135 zgetrf_ (&m,&n,rp,&m,auxipiv,&res);
1136 if(res>0) {
1137 res = 0; // fixme
1138 }
1139 CHECK(res,res);
1140 int k;
1141 for (k=0; k<mn; k++) {
1142 ipivp[k] = auxipiv[k];
1143 }
1144 free(auxipiv);
1145 OK
1146}
1147
1148
1149//////////////////// LU substitution /////////////////////////
1150
1151/* Subroutine */ int dgetrs_(char *trans, integer *n, integer *nrhs,
1152 doublereal *a, integer *lda, integer *ipiv, doublereal *b, integer *
1153 ldb, integer *info);
1154
1155int luS_l_R(KDMAT(a), KDVEC(ipiv), KDMAT(b), DMAT(x)) {
1156 integer m = ar;
1157 integer n = ac;
1158 integer mrhs = br;
1159 integer nrhs = bc;
1160
1161 REQUIRES(m==n && m==mrhs && m==ipivn,BAD_SIZE);
1162 integer* auxipiv = (integer*)malloc(n*sizeof(integer));
1163 int k;
1164 for (k=0; k<n; k++) {
1165 auxipiv[k] = (integer)ipivp[k];
1166 }
1167 integer res;
1168 memcpy(xp,bp,mrhs*nrhs*sizeof(double));
1169 dgetrs_ ("N",&n,&nrhs,(/*no const (!?)*/ double*)ap,&m,auxipiv,xp,&mrhs,&res);
1170 CHECK(res,res);
1171 free(auxipiv);
1172 OK
1173}
1174
1175
1176/* Subroutine */ int zgetrs_(char *trans, integer *n, integer *nrhs,
1177 doublecomplex *a, integer *lda, integer *ipiv, doublecomplex *b,
1178 integer *ldb, integer *info);
1179
1180int luS_l_C(KCMAT(a), KDVEC(ipiv), KCMAT(b), CMAT(x)) {
1181 integer m = ar;
1182 integer n = ac;
1183 integer mrhs = br;
1184 integer nrhs = bc;
1185
1186 REQUIRES(m==n && m==mrhs && m==ipivn,BAD_SIZE);
1187 integer* auxipiv = (integer*)malloc(n*sizeof(integer));
1188 int k;
1189 for (k=0; k<n; k++) {
1190 auxipiv[k] = (integer)ipivp[k];
1191 }
1192 integer res;
1193 memcpy(xp,bp,mrhs*nrhs*sizeof(doublecomplex));
1194 zgetrs_ ("N",&n,&nrhs,(doublecomplex*)ap,&m,auxipiv,xp,&mrhs,&res);
1195 CHECK(res,res);
1196 free(auxipiv);
1197 OK
1198}
1199
1200//////////////////// Matrix Product /////////////////////////
1201
1202void dgemm_(char *, char *, integer *, integer *, integer *,
1203 double *, const double *, integer *, const double *,
1204 integer *, double *, double *, integer *);
1205
1206int multiplyR(int ta, int tb, KDMAT(a),KDMAT(b),DMAT(r)) {
1207 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
1208 DEBUGMSG("dgemm_");
1209 CHECKNANR(a,"NaN multR Input\n")
1210 CHECKNANR(b,"NaN multR Input\n")
1211 integer m = ta?ac:ar;
1212 integer n = tb?br:bc;
1213 integer k = ta?ar:ac;
1214 integer lda = ar;
1215 integer ldb = br;
1216 integer ldc = rr;
1217 double alpha = 1;
1218 double beta = 0;
1219 dgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc);
1220 CHECKNANR(r,"NaN multR Output\n")
1221 OK
1222}
1223
1224void zgemm_(char *, char *, integer *, integer *, integer *,
1225 doublecomplex *, const doublecomplex *, integer *, const doublecomplex *,
1226 integer *, doublecomplex *, doublecomplex *, integer *);
1227
1228int multiplyC(int ta, int tb, KCMAT(a),KCMAT(b),CMAT(r)) {
1229 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
1230 DEBUGMSG("zgemm_");
1231 CHECKNANC(a,"NaN multC Input\n")
1232 CHECKNANC(b,"NaN multC Input\n")
1233 integer m = ta?ac:ar;
1234 integer n = tb?br:bc;
1235 integer k = ta?ar:ac;
1236 integer lda = ar;
1237 integer ldb = br;
1238 integer ldc = rr;
1239 doublecomplex alpha = {1,0};
1240 doublecomplex beta = {0,0};
1241 zgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,
1242 ap,&lda,
1243 bp,&ldb,&beta,
1244 rp,&ldc);
1245 CHECKNANC(r,"NaN multC Output\n")
1246 OK
1247}
1248
1249void sgemm_(char *, char *, integer *, integer *, integer *,
1250 float *, const float *, integer *, const float *,
1251 integer *, float *, float *, integer *);
1252
1253int multiplyF(int ta, int tb, KFMAT(a),KFMAT(b),FMAT(r)) {
1254 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
1255 DEBUGMSG("sgemm_");
1256 integer m = ta?ac:ar;
1257 integer n = tb?br:bc;
1258 integer k = ta?ar:ac;
1259 integer lda = ar;
1260 integer ldb = br;
1261 integer ldc = rr;
1262 float alpha = 1;
1263 float beta = 0;
1264 sgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc);
1265 OK
1266}
1267
1268void cgemm_(char *, char *, integer *, integer *, integer *,
1269 complex *, const complex *, integer *, const complex *,
1270 integer *, complex *, complex *, integer *);
1271
1272int multiplyQ(int ta, int tb, KQMAT(a),KQMAT(b),QMAT(r)) {
1273 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
1274 DEBUGMSG("cgemm_");
1275 integer m = ta?ac:ar;
1276 integer n = tb?br:bc;
1277 integer k = ta?ar:ac;
1278 integer lda = ar;
1279 integer ldb = br;
1280 integer ldc = rr;
1281 complex alpha = {1,0};
1282 complex beta = {0,0};
1283 cgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,
1284 ap,&lda,
1285 bp,&ldb,&beta,
1286 rp,&ldc);
1287 OK
1288}
1289
1290//////////////////// transpose /////////////////////////
1291
1292int transF(KFMAT(x),FMAT(t)) {
1293 REQUIRES(xr==tc && xc==tr,BAD_SIZE);
1294 DEBUGMSG("transF");
1295 int i,j;
1296 for (i=0; i<tr; i++) {
1297 for (j=0; j<tc; j++) {
1298 tp[i*tc+j] = xp[j*xc+i];
1299 }
1300 }
1301 OK
1302}
1303
1304int transR(KDMAT(x),DMAT(t)) {
1305 REQUIRES(xr==tc && xc==tr,BAD_SIZE);
1306 DEBUGMSG("transR");
1307 int i,j;
1308 for (i=0; i<tr; i++) {
1309 for (j=0; j<tc; j++) {
1310 tp[i*tc+j] = xp[j*xc+i];
1311 }
1312 }
1313 OK
1314}
1315
1316int transQ(KQMAT(x),QMAT(t)) {
1317 REQUIRES(xr==tc && xc==tr,BAD_SIZE);
1318 DEBUGMSG("transQ");
1319 int i,j;
1320 for (i=0; i<tr; i++) {
1321 for (j=0; j<tc; j++) {
1322 tp[i*tc+j] = xp[j*xc+i];
1323 }
1324 }
1325 OK
1326}
1327
1328int transC(KCMAT(x),CMAT(t)) {
1329 REQUIRES(xr==tc && xc==tr,BAD_SIZE);
1330 DEBUGMSG("transC");
1331 int i,j;
1332 for (i=0; i<tr; i++) {
1333 for (j=0; j<tc; j++) {
1334 tp[i*tc+j] = xp[j*xc+i];
1335 }
1336 }
1337 OK
1338}
1339
1340int transP(KPMAT(x), PMAT(t)) {
1341 REQUIRES(xr==tc && xc==tr,BAD_SIZE);
1342 REQUIRES(xs==ts,NOCONVER);
1343 DEBUGMSG("transP");
1344 int i,j;
1345 for (i=0; i<tr; i++) {
1346 for (j=0; j<tc; j++) {
1347 memcpy(tp+(i*tc+j)*xs,xp +(j*xc+i)*xs,xs);
1348 }
1349 }
1350 OK
1351}
1352
1353//////////////////// constant /////////////////////////
1354
1355int constantF(float * pval, FVEC(r)) {
1356 DEBUGMSG("constantF")
1357 int k;
1358 double val = *pval;
1359 for(k=0;k<rn;k++) {
1360 rp[k]=val;
1361 }
1362 OK
1363}
1364
1365int constantR(double * pval, DVEC(r)) {
1366 DEBUGMSG("constantR")
1367 int k;
1368 double val = *pval;
1369 for(k=0;k<rn;k++) {
1370 rp[k]=val;
1371 }
1372 OK
1373}
1374
1375int constantQ(complex* pval, QVEC(r)) {
1376 DEBUGMSG("constantQ")
1377 int k;
1378 complex val = *pval;
1379 for(k=0;k<rn;k++) {
1380 rp[k]=val;
1381 }
1382 OK
1383}
1384
1385int constantC(doublecomplex* pval, CVEC(r)) {
1386 DEBUGMSG("constantC")
1387 int k;
1388 doublecomplex val = *pval;
1389 for(k=0;k<rn;k++) {
1390 rp[k]=val;
1391 }
1392 OK
1393}
1394
1395int constantP(void* pval, PVEC(r)) {
1396 DEBUGMSG("constantP")
1397 int k;
1398 for(k=0;k<rn;k++) {
1399 memcpy(rp+k*rs,pval,rs);
1400 }
1401 OK
1402}
1403
1404//////////////////// float-double conversion /////////////////////////
1405
1406int float2double(FVEC(x),DVEC(y)) {
1407 DEBUGMSG("float2double")
1408 int k;
1409 for(k=0;k<xn;k++) {
1410 yp[k]=xp[k];
1411 }
1412 OK
1413}
1414
1415int double2float(DVEC(x),FVEC(y)) {
1416 DEBUGMSG("double2float")
1417 int k;
1418 for(k=0;k<xn;k++) {
1419 yp[k]=xp[k];
1420 }
1421 OK
1422}
1423
1424//////////////////// conjugate /////////////////////////
1425
1426int conjugateQ(KQVEC(x),QVEC(t)) {
1427 REQUIRES(xn==tn,BAD_SIZE);
1428 DEBUGMSG("conjugateQ");
1429 int k;
1430 for(k=0;k<xn;k++) {
1431 tp[k].r = xp[k].r;
1432 tp[k].i = -xp[k].i;
1433 }
1434 OK
1435}
1436
1437int conjugateC(KCVEC(x),CVEC(t)) {
1438 REQUIRES(xn==tn,BAD_SIZE);
1439 DEBUGMSG("conjugateC");
1440 int k;
1441 for(k=0;k<xn;k++) {
1442 tp[k].r = xp[k].r;
1443 tp[k].i = -xp[k].i;
1444 }
1445 OK
1446}
1447
1448//////////////////// step /////////////////////////
1449
1450int stepF(FVEC(x),FVEC(y)) {
1451 DEBUGMSG("stepF")
1452 int k;
1453 for(k=0;k<xn;k++) {
1454 yp[k]=xp[k]>0;
1455 }
1456 OK
1457}
1458
1459int stepD(DVEC(x),DVEC(y)) {
1460 DEBUGMSG("stepD")
1461 int k;
1462 for(k=0;k<xn;k++) {
1463 yp[k]=xp[k]>0;
1464 }
1465 OK
1466}
1467
1468//////////////////// cond /////////////////////////
1469
1470int condF(FVEC(x),FVEC(y),FVEC(lt),FVEC(eq),FVEC(gt),FVEC(r)) {
1471 REQUIRES(xn==yn && xn==ltn && xn==eqn && xn==gtn && xn==rn ,BAD_SIZE);
1472 DEBUGMSG("condF")
1473 int k;
1474 for(k=0;k<xn;k++) {
1475 rp[k] = xp[k]<yp[k]?ltp[k]:(xp[k]>yp[k]?gtp[k]:eqp[k]);
1476 }
1477 OK
1478}
1479
1480int condD(DVEC(x),DVEC(y),DVEC(lt),DVEC(eq),DVEC(gt),DVEC(r)) {
1481 REQUIRES(xn==yn && xn==ltn && xn==eqn && xn==gtn && xn==rn ,BAD_SIZE);
1482 DEBUGMSG("condD")
1483 int k;
1484 for(k=0;k<xn;k++) {
1485 rp[k] = xp[k]<yp[k]?ltp[k]:(xp[k]>yp[k]?gtp[k]:eqp[k]);
1486 }
1487 OK
1488}
1489
diff --git a/packages/base/src/C/lapack-aux.h b/packages/base/src/C/lapack-aux.h
new file mode 100644
index 0000000..a3f1899
--- /dev/null
+++ b/packages/base/src/C/lapack-aux.h
@@ -0,0 +1,60 @@
1/*
2 * We have copied the definitions in f2c.h required
3 * to compile clapack.h, modified to support both
4 * 32 and 64 bit
5
6 http://opengrok.creo.hu/dragonfly/xref/src/contrib/gcc-3.4/libf2c/readme.netlib
7 http://www.ibm.com/developerworks/library/l-port64.html
8 */
9
10#ifdef _LP64
11typedef int integer;
12typedef unsigned int uinteger;
13typedef int logical;
14typedef long longint; /* system-dependent */
15typedef unsigned long ulongint; /* system-dependent */
16#else
17typedef long int integer;
18typedef unsigned long int uinteger;
19typedef long int logical;
20typedef long long longint; /* system-dependent */
21typedef unsigned long long ulongint; /* system-dependent */
22#endif
23
24typedef char *address;
25typedef short int shortint;
26typedef float real;
27typedef double doublereal;
28typedef struct { real r, i; } complex;
29typedef struct { doublereal r, i; } doublecomplex;
30typedef short int shortlogical;
31typedef char logical1;
32typedef char integer1;
33
34typedef logical (*L_fp)();
35typedef short ftnlen;
36
37/********************************************************/
38
39#define FVEC(A) int A##n, float*A##p
40#define DVEC(A) int A##n, double*A##p
41#define QVEC(A) int A##n, complex*A##p
42#define CVEC(A) int A##n, doublecomplex*A##p
43#define PVEC(A) int A##n, void* A##p, int A##s
44#define FMAT(A) int A##r, int A##c, float* A##p
45#define DMAT(A) int A##r, int A##c, double* A##p
46#define QMAT(A) int A##r, int A##c, complex* A##p
47#define CMAT(A) int A##r, int A##c, doublecomplex* A##p
48#define PMAT(A) int A##r, int A##c, void* A##p, int A##s
49
50#define KFVEC(A) int A##n, const float*A##p
51#define KDVEC(A) int A##n, const double*A##p
52#define KQVEC(A) int A##n, const complex*A##p
53#define KCVEC(A) int A##n, const doublecomplex*A##p
54#define KPVEC(A) int A##n, const void* A##p, int A##s
55#define KFMAT(A) int A##r, int A##c, const float* A##p
56#define KDMAT(A) int A##r, int A##c, const double* A##p
57#define KQMAT(A) int A##r, int A##c, const complex* A##p
58#define KCMAT(A) int A##r, int A##c, const doublecomplex* A##p
59#define KPMAT(A) int A##r, int A##c, const void* A##p, int A##s
60
diff --git a/packages/base/src/Data/Packed.hs b/packages/base/src/Data/Packed.hs
new file mode 100644
index 0000000..c66718a
--- /dev/null
+++ b/packages/base/src/Data/Packed.hs
@@ -0,0 +1,25 @@
1-----------------------------------------------------------------------------
2{- |
3Module : Data.Packed
4Copyright : (c) Alberto Ruiz 2006-2014
5License : BSD3
6Maintainer : Alberto Ruiz
7Stability : provisional
8
9Types for dense 'Vector' and 'Matrix' of 'Storable' elements.
10
11-}
12-----------------------------------------------------------------------------
13
14module Data.Packed (
15 -- * Vector
16 --
17 -- | Vectors are @Data.Vector.Storable.Vector@ from the \"vector\" package.
18 module Data.Packed.Vector,
19 -- * Matrix
20 module Data.Packed.Matrix,
21) where
22
23import Data.Packed.Vector
24import Data.Packed.Matrix
25
diff --git a/packages/base/src/Data/Packed/Development.hs b/packages/base/src/Data/Packed/Development.hs
new file mode 100644
index 0000000..777b6c5
--- /dev/null
+++ b/packages/base/src/Data/Packed/Development.hs
@@ -0,0 +1,31 @@
1
2-----------------------------------------------------------------------------
3-- |
4-- Module : Data.Packed.Development
5-- Copyright : (c) Alberto Ruiz 2009
6-- License : BSD3
7-- Maintainer : Alberto Ruiz
8-- Stability : provisional
9-- Portability : portable
10--
11-- The library can be easily extended with additional foreign functions
12-- using the tools in this module. Illustrative usage examples can be found
13-- in the @examples\/devel@ folder included in the package.
14--
15-----------------------------------------------------------------------------
16
17module Data.Packed.Development (
18 createVector, createMatrix,
19 vec, mat,
20 app1, app2, app3, app4,
21 app5, app6, app7, app8, app9, app10,
22 MatrixOrder(..), orderOf, cmat, fmat,
23 matrixFromVector,
24 unsafeFromForeignPtr,
25 unsafeToForeignPtr,
26 check, (//),
27 at', atM'
28) where
29
30import Data.Packed.Internal
31
diff --git a/packages/base/src/Data/Packed/Foreign.hs b/packages/base/src/Data/Packed/Foreign.hs
new file mode 100644
index 0000000..efa51ca
--- /dev/null
+++ b/packages/base/src/Data/Packed/Foreign.hs
@@ -0,0 +1,99 @@
1{-# LANGUAGE MagicHash, UnboxedTuples #-}
2-- | FFI and hmatrix helpers.
3--
4-- Sample usage, to upload a perspective matrix to a shader.
5--
6-- @ glUniformMatrix4fv 0 1 (fromIntegral gl_TRUE) \`appMatrix\` perspective 0.01 100 (pi\/2) (4\/3)
7-- @
8--
9module Data.Packed.Foreign
10 ( app
11 , appVector, appVectorLen
12 , appMatrix, appMatrixLen, appMatrixRaw, appMatrixRawLen
13 , unsafeMatrixToVector, unsafeMatrixToForeignPtr
14 ) where
15import Data.Packed.Internal
16import qualified Data.Vector.Storable as S
17import Foreign (Ptr, ForeignPtr, Storable)
18import Foreign.C.Types (CInt)
19import GHC.Base (IO(..), realWorld#)
20
21{-# INLINE unsafeInlinePerformIO #-}
22-- | If we use unsafePerformIO, it may not get inlined, so in a function that returns IO (which are all safe uses of app* in this module), there would be
23-- unecessary calls to unsafePerformIO or its internals.
24unsafeInlinePerformIO :: IO a -> a
25unsafeInlinePerformIO (IO f) = case f realWorld# of
26 (# _, x #) -> x
27
28{-# INLINE app #-}
29-- | Only useful since it is left associated with a precedence of 1, unlike 'Prelude.$', which is right associative.
30-- e.g.
31--
32-- @
33-- someFunction
34-- \`appMatrixLen\` m
35-- \`appVectorLen\` v
36-- \`app\` other
37-- \`app\` arguments
38-- \`app\` go here
39-- @
40--
41-- One could also write:
42--
43-- @
44-- (someFunction
45-- \`appMatrixLen\` m
46-- \`appVectorLen\` v)
47-- other
48-- arguments
49-- (go here)
50-- @
51--
52app :: (a -> b) -> a -> b
53app f = f
54
55{-# INLINE appVector #-}
56appVector :: Storable a => (Ptr a -> b) -> Vector a -> b
57appVector f x = unsafeInlinePerformIO (S.unsafeWith x (return . f))
58
59{-# INLINE appVectorLen #-}
60appVectorLen :: Storable a => (CInt -> Ptr a -> b) -> Vector a -> b
61appVectorLen f x = unsafeInlinePerformIO (S.unsafeWith x (return . f (fromIntegral (S.length x))))
62
63{-# INLINE appMatrix #-}
64appMatrix :: Element a => (Ptr a -> b) -> Matrix a -> b
65appMatrix f x = unsafeInlinePerformIO (S.unsafeWith (flatten x) (return . f))
66
67{-# INLINE appMatrixLen #-}
68appMatrixLen :: Element a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b
69appMatrixLen f x = unsafeInlinePerformIO (S.unsafeWith (flatten x) (return . f r c))
70 where
71 r = fromIntegral (rows x)
72 c = fromIntegral (cols x)
73
74{-# INLINE appMatrixRaw #-}
75appMatrixRaw :: Storable a => (Ptr a -> b) -> Matrix a -> b
76appMatrixRaw f x = unsafeInlinePerformIO (S.unsafeWith (xdat x) (return . f))
77
78{-# INLINE appMatrixRawLen #-}
79appMatrixRawLen :: Element a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b
80appMatrixRawLen f x = unsafeInlinePerformIO (S.unsafeWith (xdat x) (return . f r c))
81 where
82 r = fromIntegral (rows x)
83 c = fromIntegral (cols x)
84
85infixl 1 `app`
86infixl 1 `appVector`
87infixl 1 `appMatrix`
88infixl 1 `appMatrixRaw`
89
90{-# INLINE unsafeMatrixToVector #-}
91-- | This will disregard the order of the matrix, and simply return it as-is.
92-- If the order of the matrix is RowMajor, this function is identical to 'flatten'.
93unsafeMatrixToVector :: Matrix a -> Vector a
94unsafeMatrixToVector = xdat
95
96{-# INLINE unsafeMatrixToForeignPtr #-}
97unsafeMatrixToForeignPtr :: Storable a => Matrix a -> (ForeignPtr a, Int)
98unsafeMatrixToForeignPtr m = S.unsafeToForeignPtr0 (xdat m)
99
diff --git a/packages/base/src/Data/Packed/Internal.hs b/packages/base/src/Data/Packed/Internal.hs
new file mode 100644
index 0000000..537e51e
--- /dev/null
+++ b/packages/base/src/Data/Packed/Internal.hs
@@ -0,0 +1,26 @@
1-----------------------------------------------------------------------------
2-- |
3-- Module : Data.Packed.Internal
4-- Copyright : (c) Alberto Ruiz 2007
5-- License : GPL-style
6--
7-- Maintainer : Alberto Ruiz <aruiz@um.es>
8-- Stability : provisional
9-- Portability : portable
10--
11-- Reexports all internal modules
12--
13-----------------------------------------------------------------------------
14-- #hide
15
16module Data.Packed.Internal (
17 module Data.Packed.Internal.Common,
18 module Data.Packed.Internal.Signatures,
19 module Data.Packed.Internal.Vector,
20 module Data.Packed.Internal.Matrix,
21) where
22
23import Data.Packed.Internal.Common
24import Data.Packed.Internal.Signatures
25import Data.Packed.Internal.Vector
26import Data.Packed.Internal.Matrix
diff --git a/packages/base/src/Data/Packed/Internal/Common.hs b/packages/base/src/Data/Packed/Internal/Common.hs
new file mode 100644
index 0000000..615bbdf
--- /dev/null
+++ b/packages/base/src/Data/Packed/Internal/Common.hs
@@ -0,0 +1,160 @@
1{-# LANGUAGE CPP #-}
2-- |
3-- Module : Data.Packed.Internal.Common
4-- Copyright : (c) Alberto Ruiz 2007
5-- License : BSD3
6-- Maintainer : Alberto Ruiz
7-- Stability : provisional
8--
9--
10-- Development utilities.
11--
12
13
14module Data.Packed.Internal.Common(
15 Adapt,
16 app1, app2, app3, app4,
17 app5, app6, app7, app8, app9, app10,
18 (//), check, mbCatch,
19 splitEvery, common, compatdim,
20 fi,
21 table,
22 finit
23) where
24
25import Control.Monad(when)
26import Foreign.C.Types
27import Foreign.Storable.Complex()
28import Data.List(transpose,intersperse)
29import Control.Exception as E
30
31-- | @splitEvery 3 [1..9] == [[1,2,3],[4,5,6],[7,8,9]]@
32splitEvery :: Int -> [a] -> [[a]]
33splitEvery _ [] = []
34splitEvery k l = take k l : splitEvery k (drop k l)
35
36-- | obtains the common value of a property of a list
37common :: (Eq a) => (b->a) -> [b] -> Maybe a
38common f = commonval . map f where
39 commonval :: (Eq a) => [a] -> Maybe a
40 commonval [] = Nothing
41 commonval [a] = Just a
42 commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing
43
44-- | common value with \"adaptable\" 1
45compatdim :: [Int] -> Maybe Int
46compatdim [] = Nothing
47compatdim [a] = Just a
48compatdim (a:b:xs)
49 | a==b = compatdim (b:xs)
50 | a==1 = compatdim (b:xs)
51 | b==1 = compatdim (a:xs)
52 | otherwise = Nothing
53
54-- | Formatting tool
55table :: String -> [[String]] -> String
56table sep as = unlines . map unwords' $ transpose mtp where
57 mt = transpose as
58 longs = map (maximum . map length) mt
59 mtp = zipWith (\a b -> map (pad a) b) longs mt
60 pad n str = replicate (n - length str) ' ' ++ str
61 unwords' = concat . intersperse sep
62
63-- | postfix function application (@flip ($)@)
64(//) :: x -> (x -> y) -> y
65infixl 0 //
66(//) = flip ($)
67
68-- | specialized fromIntegral
69fi :: Int -> CInt
70fi = fromIntegral
71
72-- hmm..
73ww2 w1 o1 w2 o2 f = w1 o1 $ w2 o2 . f
74ww3 w1 o1 w2 o2 w3 o3 f = w1 o1 $ ww2 w2 o2 w3 o3 . f
75ww4 w1 o1 w2 o2 w3 o3 w4 o4 f = w1 o1 $ ww3 w2 o2 w3 o3 w4 o4 . f
76ww5 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 f = w1 o1 $ ww4 w2 o2 w3 o3 w4 o4 w5 o5 . f
77ww6 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 f = w1 o1 $ ww5 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 . f
78ww7 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 f = w1 o1 $ ww6 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 . f
79ww8 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 f = w1 o1 $ ww7 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 . f
80ww9 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 f = w1 o1 $ ww8 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 . f
81ww10 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 w10 o10 f = w1 o1 $ ww9 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 w10 o10 . f
82
83type Adapt f t r = t -> ((f -> r) -> IO()) -> IO()
84
85type Adapt1 f t1 = Adapt f t1 (IO CInt) -> t1 -> String -> IO()
86type Adapt2 f t1 r1 t2 = Adapt f t1 r1 -> t1 -> Adapt1 r1 t2
87type Adapt3 f t1 r1 t2 r2 t3 = Adapt f t1 r1 -> t1 -> Adapt2 r1 t2 r2 t3
88type Adapt4 f t1 r1 t2 r2 t3 r3 t4 = Adapt f t1 r1 -> t1 -> Adapt3 r1 t2 r2 t3 r3 t4
89type Adapt5 f t1 r1 t2 r2 t3 r3 t4 r4 t5 = Adapt f t1 r1 -> t1 -> Adapt4 r1 t2 r2 t3 r3 t4 r4 t5
90type Adapt6 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 = Adapt f t1 r1 -> t1 -> Adapt5 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6
91type Adapt7 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 = Adapt f t1 r1 -> t1 -> Adapt6 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7
92type Adapt8 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 = Adapt f t1 r1 -> t1 -> Adapt7 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8
93type Adapt9 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 = Adapt f t1 r1 -> t1 -> Adapt8 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9
94type Adapt10 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 r9 t10 = Adapt f t1 r1 -> t1 -> Adapt9 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 r9 t10
95
96app1 :: f -> Adapt1 f t1
97app2 :: f -> Adapt2 f t1 r1 t2
98app3 :: f -> Adapt3 f t1 r1 t2 r2 t3
99app4 :: f -> Adapt4 f t1 r1 t2 r2 t3 r3 t4
100app5 :: f -> Adapt5 f t1 r1 t2 r2 t3 r3 t4 r4 t5
101app6 :: f -> Adapt6 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6
102app7 :: f -> Adapt7 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7
103app8 :: f -> Adapt8 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8
104app9 :: f -> Adapt9 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9
105app10 :: f -> Adapt10 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 r9 t10
106
107app1 f w1 o1 s = w1 o1 $ \a1 -> f // a1 // check s
108app2 f w1 o1 w2 o2 s = ww2 w1 o1 w2 o2 $ \a1 a2 -> f // a1 // a2 // check s
109app3 f w1 o1 w2 o2 w3 o3 s = ww3 w1 o1 w2 o2 w3 o3 $
110 \a1 a2 a3 -> f // a1 // a2 // a3 // check s
111app4 f w1 o1 w2 o2 w3 o3 w4 o4 s = ww4 w1 o1 w2 o2 w3 o3 w4 o4 $
112 \a1 a2 a3 a4 -> f // a1 // a2 // a3 // a4 // check s
113app5 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 s = ww5 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 $
114 \a1 a2 a3 a4 a5 -> f // a1 // a2 // a3 // a4 // a5 // check s
115app6 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 s = ww6 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 $
116 \a1 a2 a3 a4 a5 a6 -> f // a1 // a2 // a3 // a4 // a5 // a6 // check s
117app7 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 s = ww7 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 $
118 \a1 a2 a3 a4 a5 a6 a7 -> f // a1 // a2 // a3 // a4 // a5 // a6 // a7 // check s
119app8 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 s = ww8 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 $
120 \a1 a2 a3 a4 a5 a6 a7 a8 -> f // a1 // a2 // a3 // a4 // a5 // a6 // a7 // a8 // check s
121app9 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 s = ww9 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 $
122 \a1 a2 a3 a4 a5 a6 a7 a8 a9 -> f // a1 // a2 // a3 // a4 // a5 // a6 // a7 // a8 // a9 // check s
123app10 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 w10 o10 s = ww10 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 w10 o10 $
124 \a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 -> f // a1 // a2 // a3 // a4 // a5 // a6 // a7 // a8 // a9 // a10 // check s
125
126
127
128-- GSL error codes are <= 1024
129-- | error codes for the auxiliary functions required by the wrappers
130errorCode :: CInt -> String
131errorCode 2000 = "bad size"
132errorCode 2001 = "bad function code"
133errorCode 2002 = "memory problem"
134errorCode 2003 = "bad file"
135errorCode 2004 = "singular"
136errorCode 2005 = "didn't converge"
137errorCode 2006 = "the input matrix is not positive definite"
138errorCode 2007 = "not yet supported in this OS"
139errorCode n = "code "++show n
140
141
142-- | clear the fpu
143foreign import ccall unsafe "asm_finit" finit :: IO ()
144
145-- | check the error code
146check :: String -> IO CInt -> IO ()
147check msg f = do
148#if FINIT
149 finit
150#endif
151 err <- f
152 when (err/=0) $ error (msg++": "++errorCode err)
153 return ()
154
155-- | Error capture and conversion to Maybe
156mbCatch :: IO x -> IO (Maybe x)
157mbCatch act = E.catch (Just `fmap` act) f
158 where f :: SomeException -> IO (Maybe x)
159 f _ = return Nothing
160
diff --git a/packages/base/src/Data/Packed/Internal/Matrix.hs b/packages/base/src/Data/Packed/Internal/Matrix.hs
new file mode 100644
index 0000000..9b831cc
--- /dev/null
+++ b/packages/base/src/Data/Packed/Internal/Matrix.hs
@@ -0,0 +1,422 @@
1{-# LANGUAGE ForeignFunctionInterface #-}
2{-# LANGUAGE FlexibleContexts #-}
3{-# LANGUAGE FlexibleInstances #-}
4{-# LANGUAGE BangPatterns #-}
5
6-- |
7-- Module : Data.Packed.Internal.Matrix
8-- Copyright : (c) Alberto Ruiz 2007
9-- License : BSD3
10-- Maintainer : Alberto Ruiz
11-- Stability : provisional
12--
13-- Internal matrix representation
14--
15
16module Data.Packed.Internal.Matrix(
17 Matrix(..), rows, cols, cdat, fdat,
18 MatrixOrder(..), orderOf,
19 createMatrix, mat,
20 cmat, fmat,
21 toLists, flatten, reshape,
22 Element(..),
23 trans,
24 fromRows, toRows, fromColumns, toColumns,
25 matrixFromVector,
26 subMatrix,
27 liftMatrix, liftMatrix2,
28 (@@>), atM',
29 singleton,
30 emptyM,
31 size, shSize, conformVs, conformMs, conformVTo, conformMTo
32) where
33
34import Data.Packed.Internal.Common
35import Data.Packed.Internal.Signatures
36import Data.Packed.Internal.Vector
37
38import Foreign.Marshal.Alloc(alloca, free)
39import Foreign.Marshal.Array(newArray)
40import Foreign.Ptr(Ptr, castPtr)
41import Foreign.Storable(Storable, peekElemOff, pokeElemOff, poke, sizeOf)
42import Data.Complex(Complex)
43import Foreign.C.Types
44import System.IO.Unsafe(unsafePerformIO)
45import Control.DeepSeq
46
47-----------------------------------------------------------------
48
49{- Design considerations for the Matrix Type
50 -----------------------------------------
51
52- we must easily handle both row major and column major order,
53 for bindings to LAPACK and GSL/C
54
55- we'd like to simplify redundant matrix transposes:
56 - Some of them arise from the order requirements of some functions
57 - some functions (matrix product) admit transposed arguments
58
59- maybe we don't really need this kind of simplification:
60 - more complex code
61 - some computational overhead
62 - only appreciable gain in code with a lot of redundant transpositions
63 and cheap matrix computations
64
65- we could carry both the matrix and its (lazily computed) transpose.
66 This may save some transpositions, but it is necessary to keep track of the
67 data which is actually computed to be used by functions like the matrix product
68 which admit both orders.
69
70- but if we need the transposed data and it is not in the structure, we must make
71 sure that we touch the same foreignptr that is used in the computation.
72
73- a reasonable solution is using two constructors for a matrix. Transposition just
74 "flips" the constructor. Actual data transposition is not done if followed by a
75 matrix product or another transpose.
76
77-}
78
79data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq)
80
81transOrder RowMajor = ColumnMajor
82transOrder ColumnMajor = RowMajor
83{- | Matrix representation suitable for GSL and LAPACK computations.
84
85The elements are stored in a continuous memory array.
86
87-}
88
89data Matrix t = Matrix { irows :: {-# UNPACK #-} !Int
90 , icols :: {-# UNPACK #-} !Int
91 , xdat :: {-# UNPACK #-} !(Vector t)
92 , order :: !MatrixOrder }
93-- RowMajor: preferred by C, fdat may require a transposition
94-- ColumnMajor: preferred by LAPACK, cdat may require a transposition
95
96cdat = xdat
97fdat = xdat
98
99rows :: Matrix t -> Int
100rows = irows
101
102cols :: Matrix t -> Int
103cols = icols
104
105orderOf :: Matrix t -> MatrixOrder
106orderOf = order
107
108
109-- | Matrix transpose.
110trans :: Matrix t -> Matrix t
111trans Matrix {irows = r, icols = c, xdat = d, order = o } = Matrix { irows = c, icols = r, xdat = d, order = transOrder o}
112
113cmat :: (Element t) => Matrix t -> Matrix t
114cmat m@Matrix{order = RowMajor} = m
115cmat Matrix {irows = r, icols = c, xdat = d, order = ColumnMajor } = Matrix { irows = r, icols = c, xdat = transdata r d c, order = RowMajor}
116
117fmat :: (Element t) => Matrix t -> Matrix t
118fmat m@Matrix{order = ColumnMajor} = m
119fmat Matrix {irows = r, icols = c, xdat = d, order = RowMajor } = Matrix { irows = r, icols = c, xdat = transdata c d r, order = ColumnMajor}
120
121-- C-Haskell matrix adapter
122-- mat :: Adapt (CInt -> CInt -> Ptr t -> r) (Matrix t) r
123
124mat :: (Storable t) => Matrix t -> (((CInt -> CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b
125mat a f =
126 unsafeWith (xdat a) $ \p -> do
127 let m g = do
128 g (fi (rows a)) (fi (cols a)) p
129 f m
130
131{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose.
132
133>>> flatten (ident 3)
134fromList [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]
135
136-}
137flatten :: Element t => Matrix t -> Vector t
138flatten = xdat . cmat
139
140{-
141type Mt t s = Int -> Int -> Ptr t -> s
142
143infixr 6 ::>
144type t ::> s = Mt t s
145-}
146
147-- | the inverse of 'Data.Packed.Matrix.fromLists'
148toLists :: (Element t) => Matrix t -> [[t]]
149toLists m = splitEvery (cols m) . toList . flatten $ m
150
151-- | Create a matrix from a list of vectors.
152-- All vectors must have the same dimension,
153-- or dimension 1, which is are automatically expanded.
154fromRows :: Element t => [Vector t] -> Matrix t
155fromRows [] = emptyM 0 0
156fromRows vs = case compatdim (map dim vs) of
157 Nothing -> error $ "fromRows expects vectors with equal sizes (or singletons), given: " ++ show (map dim vs)
158 Just 0 -> emptyM r 0
159 Just c -> matrixFromVector RowMajor r c . vjoin . map (adapt c) $ vs
160 where
161 r = length vs
162 adapt c v
163 | c == 0 = fromList[]
164 | dim v == c = v
165 | otherwise = constantD (v@>0) c
166
167-- | extracts the rows of a matrix as a list of vectors
168toRows :: Element t => Matrix t -> [Vector t]
169toRows m
170 | c == 0 = replicate r (fromList[])
171 | otherwise = toRows' 0
172 where
173 v = flatten m
174 r = rows m
175 c = cols m
176 toRows' k | k == r*c = []
177 | otherwise = subVector k c v : toRows' (k+c)
178
179-- | Creates a matrix from a list of vectors, as columns
180fromColumns :: Element t => [Vector t] -> Matrix t
181fromColumns m = trans . fromRows $ m
182
183-- | Creates a list of vectors from the columns of a matrix
184toColumns :: Element t => Matrix t -> [Vector t]
185toColumns m = toRows . trans $ m
186
187-- | Reads a matrix position.
188(@@>) :: Storable t => Matrix t -> (Int,Int) -> t
189infixl 9 @@>
190m@Matrix {irows = r, icols = c} @@> (i,j)
191 | safe = if i<0 || i>=r || j<0 || j>=c
192 then error "matrix indexing out of range"
193 else atM' m i j
194 | otherwise = atM' m i j
195{-# INLINE (@@>) #-}
196
197-- Unsafe matrix access without range checking
198atM' Matrix {icols = c, xdat = v, order = RowMajor} i j = v `at'` (i*c+j)
199atM' Matrix {irows = r, xdat = v, order = ColumnMajor} i j = v `at'` (j*r+i)
200{-# INLINE atM' #-}
201
202------------------------------------------------------------------
203
204matrixFromVector o r c v
205 | r * c == dim v = m
206 | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m
207 where
208 m = Matrix { irows = r, icols = c, xdat = v, order = o }
209
210-- allocates memory for a new matrix
211createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a)
212createMatrix ord r c = do
213 p <- createVector (r*c)
214 return (matrixFromVector ord r c p)
215
216{- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. (GNU-Octave groups by columns. To do it you can define @reshapeF r = trans . reshape r@
217where r is the desired number of rows.)
218
219>>> reshape 4 (fromList [1..12])
220(3><4)
221 [ 1.0, 2.0, 3.0, 4.0
222 , 5.0, 6.0, 7.0, 8.0
223 , 9.0, 10.0, 11.0, 12.0 ]
224
225-}
226reshape :: Storable t => Int -> Vector t -> Matrix t
227reshape 0 v = matrixFromVector RowMajor 0 0 v
228reshape c v = matrixFromVector RowMajor (dim v `div` c) c v
229
230singleton x = reshape 1 (fromList [x])
231
232-- | application of a vector function on the flattened matrix elements
233liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b
234liftMatrix f Matrix { irows = r, icols = c, xdat = d, order = o } = matrixFromVector o r c (f d)
235
236-- | application of a vector function on the flattened matrices elements
237liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
238liftMatrix2 f m1 m2
239 | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2"
240 | otherwise = case orderOf m1 of
241 RowMajor -> matrixFromVector RowMajor (rows m1) (cols m1) (f (xdat m1) (flatten m2))
242 ColumnMajor -> matrixFromVector ColumnMajor (rows m1) (cols m1) (f (xdat m1) ((xdat.fmat) m2))
243
244
245compat :: Matrix a -> Matrix b -> Bool
246compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2
247
248------------------------------------------------------------------
249
250{- | Supported matrix elements.
251
252 This class provides optimized internal
253 operations for selected element types.
254 It provides unoptimised defaults for any 'Storable' type,
255 so you can create instances simply as:
256 @instance Element Foo@.
257-}
258class (Storable a) => Element a where
259 subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position
260 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
261 -> Matrix a -> Matrix a
262 subMatrixD = subMatrix'
263 transdata :: Int -> Vector a -> Int -> Vector a
264 transdata = transdataP -- transdata'
265 constantD :: a -> Int -> Vector a
266 constantD = constantP -- constant'
267
268
269instance Element Float where
270 transdata = transdataAux ctransF
271 constantD = constantAux cconstantF
272
273instance Element Double where
274 transdata = transdataAux ctransR
275 constantD = constantAux cconstantR
276
277instance Element (Complex Float) where
278 transdata = transdataAux ctransQ
279 constantD = constantAux cconstantQ
280
281instance Element (Complex Double) where
282 transdata = transdataAux ctransC
283 constantD = constantAux cconstantC
284
285-------------------------------------------------------------------
286
287transdataAux fun c1 d c2 =
288 if noneed
289 then d
290 else unsafePerformIO $ do
291 v <- createVector (dim d)
292 unsafeWith d $ \pd ->
293 unsafeWith v $ \pv ->
294 fun (fi r1) (fi c1) pd (fi r2) (fi c2) pv // check "transdataAux"
295 return v
296 where r1 = dim d `div` c1
297 r2 = dim d `div` c2
298 noneed = dim d == 0 || r1 == 1 || c1 == 1
299
300transdataP :: Storable a => Int -> Vector a -> Int -> Vector a
301transdataP c1 d c2 =
302 if noneed
303 then d
304 else unsafePerformIO $ do
305 v <- createVector (dim d)
306 unsafeWith d $ \pd ->
307 unsafeWith v $ \pv ->
308 ctransP (fi r1) (fi c1) (castPtr pd) (fi sz) (fi r2) (fi c2) (castPtr pv) (fi sz) // check "transdataP"
309 return v
310 where r1 = dim d `div` c1
311 r2 = dim d `div` c2
312 sz = sizeOf (d @> 0)
313 noneed = dim d == 0 || r1 == 1 || c1 == 1
314
315foreign import ccall unsafe "transF" ctransF :: TFMFM
316foreign import ccall unsafe "transR" ctransR :: TMM
317foreign import ccall unsafe "transQ" ctransQ :: TQMQM
318foreign import ccall unsafe "transC" ctransC :: TCMCM
319foreign import ccall unsafe "transP" ctransP :: CInt -> CInt -> Ptr () -> CInt -> CInt -> CInt -> Ptr () -> CInt -> IO CInt
320
321----------------------------------------------------------------------
322
323constantAux fun x n = unsafePerformIO $ do
324 v <- createVector n
325 px <- newArray [x]
326 app1 (fun px) vec v "constantAux"
327 free px
328 return v
329
330foreign import ccall unsafe "constantF" cconstantF :: Ptr Float -> TF
331
332foreign import ccall unsafe "constantR" cconstantR :: Ptr Double -> TV
333
334foreign import ccall unsafe "constantQ" cconstantQ :: Ptr (Complex Float) -> TQV
335
336foreign import ccall unsafe "constantC" cconstantC :: Ptr (Complex Double) -> TCV
337
338constantP :: Storable a => a -> Int -> Vector a
339constantP a n = unsafePerformIO $ do
340 let sz = sizeOf a
341 v <- createVector n
342 unsafeWith v $ \p -> do
343 alloca $ \k -> do
344 poke k a
345 cconstantP (castPtr k) (fi n) (castPtr p) (fi sz) // check "constantP"
346 return v
347foreign import ccall unsafe "constantP" cconstantP :: Ptr () -> CInt -> Ptr () -> CInt -> IO CInt
348
349----------------------------------------------------------------------
350
351-- | Extracts a submatrix from a matrix.
352subMatrix :: Element a
353 => (Int,Int) -- ^ (r0,c0) starting position
354 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
355 -> Matrix a -- ^ input matrix
356 -> Matrix a -- ^ result
357subMatrix (r0,c0) (rt,ct) m
358 | 0 <= r0 && 0 <= rt && r0+rt <= (rows m) &&
359 0 <= c0 && 0 <= ct && c0+ct <= (cols m) = subMatrixD (r0,c0) (rt,ct) m
360 | otherwise = error $ "wrong subMatrix "++
361 show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m)
362
363subMatrix'' (r0,c0) (rt,ct) c v = unsafePerformIO $ do
364 w <- createVector (rt*ct)
365 unsafeWith v $ \p ->
366 unsafeWith w $ \q -> do
367 let go (-1) _ = return ()
368 go !i (-1) = go (i-1) (ct-1)
369 go !i !j = do x <- peekElemOff p ((i+r0)*c+j+c0)
370 pokeElemOff q (i*ct+j) x
371 go i (j-1)
372 go (rt-1) (ct-1)
373 return w
374
375subMatrix' (r0,c0) (rt,ct) (Matrix { icols = c, xdat = v, order = RowMajor}) = Matrix rt ct (subMatrix'' (r0,c0) (rt,ct) c v) RowMajor
376subMatrix' (r0,c0) (rt,ct) m = trans $ subMatrix' (c0,r0) (ct,rt) (trans m)
377
378--------------------------------------------------------------------------
379
380maxZ xs = if minimum xs == 0 then 0 else maximum xs
381
382conformMs ms = map (conformMTo (r,c)) ms
383 where
384 r = maxZ (map rows ms)
385 c = maxZ (map cols ms)
386
387
388conformVs vs = map (conformVTo n) vs
389 where
390 n = maxZ (map dim vs)
391
392conformMTo (r,c) m
393 | size m == (r,c) = m
394 | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c))
395 | size m == (r,1) = repCols c m
396 | size m == (1,c) = repRows r m
397 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")"
398
399conformVTo n v
400 | dim v == n = v
401 | dim v == 1 = constantD (v@>0) n
402 | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n
403
404repRows n x = fromRows (replicate n (flatten x))
405repCols n x = fromColumns (replicate n (flatten x))
406
407size m = (rows m, cols m)
408
409shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")"
410
411emptyM r c = matrixFromVector RowMajor r c (fromList[])
412
413----------------------------------------------------------------------
414
415instance (Storable t, NFData t) => NFData (Matrix t)
416 where
417 rnf m | d > 0 = rnf (v @> 0)
418 | otherwise = ()
419 where
420 d = dim v
421 v = xdat m
422
diff --git a/packages/base/src/Data/Packed/Internal/Signatures.hs b/packages/base/src/Data/Packed/Internal/Signatures.hs
new file mode 100644
index 0000000..acc3070
--- /dev/null
+++ b/packages/base/src/Data/Packed/Internal/Signatures.hs
@@ -0,0 +1,70 @@
1-- |
2-- Module : Data.Packed.Internal.Signatures
3-- Copyright : (c) Alberto Ruiz 2009
4-- License : BSD3
5-- Maintainer : Alberto Ruiz
6-- Stability : provisional
7--
8-- Signatures of the C functions.
9--
10
11
12module Data.Packed.Internal.Signatures where
13
14import Foreign.Ptr(Ptr)
15import Data.Complex(Complex)
16import Foreign.C.Types(CInt)
17
18type PF = Ptr Float --
19type PD = Ptr Double --
20type PQ = Ptr (Complex Float) --
21type PC = Ptr (Complex Double) --
22type TF = CInt -> PF -> IO CInt --
23type TFF = CInt -> PF -> TF --
24type TFV = CInt -> PF -> TV --
25type TVF = CInt -> PD -> TF --
26type TFFF = CInt -> PF -> TFF --
27type TV = CInt -> PD -> IO CInt --
28type TVV = CInt -> PD -> TV --
29type TVVV = CInt -> PD -> TVV --
30type TFM = CInt -> CInt -> PF -> IO CInt --
31type TFMFM = CInt -> CInt -> PF -> TFM --
32type TFMFMFM = CInt -> CInt -> PF -> TFMFM --
33type TM = CInt -> CInt -> PD -> IO CInt --
34type TMM = CInt -> CInt -> PD -> TM --
35type TVMM = CInt -> PD -> TMM --
36type TMVMM = CInt -> CInt -> PD -> TVMM --
37type TMMM = CInt -> CInt -> PD -> TMM --
38type TVM = CInt -> PD -> TM --
39type TVVM = CInt -> PD -> TVM --
40type TMV = CInt -> CInt -> PD -> TV --
41type TMMV = CInt -> CInt -> PD -> TMV --
42type TMVM = CInt -> CInt -> PD -> TVM --
43type TMMVM = CInt -> CInt -> PD -> TMVM --
44type TCM = CInt -> CInt -> PC -> IO CInt --
45type TCVCM = CInt -> PC -> TCM --
46type TCMCVCM = CInt -> CInt -> PC -> TCVCM --
47type TMCMCVCM = CInt -> CInt -> PD -> TCMCVCM --
48type TCMCMCVCM = CInt -> CInt -> PC -> TCMCVCM --
49type TCMCM = CInt -> CInt -> PC -> TCM --
50type TVCM = CInt -> PD -> TCM --
51type TCMVCM = CInt -> CInt -> PC -> TVCM --
52type TCMCMVCM = CInt -> CInt -> PC -> TCMVCM --
53type TCMCMCM = CInt -> CInt -> PC -> TCMCM --
54type TCV = CInt -> PC -> IO CInt --
55type TCVCV = CInt -> PC -> TCV --
56type TCVCVCV = CInt -> PC -> TCVCV --
57type TCVV = CInt -> PC -> TV --
58type TQV = CInt -> PQ -> IO CInt --
59type TQVQV = CInt -> PQ -> TQV --
60type TQVQVQV = CInt -> PQ -> TQVQV --
61type TQVF = CInt -> PQ -> TF --
62type TQM = CInt -> CInt -> PQ -> IO CInt --
63type TQMQM = CInt -> CInt -> PQ -> TQM --
64type TQMQMQM = CInt -> CInt -> PQ -> TQMQM --
65type TCMCV = CInt -> CInt -> PC -> TCV --
66type TVCV = CInt -> PD -> TCV --
67type TCVM = CInt -> PC -> TM --
68type TMCVM = CInt -> CInt -> PD -> TCVM --
69type TMMCVM = CInt -> CInt -> PD -> TMCVM --
70
diff --git a/packages/base/src/Data/Packed/Internal/Vector.hs b/packages/base/src/Data/Packed/Internal/Vector.hs
new file mode 100644
index 0000000..d0bc143
--- /dev/null
+++ b/packages/base/src/Data/Packed/Internal/Vector.hs
@@ -0,0 +1,471 @@
1{-# LANGUAGE MagicHash, CPP, UnboxedTuples, BangPatterns, FlexibleContexts #-}
2-- |
3-- Module : Data.Packed.Internal.Vector
4-- Copyright : (c) Alberto Ruiz 2007
5-- License : BSD3
6-- Maintainer : Alberto Ruiz
7-- Stability : provisional
8--
9-- Vector implementation
10--
11--------------------------------------------------------------------------------
12
13module Data.Packed.Internal.Vector (
14 Vector, dim,
15 fromList, toList, (|>),
16 vjoin, (@>), safe, at, at', subVector, takesV,
17 mapVector, mapVectorWithIndex, zipVectorWith, unzipVectorWith,
18 mapVectorM, mapVectorM_, mapVectorWithIndexM, mapVectorWithIndexM_,
19 foldVector, foldVectorG, foldLoop, foldVectorWithIndex,
20 createVector, vec,
21 asComplex, asReal, float2DoubleV, double2FloatV,
22 stepF, stepD, condF, condD,
23 conjugateQ, conjugateC,
24 cloneVector,
25 unsafeToForeignPtr,
26 unsafeFromForeignPtr,
27 unsafeWith
28) where
29
30import Data.Packed.Internal.Common
31import Data.Packed.Internal.Signatures
32import Foreign.Marshal.Array(peekArray, copyArray, advancePtr)
33import Foreign.ForeignPtr(ForeignPtr, castForeignPtr)
34import Foreign.Ptr(Ptr)
35import Foreign.Storable(Storable, peekElemOff, pokeElemOff, sizeOf)
36import Foreign.C.Types
37import Data.Complex
38import Control.Monad(when)
39import System.IO.Unsafe(unsafePerformIO)
40
41#if __GLASGOW_HASKELL__ >= 605
42import GHC.ForeignPtr (mallocPlainForeignPtrBytes)
43#else
44import Foreign.ForeignPtr (mallocForeignPtrBytes)
45#endif
46
47import GHC.Base
48#if __GLASGOW_HASKELL__ < 612
49import GHC.IOBase hiding (liftIO)
50#endif
51
52import qualified Data.Vector.Storable as Vector
53import Data.Vector.Storable(Vector,
54 fromList,
55 unsafeToForeignPtr,
56 unsafeFromForeignPtr,
57 unsafeWith)
58
59
60-- | Number of elements
61dim :: (Storable t) => Vector t -> Int
62dim = Vector.length
63
64
65-- C-Haskell vector adapter
66-- vec :: Adapt (CInt -> Ptr t -> r) (Vector t) r
67vec :: (Storable t) => Vector t -> (((CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b
68vec x f = unsafeWith x $ \p -> do
69 let v g = do
70 g (fi $ dim x) p
71 f v
72{-# INLINE vec #-}
73
74
75-- allocates memory for a new vector
76createVector :: Storable a => Int -> IO (Vector a)
77createVector n = do
78 when (n < 0) $ error ("trying to createVector of negative dim: "++show n)
79 fp <- doMalloc undefined
80 return $ unsafeFromForeignPtr fp 0 n
81 where
82 --
83 -- Use the much cheaper Haskell heap allocated storage
84 -- for foreign pointer space we control
85 --
86 doMalloc :: Storable b => b -> IO (ForeignPtr b)
87 doMalloc dummy = do
88#if __GLASGOW_HASKELL__ >= 605
89 mallocPlainForeignPtrBytes (n * sizeOf dummy)
90#else
91 mallocForeignPtrBytes (n * sizeOf dummy)
92#endif
93
94{- | creates a Vector from a list:
95
96@> fromList [2,3,5,7]
974 |> [2.0,3.0,5.0,7.0]@
98
99-}
100
101safeRead v = inlinePerformIO . unsafeWith v
102{-# INLINE safeRead #-}
103
104inlinePerformIO :: IO a -> a
105inlinePerformIO (IO m) = case m realWorld# of (# _, r #) -> r
106{-# INLINE inlinePerformIO #-}
107
108{- | extracts the Vector elements to a list
109
110>>> toList (linspace 5 (1,10))
111[1.0,3.25,5.5,7.75,10.0]
112
113-}
114toList :: Storable a => Vector a -> [a]
115toList v = safeRead v $ peekArray (dim v)
116
117{- | Create a vector from a list of elements and explicit dimension. The input
118 list is explicitly truncated if it is too long, so it may safely
119 be used, for instance, with infinite lists.
120
121>>> 5 |> [1..]
122fromList [1.0,2.0,3.0,4.0,5.0]
123
124-}
125(|>) :: (Storable a) => Int -> [a] -> Vector a
126infixl 9 |>
127n |> l = if length l' == n
128 then fromList l'
129 else error "list too short for |>"
130 where l' = take n l
131
132
133-- | access to Vector elements without range checking
134at' :: Storable a => Vector a -> Int -> a
135at' v n = safeRead v $ flip peekElemOff n
136{-# INLINE at' #-}
137
138--
139-- turn off bounds checking with -funsafe at configure time.
140-- ghc will optimise away the salways true case at compile time.
141--
142#if defined(UNSAFE)
143safe :: Bool
144safe = False
145#else
146safe = True
147#endif
148
149-- | access to Vector elements with range checking.
150at :: Storable a => Vector a -> Int -> a
151at v n
152 | safe = if n >= 0 && n < dim v
153 then at' v n
154 else error "vector index out of range"
155 | otherwise = at' v n
156{-# INLINE at #-}
157
158{- | takes a number of consecutive elements from a Vector
159
160>>> subVector 2 3 (fromList [1..10])
161fromList [3.0,4.0,5.0]
162
163-}
164subVector :: Storable t => Int -- ^ index of the starting element
165 -> Int -- ^ number of elements to extract
166 -> Vector t -- ^ source
167 -> Vector t -- ^ result
168subVector = Vector.slice
169
170
171{- | Reads a vector position:
172
173>>> fromList [0..9] @> 7
1747.0
175
176-}
177(@>) :: Storable t => Vector t -> Int -> t
178infixl 9 @>
179(@>) = at
180
181
182{- | concatenate a list of vectors
183
184>>> vjoin [fromList [1..5::Double], konst 1 3]
185fromList [1.0,2.0,3.0,4.0,5.0,1.0,1.0,1.0]
186
187-}
188vjoin :: Storable t => [Vector t] -> Vector t
189vjoin [] = fromList []
190vjoin [v] = v
191vjoin as = unsafePerformIO $ do
192 let tot = sum (map dim as)
193 r <- createVector tot
194 unsafeWith r $ \ptr ->
195 joiner as tot ptr
196 return r
197 where joiner [] _ _ = return ()
198 joiner (v:cs) _ p = do
199 let n = dim v
200 unsafeWith v $ \pb -> copyArray p pb n
201 joiner cs 0 (advancePtr p n)
202
203
204{- | Extract consecutive subvectors of the given sizes.
205
206>>> takesV [3,4] (linspace 10 (1,10::Double))
207[fromList [1.0,2.0,3.0],fromList [4.0,5.0,6.0,7.0]]
208
209-}
210takesV :: Storable t => [Int] -> Vector t -> [Vector t]
211takesV ms w | sum ms > dim w = error $ "takesV " ++ show ms ++ " on dim = " ++ (show $ dim w)
212 | otherwise = go ms w
213 where go [] _ = []
214 go (n:ns) v = subVector 0 n v
215 : go ns (subVector n (dim v - n) v)
216
217---------------------------------------------------------------
218
219-- | transforms a complex vector into a real vector with alternating real and imaginary parts
220asReal :: (RealFloat a, Storable a) => Vector (Complex a) -> Vector a
221asReal v = unsafeFromForeignPtr (castForeignPtr fp) (2*i) (2*n)
222 where (fp,i,n) = unsafeToForeignPtr v
223
224-- | transforms a real vector into a complex vector with alternating real and imaginary parts
225asComplex :: (RealFloat a, Storable a) => Vector a -> Vector (Complex a)
226asComplex v = unsafeFromForeignPtr (castForeignPtr fp) (i `div` 2) (n `div` 2)
227 where (fp,i,n) = unsafeToForeignPtr v
228
229---------------------------------------------------------------
230
231float2DoubleV :: Vector Float -> Vector Double
232float2DoubleV v = unsafePerformIO $ do
233 r <- createVector (dim v)
234 app2 c_float2double vec v vec r "float2double"
235 return r
236
237double2FloatV :: Vector Double -> Vector Float
238double2FloatV v = unsafePerformIO $ do
239 r <- createVector (dim v)
240 app2 c_double2float vec v vec r "double2float2"
241 return r
242
243
244foreign import ccall unsafe "float2double" c_float2double:: TFV
245foreign import ccall unsafe "double2float" c_double2float:: TVF
246
247---------------------------------------------------------------
248
249stepF :: Vector Float -> Vector Float
250stepF v = unsafePerformIO $ do
251 r <- createVector (dim v)
252 app2 c_stepF vec v vec r "stepF"
253 return r
254
255stepD :: Vector Double -> Vector Double
256stepD v = unsafePerformIO $ do
257 r <- createVector (dim v)
258 app2 c_stepD vec v vec r "stepD"
259 return r
260
261foreign import ccall unsafe "stepF" c_stepF :: TFF
262foreign import ccall unsafe "stepD" c_stepD :: TVV
263
264---------------------------------------------------------------
265
266condF :: Vector Float -> Vector Float -> Vector Float -> Vector Float -> Vector Float -> Vector Float
267condF x y l e g = unsafePerformIO $ do
268 r <- createVector (dim x)
269 app6 c_condF vec x vec y vec l vec e vec g vec r "condF"
270 return r
271
272condD :: Vector Double -> Vector Double -> Vector Double -> Vector Double -> Vector Double -> Vector Double
273condD x y l e g = unsafePerformIO $ do
274 r <- createVector (dim x)
275 app6 c_condD vec x vec y vec l vec e vec g vec r "condD"
276 return r
277
278foreign import ccall unsafe "condF" c_condF :: CInt -> PF -> CInt -> PF -> CInt -> PF -> TFFF
279foreign import ccall unsafe "condD" c_condD :: CInt -> PD -> CInt -> PD -> CInt -> PD -> TVVV
280
281--------------------------------------------------------------------------------
282
283conjugateAux fun x = unsafePerformIO $ do
284 v <- createVector (dim x)
285 app2 fun vec x vec v "conjugateAux"
286 return v
287
288conjugateQ :: Vector (Complex Float) -> Vector (Complex Float)
289conjugateQ = conjugateAux c_conjugateQ
290foreign import ccall unsafe "conjugateQ" c_conjugateQ :: TQVQV
291
292conjugateC :: Vector (Complex Double) -> Vector (Complex Double)
293conjugateC = conjugateAux c_conjugateC
294foreign import ccall unsafe "conjugateC" c_conjugateC :: TCVCV
295
296--------------------------------------------------------------------------------
297
298cloneVector :: Storable t => Vector t -> IO (Vector t)
299cloneVector v = do
300 let n = dim v
301 r <- createVector n
302 let f _ s _ d = copyArray d s n >> return 0
303 app2 f vec v vec r "cloneVector"
304 return r
305
306------------------------------------------------------------------
307
308-- | map on Vectors
309mapVector :: (Storable a, Storable b) => (a-> b) -> Vector a -> Vector b
310mapVector f v = unsafePerformIO $ do
311 w <- createVector (dim v)
312 unsafeWith v $ \p ->
313 unsafeWith w $ \q -> do
314 let go (-1) = return ()
315 go !k = do x <- peekElemOff p k
316 pokeElemOff q k (f x)
317 go (k-1)
318 go (dim v -1)
319 return w
320{-# INLINE mapVector #-}
321
322-- | zipWith for Vectors
323zipVectorWith :: (Storable a, Storable b, Storable c) => (a-> b -> c) -> Vector a -> Vector b -> Vector c
324zipVectorWith f u v = unsafePerformIO $ do
325 let n = min (dim u) (dim v)
326 w <- createVector n
327 unsafeWith u $ \pu ->
328 unsafeWith v $ \pv ->
329 unsafeWith w $ \pw -> do
330 let go (-1) = return ()
331 go !k = do x <- peekElemOff pu k
332 y <- peekElemOff pv k
333 pokeElemOff pw k (f x y)
334 go (k-1)
335 go (n -1)
336 return w
337{-# INLINE zipVectorWith #-}
338
339-- | unzipWith for Vectors
340unzipVectorWith :: (Storable (a,b), Storable c, Storable d)
341 => ((a,b) -> (c,d)) -> Vector (a,b) -> (Vector c,Vector d)
342unzipVectorWith f u = unsafePerformIO $ do
343 let n = dim u
344 v <- createVector n
345 w <- createVector n
346 unsafeWith u $ \pu ->
347 unsafeWith v $ \pv ->
348 unsafeWith w $ \pw -> do
349 let go (-1) = return ()
350 go !k = do z <- peekElemOff pu k
351 let (x,y) = f z
352 pokeElemOff pv k x
353 pokeElemOff pw k y
354 go (k-1)
355 go (n-1)
356 return (v,w)
357{-# INLINE unzipVectorWith #-}
358
359foldVector :: Storable a => (a -> b -> b) -> b -> Vector a -> b
360foldVector f x v = unsafePerformIO $
361 unsafeWith v $ \p -> do
362 let go (-1) s = return s
363 go !k !s = do y <- peekElemOff p k
364 go (k-1::Int) (f y s)
365 go (dim v -1) x
366{-# INLINE foldVector #-}
367
368-- the zero-indexed index is passed to the folding function
369foldVectorWithIndex :: Storable a => (Int -> a -> b -> b) -> b -> Vector a -> b
370foldVectorWithIndex f x v = unsafePerformIO $
371 unsafeWith v $ \p -> do
372 let go (-1) s = return s
373 go !k !s = do y <- peekElemOff p k
374 go (k-1::Int) (f k y s)
375 go (dim v -1) x
376{-# INLINE foldVectorWithIndex #-}
377
378foldLoop f s0 d = go (d - 1) s0
379 where
380 go 0 s = f (0::Int) s
381 go !j !s = go (j - 1) (f j s)
382
383foldVectorG f s0 v = foldLoop g s0 (dim v)
384 where g !k !s = f k (at' v) s
385 {-# INLINE g #-} -- Thanks to Ryan Ingram (http://permalink.gmane.org/gmane.comp.lang.haskell.cafe/46479)
386{-# INLINE foldVectorG #-}
387
388-------------------------------------------------------------------
389
390-- | monadic map over Vectors
391-- the monad @m@ must be strict
392mapVectorM :: (Storable a, Storable b, Monad m) => (a -> m b) -> Vector a -> m (Vector b)
393mapVectorM f v = do
394 w <- return $! unsafePerformIO $! createVector (dim v)
395 mapVectorM' w 0 (dim v -1)
396 return w
397 where mapVectorM' w' !k !t
398 | k == t = do
399 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
400 y <- f x
401 return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y
402 | otherwise = do
403 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
404 y <- f x
405 _ <- return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y
406 mapVectorM' w' (k+1) t
407{-# INLINE mapVectorM #-}
408
409-- | monadic map over Vectors
410mapVectorM_ :: (Storable a, Monad m) => (a -> m ()) -> Vector a -> m ()
411mapVectorM_ f v = do
412 mapVectorM' 0 (dim v -1)
413 where mapVectorM' !k !t
414 | k == t = do
415 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
416 f x
417 | otherwise = do
418 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
419 _ <- f x
420 mapVectorM' (k+1) t
421{-# INLINE mapVectorM_ #-}
422
423-- | monadic map over Vectors with the zero-indexed index passed to the mapping function
424-- the monad @m@ must be strict
425mapVectorWithIndexM :: (Storable a, Storable b, Monad m) => (Int -> a -> m b) -> Vector a -> m (Vector b)
426mapVectorWithIndexM f v = do
427 w <- return $! unsafePerformIO $! createVector (dim v)
428 mapVectorM' w 0 (dim v -1)
429 return w
430 where mapVectorM' w' !k !t
431 | k == t = do
432 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
433 y <- f k x
434 return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y
435 | otherwise = do
436 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
437 y <- f k x
438 _ <- return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y
439 mapVectorM' w' (k+1) t
440{-# INLINE mapVectorWithIndexM #-}
441
442-- | monadic map over Vectors with the zero-indexed index passed to the mapping function
443mapVectorWithIndexM_ :: (Storable a, Monad m) => (Int -> a -> m ()) -> Vector a -> m ()
444mapVectorWithIndexM_ f v = do
445 mapVectorM' 0 (dim v -1)
446 where mapVectorM' !k !t
447 | k == t = do
448 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
449 f k x
450 | otherwise = do
451 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
452 _ <- f k x
453 mapVectorM' (k+1) t
454{-# INLINE mapVectorWithIndexM_ #-}
455
456
457mapVectorWithIndex :: (Storable a, Storable b) => (Int -> a -> b) -> Vector a -> Vector b
458--mapVectorWithIndex g = head . mapVectorWithIndexM (\a b -> [g a b])
459mapVectorWithIndex f v = unsafePerformIO $ do
460 w <- createVector (dim v)
461 unsafeWith v $ \p ->
462 unsafeWith w $ \q -> do
463 let go (-1) = return ()
464 go !k = do x <- peekElemOff p k
465 pokeElemOff q k (f k x)
466 go (k-1)
467 go (dim v -1)
468 return w
469{-# INLINE mapVectorWithIndex #-}
470
471
diff --git a/packages/base/src/Data/Packed/Matrix.hs b/packages/base/src/Data/Packed/Matrix.hs
new file mode 100644
index 0000000..d94d167
--- /dev/null
+++ b/packages/base/src/Data/Packed/Matrix.hs
@@ -0,0 +1,490 @@
1{-# LANGUAGE TypeFamilies #-}
2{-# LANGUAGE FlexibleContexts #-}
3{-# LANGUAGE FlexibleInstances #-}
4{-# LANGUAGE MultiParamTypeClasses #-}
5{-# LANGUAGE CPP #-}
6
7-----------------------------------------------------------------------------
8-- |
9-- Module : Data.Packed.Matrix
10-- Copyright : (c) Alberto Ruiz 2007-10
11-- License : GPL
12--
13-- Maintainer : Alberto Ruiz <aruiz@um.es>
14-- Stability : provisional
15--
16-- A Matrix representation suitable for numerical computations using LAPACK and GSL.
17--
18-- This module provides basic functions for manipulation of structure.
19
20-----------------------------------------------------------------------------
21{-# OPTIONS_HADDOCK hide #-}
22
23module Data.Packed.Matrix (
24 Matrix,
25 Element,
26 rows,cols,
27 (><),
28 trans,
29 reshape, flatten,
30 fromLists, toLists, buildMatrix,
31 (@@>),
32 asRow, asColumn,
33 fromRows, toRows, fromColumns, toColumns,
34 fromBlocks, diagBlock, toBlocks, toBlocksEvery,
35 repmat,
36 flipud, fliprl,
37 subMatrix, takeRows, dropRows, takeColumns, dropColumns,
38 extractRows, extractColumns,
39 diagRect, takeDiag,
40 mapMatrix, mapMatrixWithIndex, mapMatrixWithIndexM, mapMatrixWithIndexM_,
41 liftMatrix, liftMatrix2, liftMatrix2Auto,fromArray2D
42) where
43
44import Data.Packed.Internal
45import qualified Data.Packed.ST as ST
46import Data.Array
47
48import Data.List(transpose,intersperse)
49import Foreign.Storable(Storable)
50import Control.Monad(liftM)
51
52-------------------------------------------------------------------
53
54#ifdef BINARY
55
56import Data.Binary
57import Control.Monad(replicateM)
58
59instance (Binary a, Element a, Storable a) => Binary (Matrix a) where
60 put m = do
61 let r = rows m
62 let c = cols m
63 put r
64 put c
65 mapM_ (\i -> mapM_ (\j -> put $ m @@> (i,j)) [0..(c-1)]) [0..(r-1)]
66 get = do
67 r <- get
68 c <- get
69 xs <- replicateM r $ replicateM c get
70 return $ fromLists xs
71
72#endif
73
74-------------------------------------------------------------------
75
76instance (Show a, Element a) => (Show (Matrix a)) where
77 show m | rows m == 0 || cols m == 0 = sizes m ++" []"
78 show m = (sizes m++) . dsp . map (map show) . toLists $ m
79
80sizes m = "("++show (rows m)++"><"++show (cols m)++")\n"
81
82dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp
83 where
84 mt = transpose as
85 longs = map (maximum . map length) mt
86 mtp = zipWith (\a b -> map (pad a) b) longs mt
87 pad n str = replicate (n - length str) ' ' ++ str
88 unwords' = concat . intersperse ", "
89
90------------------------------------------------------------------
91
92instance (Element a, Read a) => Read (Matrix a) where
93 readsPrec _ s = [((rs><cs) . read $ listnums, rest)]
94 where (thing,rest) = breakAt ']' s
95 (dims,listnums) = breakAt ')' thing
96 cs = read . init . fst. breakAt ')' . snd . breakAt '<' $ dims
97 rs = read . snd . breakAt '(' .init . fst . breakAt '>' $ dims
98
99
100breakAt c l = (a++[c],tail b) where
101 (a,b) = break (==c) l
102
103------------------------------------------------------------------
104
105-- | creates a matrix from a vertical list of matrices
106joinVert :: Element t => [Matrix t] -> Matrix t
107joinVert [] = emptyM 0 0
108joinVert ms = case common cols ms of
109 Nothing -> error "(impossible) joinVert on matrices with different number of columns"
110 Just c -> matrixFromVector RowMajor (sum (map rows ms)) c $ vjoin (map flatten ms)
111
112-- | creates a matrix from a horizontal list of matrices
113joinHoriz :: Element t => [Matrix t] -> Matrix t
114joinHoriz ms = trans. joinVert . map trans $ ms
115
116{- | Create a matrix from blocks given as a list of lists of matrices.
117
118Single row-column components are automatically expanded to match the
119corresponding common row and column:
120
121@
122disp = putStr . dispf 2
123@
124
125>>> disp $ fromBlocks [[ident 5, 7, row[10,20]], [3, diagl[1,2,3], 0]]
1268x10
1271 0 0 0 0 7 7 7 10 20
1280 1 0 0 0 7 7 7 10 20
1290 0 1 0 0 7 7 7 10 20
1300 0 0 1 0 7 7 7 10 20
1310 0 0 0 1 7 7 7 10 20
1323 3 3 3 3 1 0 0 0 0
1333 3 3 3 3 0 2 0 0 0
1343 3 3 3 3 0 0 3 0 0
135
136-}
137fromBlocks :: Element t => [[Matrix t]] -> Matrix t
138fromBlocks = fromBlocksRaw . adaptBlocks
139
140fromBlocksRaw mms = joinVert . map joinHoriz $ mms
141
142adaptBlocks ms = ms' where
143 bc = case common length ms of
144 Just c -> c
145 Nothing -> error "fromBlocks requires rectangular [[Matrix]]"
146 rs = map (compatdim . map rows) ms
147 cs = map (compatdim . map cols) (transpose ms)
148 szs = sequence [rs,cs]
149 ms' = splitEvery bc $ zipWith g szs (concat ms)
150
151 g [Just nr,Just nc] m
152 | nr == r && nc == c = m
153 | r == 1 && c == 1 = matrixFromVector RowMajor nr nc (constantD x (nr*nc))
154 | r == 1 = fromRows (replicate nr (flatten m))
155 | otherwise = fromColumns (replicate nc (flatten m))
156 where
157 r = rows m
158 c = cols m
159 x = m@@>(0,0)
160 g _ _ = error "inconsistent dimensions in fromBlocks"
161
162
163--------------------------------------------------------------------------------
164
165{- | create a block diagonal matrix
166
167>>> disp 2 $ diagBlock [konst 1 (2,2), konst 2 (3,5), col [5,7]]
1687x8
1691 1 0 0 0 0 0 0
1701 1 0 0 0 0 0 0
1710 0 2 2 2 2 2 0
1720 0 2 2 2 2 2 0
1730 0 2 2 2 2 2 0
1740 0 0 0 0 0 0 5
1750 0 0 0 0 0 0 7
176
177>>> diagBlock [(0><4)[], konst 2 (2,3)] :: Matrix Double
178(2><7)
179 [ 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0
180 , 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0 ]
181
182-}
183diagBlock :: (Element t, Num t) => [Matrix t] -> Matrix t
184diagBlock ms = fromBlocks $ zipWith f ms [0..]
185 where
186 f m k = take n $ replicate k z ++ m : repeat z
187 n = length ms
188 z = (1><1) [0]
189
190--------------------------------------------------------------------------------
191
192
193-- | Reverse rows
194flipud :: Element t => Matrix t -> Matrix t
195flipud m = extractRows [r-1,r-2 .. 0] $ m
196 where
197 r = rows m
198
199-- | Reverse columns
200fliprl :: Element t => Matrix t -> Matrix t
201fliprl m = extractColumns [c-1,c-2 .. 0] $ m
202 where
203 c = cols m
204
205------------------------------------------------------------
206
207{- | creates a rectangular diagonal matrix:
208
209>>> diagRect 7 (fromList [10,20,30]) 4 5 :: Matrix Double
210(4><5)
211 [ 10.0, 7.0, 7.0, 7.0, 7.0
212 , 7.0, 20.0, 7.0, 7.0, 7.0
213 , 7.0, 7.0, 30.0, 7.0, 7.0
214 , 7.0, 7.0, 7.0, 7.0, 7.0 ]
215
216-}
217diagRect :: (Storable t) => t -> Vector t -> Int -> Int -> Matrix t
218diagRect z v r c = ST.runSTMatrix $ do
219 m <- ST.newMatrix z r c
220 let d = min r c `min` (dim v)
221 mapM_ (\k -> ST.writeMatrix m k k (v@>k)) [0..d-1]
222 return m
223
224-- | extracts the diagonal from a rectangular matrix
225takeDiag :: (Element t) => Matrix t -> Vector t
226takeDiag m = fromList [flatten m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]]
227
228------------------------------------------------------------
229
230{- | An easy way to create a matrix:
231
232>>> (2><3)[2,4,7,-3,11,0]
233(2><3)
234 [ 2.0, 4.0, 7.0
235 , -3.0, 11.0, 0.0 ]
236
237This is the format produced by the instances of Show (Matrix a), which
238can also be used for input.
239
240The input list is explicitly truncated, so that it can
241safely be used with lists that are too long (like infinite lists).
242
243>>> (2><3)[1..]
244(2><3)
245 [ 1.0, 2.0, 3.0
246 , 4.0, 5.0, 6.0 ]
247
248
249-}
250(><) :: (Storable a) => Int -> Int -> [a] -> Matrix a
251r >< c = f where
252 f l | dim v == r*c = matrixFromVector RowMajor r c v
253 | otherwise = error $ "inconsistent list size = "
254 ++show (dim v) ++" in ("++show r++"><"++show c++")"
255 where v = fromList $ take (r*c) l
256
257----------------------------------------------------------------
258
259-- | Creates a matrix with the first n rows of another matrix
260takeRows :: Element t => Int -> Matrix t -> Matrix t
261takeRows n mt = subMatrix (0,0) (n, cols mt) mt
262-- | Creates a copy of a matrix without the first n rows
263dropRows :: Element t => Int -> Matrix t -> Matrix t
264dropRows n mt = subMatrix (n,0) (rows mt - n, cols mt) mt
265-- |Creates a matrix with the first n columns of another matrix
266takeColumns :: Element t => Int -> Matrix t -> Matrix t
267takeColumns n mt = subMatrix (0,0) (rows mt, n) mt
268-- | Creates a copy of a matrix without the first n columns
269dropColumns :: Element t => Int -> Matrix t -> Matrix t
270dropColumns n mt = subMatrix (0,n) (rows mt, cols mt - n) mt
271
272----------------------------------------------------------------
273
274{- | Creates a 'Matrix' from a list of lists (considered as rows).
275
276>>> fromLists [[1,2],[3,4],[5,6]]
277(3><2)
278 [ 1.0, 2.0
279 , 3.0, 4.0
280 , 5.0, 6.0 ]
281
282-}
283fromLists :: Element t => [[t]] -> Matrix t
284fromLists = fromRows . map fromList
285
286-- | creates a 1-row matrix from a vector
287--
288-- >>> asRow (fromList [1..5])
289-- (1><5)
290-- [ 1.0, 2.0, 3.0, 4.0, 5.0 ]
291--
292asRow :: Storable a => Vector a -> Matrix a
293asRow v = reshape (dim v) v
294
295-- | creates a 1-column matrix from a vector
296--
297-- >>> asColumn (fromList [1..5])
298-- (5><1)
299-- [ 1.0
300-- , 2.0
301-- , 3.0
302-- , 4.0
303-- , 5.0 ]
304--
305asColumn :: Storable a => Vector a -> Matrix a
306asColumn = trans . asRow
307
308
309
310{- | creates a Matrix of the specified size using the supplied function to
311 to map the row\/column position to the value at that row\/column position.
312
313@> buildMatrix 3 4 (\\(r,c) -> fromIntegral r * fromIntegral c)
314(3><4)
315 [ 0.0, 0.0, 0.0, 0.0, 0.0
316 , 0.0, 1.0, 2.0, 3.0, 4.0
317 , 0.0, 2.0, 4.0, 6.0, 8.0]@
318
319Hilbert matrix of order N:
320
321@hilb n = buildMatrix n n (\\(i,j)->1/(fromIntegral i + fromIntegral j +1))@
322
323-}
324buildMatrix :: Element a => Int -> Int -> ((Int, Int) -> a) -> Matrix a
325buildMatrix rc cc f =
326 fromLists $ map (map f)
327 $ map (\ ri -> map (\ ci -> (ri, ci)) [0 .. (cc - 1)]) [0 .. (rc - 1)]
328
329-----------------------------------------------------
330
331fromArray2D :: (Storable e) => Array (Int, Int) e -> Matrix e
332fromArray2D m = (r><c) (elems m)
333 where ((r0,c0),(r1,c1)) = bounds m
334 r = r1-r0+1
335 c = c1-c0+1
336
337
338-- | rearranges the rows of a matrix according to the order given in a list of integers.
339extractRows :: Element t => [Int] -> Matrix t -> Matrix t
340extractRows [] m = emptyM 0 (cols m)
341extractRows l m = fromRows $ extract (toRows m) l
342 where
343 extract l' is = [l'!!i | i<- map verify is]
344 verify k
345 | k >= 0 && k < rows m = k
346 | otherwise = error $ "can't extract row "
347 ++show k++" in list " ++ show l ++ " from matrix " ++ shSize m
348
349-- | rearranges the rows of a matrix according to the order given in a list of integers.
350extractColumns :: Element t => [Int] -> Matrix t -> Matrix t
351extractColumns l m = trans . extractRows (map verify l) . trans $ m
352 where
353 verify k
354 | k >= 0 && k < cols m = k
355 | otherwise = error $ "can't extract column "
356 ++show k++" in list " ++ show l ++ " from matrix " ++ shSize m
357
358
359
360{- | creates matrix by repetition of a matrix a given number of rows and columns
361
362>>> repmat (ident 2) 2 3
363(4><6)
364 [ 1.0, 0.0, 1.0, 0.0, 1.0, 0.0
365 , 0.0, 1.0, 0.0, 1.0, 0.0, 1.0
366 , 1.0, 0.0, 1.0, 0.0, 1.0, 0.0
367 , 0.0, 1.0, 0.0, 1.0, 0.0, 1.0 ]
368
369-}
370repmat :: (Element t) => Matrix t -> Int -> Int -> Matrix t
371repmat m r c
372 | r == 0 || c == 0 = emptyM (r*rows m) (c*cols m)
373 | otherwise = fromBlocks $ replicate r $ replicate c $ m
374
375-- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix.
376liftMatrix2Auto :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
377liftMatrix2Auto f m1 m2
378 | compat' m1 m2 = lM f m1 m2
379 | ok = lM f m1' m2'
380 | otherwise = error $ "nonconformable matrices in liftMatrix2Auto: " ++ shSize m1 ++ ", " ++ shSize m2
381 where
382 (r1,c1) = size m1
383 (r2,c2) = size m2
384 r = max r1 r2
385 c = max c1 c2
386 r0 = min r1 r2
387 c0 = min c1 c2
388 ok = r0 == 1 || r1 == r2 && c0 == 1 || c1 == c2
389 m1' = conformMTo (r,c) m1
390 m2' = conformMTo (r,c) m2
391
392-- FIXME do not flatten if equal order
393lM f m1 m2 = matrixFromVector
394 RowMajor
395 (max (rows m1) (rows m2))
396 (max (cols m1) (cols m2))
397 (f (flatten m1) (flatten m2))
398
399compat' :: Matrix a -> Matrix b -> Bool
400compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2
401 where
402 s1 = size m1
403 s2 = size m2
404
405------------------------------------------------------------
406
407toBlockRows [r] m | r == rows m = [m]
408toBlockRows rs m = map (reshape (cols m)) (takesV szs (flatten m))
409 where szs = map (* cols m) rs
410
411toBlockCols [c] m | c == cols m = [m]
412toBlockCols cs m = map trans . toBlockRows cs . trans $ m
413
414-- | Partition a matrix into blocks with the given numbers of rows and columns.
415-- The remaining rows and columns are discarded.
416toBlocks :: (Element t) => [Int] -> [Int] -> Matrix t -> [[Matrix t]]
417toBlocks rs cs m = map (toBlockCols cs) . toBlockRows rs $ m
418
419-- | Fully partition a matrix into blocks of the same size. If the dimensions are not
420-- a multiple of the given size the last blocks will be smaller.
421toBlocksEvery :: (Element t) => Int -> Int -> Matrix t -> [[Matrix t]]
422toBlocksEvery r c m
423 | r < 1 || c < 1 = error $ "toBlocksEvery expects block sizes > 0, given "++show r++" and "++ show c
424 | otherwise = toBlocks rs cs m
425 where
426 (qr,rr) = rows m `divMod` r
427 (qc,rc) = cols m `divMod` c
428 rs = replicate qr r ++ if rr > 0 then [rr] else []
429 cs = replicate qc c ++ if rc > 0 then [rc] else []
430
431-------------------------------------------------------------------
432
433-- Given a column number and a function taking matrix indexes, returns
434-- a function which takes vector indexes (that can be used on the
435-- flattened matrix).
436mk :: Int -> ((Int, Int) -> t) -> (Int -> t)
437mk c g = \k -> g (divMod k c)
438
439{- |
440
441>>> mapMatrixWithIndexM_ (\(i,j) v -> printf "m[%d,%d] = %.f\n" i j v :: IO()) ((2><3)[1 :: Double ..])
442m[0,0] = 1
443m[0,1] = 2
444m[0,2] = 3
445m[1,0] = 4
446m[1,1] = 5
447m[1,2] = 6
448
449-}
450mapMatrixWithIndexM_
451 :: (Element a, Num a, Monad m) =>
452 ((Int, Int) -> a -> m ()) -> Matrix a -> m ()
453mapMatrixWithIndexM_ g m = mapVectorWithIndexM_ (mk c g) . flatten $ m
454 where
455 c = cols m
456
457{- |
458
459>>> mapMatrixWithIndexM (\(i,j) v -> Just $ 100*v + 10*fromIntegral i + fromIntegral j) (ident 3:: Matrix Double)
460Just (3><3)
461 [ 100.0, 1.0, 2.0
462 , 10.0, 111.0, 12.0
463 , 20.0, 21.0, 122.0 ]
464
465-}
466mapMatrixWithIndexM
467 :: (Element a, Storable b, Monad m) =>
468 ((Int, Int) -> a -> m b) -> Matrix a -> m (Matrix b)
469mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . flatten $ m
470 where
471 c = cols m
472
473{- |
474
475>>> mapMatrixWithIndex (\\(i,j) v -> 100*v + 10*fromIntegral i + fromIntegral j) (ident 3:: Matrix Double)
476(3><3)
477 [ 100.0, 1.0, 2.0
478 , 10.0, 111.0, 12.0
479 , 20.0, 21.0, 122.0 ]
480
481 -}
482mapMatrixWithIndex
483 :: (Element a, Storable b) =>
484 ((Int, Int) -> a -> b) -> Matrix a -> Matrix b
485mapMatrixWithIndex g m = reshape c . mapVectorWithIndex (mk c g) . flatten $ m
486 where
487 c = cols m
488
489mapMatrix :: (Storable a, Storable b) => (a -> b) -> Matrix a -> Matrix b
490mapMatrix f = liftMatrix (mapVector f)
diff --git a/packages/base/src/Data/Packed/ST.hs b/packages/base/src/Data/Packed/ST.hs
new file mode 100644
index 0000000..dae457c
--- /dev/null
+++ b/packages/base/src/Data/Packed/ST.hs
@@ -0,0 +1,178 @@
1{-# LANGUAGE CPP #-}
2{-# LANGUAGE TypeOperators #-}
3{-# LANGUAGE Rank2Types #-}
4{-# LANGUAGE BangPatterns #-}
5-----------------------------------------------------------------------------
6-- |
7-- Module : Data.Packed.ST
8-- Copyright : (c) Alberto Ruiz 2008
9-- License : BSD3
10-- Maintainer : Alberto Ruiz
11-- Stability : provisional
12-- Portability : portable
13--
14-- In-place manipulation inside the ST monad.
15-- See examples/inplace.hs in the distribution.
16--
17-----------------------------------------------------------------------------
18
19module Data.Packed.ST (
20 -- * Mutable Vectors
21 STVector, newVector, thawVector, freezeVector, runSTVector,
22 readVector, writeVector, modifyVector, liftSTVector,
23 -- * Mutable Matrices
24 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix,
25 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix,
26 -- * Unsafe functions
27 newUndefinedVector,
28 unsafeReadVector, unsafeWriteVector,
29 unsafeThawVector, unsafeFreezeVector,
30 newUndefinedMatrix,
31 unsafeReadMatrix, unsafeWriteMatrix,
32 unsafeThawMatrix, unsafeFreezeMatrix
33) where
34
35import Data.Packed.Internal
36
37import Control.Monad.ST(ST, runST)
38import Foreign.Storable(Storable, peekElemOff, pokeElemOff)
39
40#if MIN_VERSION_base(4,4,0)
41import Control.Monad.ST.Unsafe(unsafeIOToST)
42#else
43import Control.Monad.ST(unsafeIOToST)
44#endif
45
46{-# INLINE ioReadV #-}
47ioReadV :: Storable t => Vector t -> Int -> IO t
48ioReadV v k = unsafeWith v $ \s -> peekElemOff s k
49
50{-# INLINE ioWriteV #-}
51ioWriteV :: Storable t => Vector t -> Int -> t -> IO ()
52ioWriteV v k x = unsafeWith v $ \s -> pokeElemOff s k x
53
54newtype STVector s t = STVector (Vector t)
55
56thawVector :: Storable t => Vector t -> ST s (STVector s t)
57thawVector = unsafeIOToST . fmap STVector . cloneVector
58
59unsafeThawVector :: Storable t => Vector t -> ST s (STVector s t)
60unsafeThawVector = unsafeIOToST . return . STVector
61
62runSTVector :: Storable t => (forall s . ST s (STVector s t)) -> Vector t
63runSTVector st = runST (st >>= unsafeFreezeVector)
64
65{-# INLINE unsafeReadVector #-}
66unsafeReadVector :: Storable t => STVector s t -> Int -> ST s t
67unsafeReadVector (STVector x) = unsafeIOToST . ioReadV x
68
69{-# INLINE unsafeWriteVector #-}
70unsafeWriteVector :: Storable t => STVector s t -> Int -> t -> ST s ()
71unsafeWriteVector (STVector x) k = unsafeIOToST . ioWriteV x k
72
73{-# INLINE modifyVector #-}
74modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s ()
75modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k
76
77liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s1 t -> ST s2 a
78liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x
79
80freezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t)
81freezeVector v = liftSTVector id v
82
83unsafeFreezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t)
84unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x
85
86{-# INLINE safeIndexV #-}
87safeIndexV f (STVector v) k
88 | k < 0 || k>= dim v = error $ "out of range error in vector (dim="
89 ++show (dim v)++", pos="++show k++")"
90 | otherwise = f (STVector v) k
91
92{-# INLINE readVector #-}
93readVector :: Storable t => STVector s t -> Int -> ST s t
94readVector = safeIndexV unsafeReadVector
95
96{-# INLINE writeVector #-}
97writeVector :: Storable t => STVector s t -> Int -> t -> ST s ()
98writeVector = safeIndexV unsafeWriteVector
99
100newUndefinedVector :: Storable t => Int -> ST s (STVector s t)
101newUndefinedVector = unsafeIOToST . fmap STVector . createVector
102
103{-# INLINE newVector #-}
104newVector :: Storable t => t -> Int -> ST s (STVector s t)
105newVector x n = do
106 v <- newUndefinedVector n
107 let go (-1) = return v
108 go !k = unsafeWriteVector v k x >> go (k-1 :: Int)
109 go (n-1)
110
111-------------------------------------------------------------------------
112
113{-# INLINE ioReadM #-}
114ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t
115ioReadM (Matrix _ nc cv RowMajor) r c = ioReadV cv (r*nc+c)
116ioReadM (Matrix nr _ fv ColumnMajor) r c = ioReadV fv (c*nr+r)
117
118{-# INLINE ioWriteM #-}
119ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO ()
120ioWriteM (Matrix _ nc cv RowMajor) r c val = ioWriteV cv (r*nc+c) val
121ioWriteM (Matrix nr _ fv ColumnMajor) r c val = ioWriteV fv (c*nr+r) val
122
123newtype STMatrix s t = STMatrix (Matrix t)
124
125thawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t)
126thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix
127
128unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t)
129unsafeThawMatrix = unsafeIOToST . return . STMatrix
130
131runSTMatrix :: Storable t => (forall s . ST s (STMatrix s t)) -> Matrix t
132runSTMatrix st = runST (st >>= unsafeFreezeMatrix)
133
134{-# INLINE unsafeReadMatrix #-}
135unsafeReadMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t
136unsafeReadMatrix (STMatrix x) r = unsafeIOToST . ioReadM x r
137
138{-# INLINE unsafeWriteMatrix #-}
139unsafeWriteMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s ()
140unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c
141
142{-# INLINE modifyMatrix #-}
143modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
144modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c
145
146liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s1 t -> ST s2 a
147liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x
148
149unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t)
150unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x
151
152freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t)
153freezeMatrix m = liftSTMatrix id m
154
155cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o)
156
157{-# INLINE safeIndexM #-}
158safeIndexM f (STMatrix m) r c
159 | r<0 || r>=rows m ||
160 c<0 || c>=cols m = error $ "out of range error in matrix (size="
161 ++show (rows m,cols m)++", pos="++show (r,c)++")"
162 | otherwise = f (STMatrix m) r c
163
164{-# INLINE readMatrix #-}
165readMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t
166readMatrix = safeIndexM unsafeReadMatrix
167
168{-# INLINE writeMatrix #-}
169writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s ()
170writeMatrix = safeIndexM unsafeWriteMatrix
171
172newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t)
173newUndefinedMatrix ord r c = unsafeIOToST $ fmap STMatrix $ createMatrix ord r c
174
175{-# NOINLINE newMatrix #-}
176newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t)
177newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c)
178
diff --git a/packages/base/src/Data/Packed/Vector.hs b/packages/base/src/Data/Packed/Vector.hs
new file mode 100644
index 0000000..b5a4318
--- /dev/null
+++ b/packages/base/src/Data/Packed/Vector.hs
@@ -0,0 +1,96 @@
1{-# LANGUAGE FlexibleContexts #-}
2{-# LANGUAGE CPP #-}
3-----------------------------------------------------------------------------
4-- |
5-- Module : Data.Packed.Vector
6-- Copyright : (c) Alberto Ruiz 2007-10
7-- License : GPL
8--
9-- Maintainer : Alberto Ruiz <aruiz@um.es>
10-- Stability : provisional
11--
12-- 1D arrays suitable for numeric computations using external libraries.
13--
14-- This module provides basic functions for manipulation of structure.
15--
16-----------------------------------------------------------------------------
17{-# OPTIONS_HADDOCK hide #-}
18
19module Data.Packed.Vector (
20 Vector,
21 fromList, (|>), toList, buildVector,
22 dim, (@>),
23 subVector, takesV, vjoin, join,
24 mapVector, mapVectorWithIndex, zipVector, zipVectorWith, unzipVector, unzipVectorWith,
25 mapVectorM, mapVectorM_, mapVectorWithIndexM, mapVectorWithIndexM_,
26 foldLoop, foldVector, foldVectorG, foldVectorWithIndex
27) where
28
29import Data.Packed.Internal.Vector
30import Foreign.Storable
31
32-------------------------------------------------------------------
33
34#ifdef BINARY
35
36import Data.Binary
37import Control.Monad(replicateM)
38
39-- a 64K cache, with a Double taking 13 bytes in Bytestring,
40-- implies a chunk size of 5041
41chunk :: Int
42chunk = 5000
43
44chunks :: Int -> [Int]
45chunks d = let c = d `div` chunk
46 m = d `mod` chunk
47 in if m /= 0 then reverse (m:(replicate c chunk)) else (replicate c chunk)
48
49putVector v = do
50 let d = dim v
51 mapM_ (\i -> put $ v @> i) [0..(d-1)]
52
53getVector d = do
54 xs <- replicateM d get
55 return $! fromList xs
56
57instance (Binary a, Storable a) => Binary (Vector a) where
58 put v = do
59 let d = dim v
60 put d
61 mapM_ putVector $! takesV (chunks d) v
62 get = do
63 d <- get
64 vs <- mapM getVector $ chunks d
65 return $! vjoin vs
66
67#endif
68
69-------------------------------------------------------------------
70
71{- | creates a Vector of the specified length using the supplied function to
72 to map the index to the value at that index.
73
74@> buildVector 4 fromIntegral
754 |> [0.0,1.0,2.0,3.0]@
76
77-}
78buildVector :: Storable a => Int -> (Int -> a) -> Vector a
79buildVector len f =
80 fromList $ map f [0 .. (len - 1)]
81
82
83-- | zip for Vectors
84zipVector :: (Storable a, Storable b, Storable (a,b)) => Vector a -> Vector b -> Vector (a,b)
85zipVector = zipVectorWith (,)
86
87-- | unzip for Vectors
88unzipVector :: (Storable a, Storable b, Storable (a,b)) => Vector (a,b) -> (Vector a,Vector b)
89unzipVector = unzipVectorWith id
90
91-------------------------------------------------------------------
92
93{-# DEPRECATED join "use vjoin or Data.Vector.concat" #-}
94join :: Storable t => [Vector t] -> Vector t
95join = vjoin
96