summaryrefslogtreecommitdiff
path: root/lib/Numeric
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Numeric')
-rw-r--r--lib/Numeric/GSL/Matrix.hs311
-rw-r--r--lib/Numeric/GSL/gsl-aux.c286
-rw-r--r--lib/Numeric/GSL/gsl-aux.h19
-rw-r--r--lib/Numeric/LinearAlgebra/Algorithms.hs171
-rw-r--r--lib/Numeric/LinearAlgebra/Interface.hs4
-rw-r--r--lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c74
-rw-r--r--lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h6
-rw-r--r--lib/Numeric/LinearAlgebra/Linear.hs54
-rw-r--r--lib/Numeric/LinearAlgebra/Tests.hs5
-rw-r--r--lib/Numeric/LinearAlgebra/Tests/Instances.hs16
-rw-r--r--lib/Numeric/LinearAlgebra/Tests/Properties.hs6
11 files changed, 279 insertions, 673 deletions
diff --git a/lib/Numeric/GSL/Matrix.hs b/lib/Numeric/GSL/Matrix.hs
deleted file mode 100644
index f62bb82..0000000
--- a/lib/Numeric/GSL/Matrix.hs
+++ /dev/null
@@ -1,311 +0,0 @@
1-----------------------------------------------------------------------------
2-- |
3-- Module : Numeric.GSL.Matrix
4-- Copyright : (c) Alberto Ruiz 2007
5-- License : GPL-style
6--
7-- Maintainer : Alberto Ruiz <aruiz@um.es>
8-- Stability : provisional
9-- Portability : portable (uses FFI)
10--
11-- A few linear algebra computations based on GSL.
12--
13-----------------------------------------------------------------------------
14-- #hide
15
16module Numeric.GSL.Matrix(
17 eigSg, eigHg,
18 svdg,
19 qr, qrPacked, unpackQR,
20 cholR, cholC,
21 luSolveR, luSolveC,
22 luR, luC
23) where
24
25import Data.Packed.Internal
26import Data.Packed.Matrix(ident)
27import Numeric.GSL.Vector
28import Foreign
29import Complex
30
31{- | eigendecomposition of a real symmetric matrix using /gsl_eigen_symmv/.
32
33> > let (l,v) = eigS $ 'fromLists' [[1,2],[2,1]]
34> > l
35> 3.000 -1.000
36>
37> > v
38> 0.707 -0.707
39> 0.707 0.707
40>
41> > v <> diag l <> trans v
42> 1.000 2.000
43> 2.000 1.000
44
45-}
46eigSg :: Matrix Double -> (Vector Double, Matrix Double)
47eigSg = eigSg' . cmat
48
49eigSg' m
50 | r == 1 = (fromList [cdat m `at` 0], singleton 1)
51 | otherwise = unsafePerformIO $ do
52 l <- createVector r
53 v <- createMatrix RowMajor r r
54 app3 c_eigS mat m vec l mat v "eigSg"
55 return (l,v)
56 where r = rows m
57foreign import ccall "gsl-aux.h eigensystemR" c_eigS :: TMVM
58
59------------------------------------------------------------------
60
61
62
63{- | eigendecomposition of a complex hermitian matrix using /gsl_eigen_hermv/
64
65> > let (l,v) = eigH $ 'fromLists' [[1,2+i],[2-i,3]]
66>
67> > l
68> 4.449 -0.449
69>
70> > v
71> -0.544 0.839
72> (-0.751,0.375) (-0.487,0.243)
73>
74> > v <> diag l <> (conjTrans) v
75> 1.000 (2.000,1.000)
76> (2.000,-1.000) 3.000
77
78-}
79eigHg :: Matrix (Complex Double)-> (Vector Double, Matrix (Complex Double))
80eigHg = eigHg' . cmat
81
82eigHg' m
83 | r == 1 = (fromList [realPart $ cdat m `at` 0], singleton 1)
84 | otherwise = unsafePerformIO $ do
85 l <- createVector r
86 v <- createMatrix RowMajor r r
87 app3 c_eigH mat m vec l mat v "eigHg"
88 return (l,v)
89 where r = rows m
90foreign import ccall "gsl-aux.h eigensystemC" c_eigH :: TCMVCM
91
92
93{- | Singular value decomposition of a real matrix, using /gsl_linalg_SV_decomp_mod/:
94
95
96-}
97svdg :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double)
98svdg x = if rows x >= cols x
99 then svd' (cmat x)
100 else (v, s, u) where (u,s,v) = svd' (cmat (trans x))
101
102svd' x = unsafePerformIO $ do
103 u <- createMatrix RowMajor r c
104 s <- createVector c
105 v <- createMatrix RowMajor c c
106 app4 c_svd mat x mat u vec s mat v "svdg"
107 return (u,s,v)
108 where r = rows x
109 c = cols x
110foreign import ccall "gsl-aux.h svd" c_svd :: TMMVM
111
112{- | QR decomposition of a real matrix using /gsl_linalg_QR_decomp/ and /gsl_linalg_QR_unpack/.
113
114-}
115qr :: Matrix Double -> (Matrix Double, Matrix Double)
116qr = qr' . cmat
117
118qr' x = unsafePerformIO $ do
119 q <- createMatrix RowMajor r r
120 rot <- createMatrix RowMajor r c
121 app3 c_qr mat x mat q mat rot "qr"
122 return (q,rot)
123 where r = rows x
124 c = cols x
125foreign import ccall "gsl-aux.h QR" c_qr :: TMMM
126
127qrPacked :: Matrix Double -> (Matrix Double, Vector Double)
128qrPacked = qrPacked' . cmat
129
130qrPacked' x = unsafePerformIO $ do
131 qrp <- createMatrix RowMajor r c
132 tau <- createVector (min r c)
133 app3 c_qrPacked mat x mat qrp vec tau "qrUnpacked"
134 return (qrp,tau)
135 where r = rows x
136 c = cols x
137foreign import ccall "gsl-aux.h QRpacked" c_qrPacked :: TMMV
138
139unpackQR :: (Matrix Double, Vector Double) -> (Matrix Double, Matrix Double)
140unpackQR (qrp,tau) = unpackQR' (cmat qrp, tau)
141
142unpackQR' (qrp,tau) = unsafePerformIO $ do
143 q <- createMatrix RowMajor r r
144 res <- createMatrix RowMajor r c
145 app4 c_qrUnpack mat qrp vec tau mat q mat res "qrUnpack"
146 return (q,res)
147 where r = rows qrp
148 c = cols qrp
149foreign import ccall "gsl-aux.h QRunpack" c_qrUnpack :: TMVMM
150
151{- | Cholesky decomposition of a symmetric positive definite real matrix using /gsl_linalg_cholesky_decomp/.
152
153@\> chol $ (2><2) [1,2,
154 2,9::Double]
155(2><2)
156 [ 1.0, 0.0
157 , 2.0, 2.23606797749979 ]@
158
159-}
160cholR :: Matrix Double -> Matrix Double
161cholR = cholR' . cmat
162
163cholR' x = unsafePerformIO $ do
164 r <- createMatrix RowMajor n n
165 app2 c_cholR mat x mat r "cholR"
166 return r
167 where n = rows x
168foreign import ccall "gsl-aux.h cholR" c_cholR :: TMM
169
170cholC :: Matrix (Complex Double) -> Matrix (Complex Double)
171cholC = cholC' . cmat
172
173cholC' x = unsafePerformIO $ do
174 r <- createMatrix RowMajor n n
175 app2 c_cholC mat x mat r "cholC"
176 return r
177 where n = rows x
178foreign import ccall "gsl-aux.h cholC" c_cholC :: TCMCM
179
180
181--------------------------------------------------------
182
183{- -| efficient multiplication by the inverse of a matrix (for real matrices)
184-}
185luSolveR :: Matrix Double -> Matrix Double -> Matrix Double
186luSolveR a b = luSolveR' (cmat a) (cmat b)
187
188luSolveR' a b
189 | n1==n2 && n1==r = unsafePerformIO $ do
190 s <- createMatrix RowMajor r c
191 app3 c_luSolveR mat a mat b mat s "luSolveR"
192 return s
193 | otherwise = error "luSolveR of nonsquare matrix"
194 where n1 = rows a
195 n2 = cols a
196 r = rows b
197 c = cols b
198foreign import ccall "gsl-aux.h luSolveR" c_luSolveR :: TMMM
199
200{- -| efficient multiplication by the inverse of a matrix (for complex matrices).
201-}
202luSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
203luSolveC a b = luSolveC' (cmat a) (cmat b)
204
205luSolveC' a b
206 | n1==n2 && n1==r = unsafePerformIO $ do
207 s <- createMatrix RowMajor r c
208 app3 c_luSolveC mat a mat b mat s "luSolveC"
209 return s
210 | otherwise = error "luSolveC of nonsquare matrix"
211 where n1 = rows a
212 n2 = cols a
213 r = rows b
214 c = cols b
215foreign import ccall "gsl-aux.h luSolveC" c_luSolveC :: TCMCMCM
216
217{- | lu decomposition of real matrix (packed as a vector including l, u, the permutation and sign)
218-}
219luRaux :: Matrix Double -> Vector Double
220luRaux = luRaux' . cmat
221
222luRaux' x = unsafePerformIO $ do
223 res <- createVector (r*r+r+1)
224 app2 c_luRaux mat x vec res "luRaux"
225 return res
226 where r = rows x
227foreign import ccall "gsl-aux.h luRaux" c_luRaux :: TMV
228
229{- | lu decomposition of complex matrix (packed as a vector including l, u, the permutation and sign)
230-}
231luCaux :: Matrix (Complex Double) -> Vector (Complex Double)
232luCaux = luCaux' . cmat
233
234luCaux' x = unsafePerformIO $ do
235 res <- createVector (r*r+r+1)
236 app2 c_luCaux mat x vec res "luCaux"
237 return res
238 where r = rows x
239foreign import ccall "gsl-aux.h luCaux" c_luCaux :: TCMCV
240
241{- | The LU decomposition of a square matrix. Is based on /gsl_linalg_LU_decomp/ and /gsl_linalg_complex_LU_decomp/ as described in <http://www.gnu.org/software/Numeric.GSL/manual/Numeric.GSL-ref_13.html#SEC223>.
242
243@\> let m = 'fromLists' [[1,2,-3],[2+3*i,-7,0],[1,-i,2*i]]
244\> let (l,u,p,s) = luR m@
245
246L is the lower triangular:
247
248@\> l
249 1. 0. 0.
2500.154-0.231i 1. 0.
2510.154-0.231i 0.624-0.522i 1.@
252
253U is the upper triangular:
254
255@\> u
2562.+3.i -7. 0.
257 0. 3.077-1.615i -3.
258 0. 0. 1.873+0.433i@
259
260p is a permutation:
261
262@\> p
263[1,0,2]@
264
265L \* U obtains a permuted version of the original matrix:
266
267@\> extractRows p m
268 2.+3.i -7. 0.
269 1. 2. -3.
270 1. -1.i 2.i
271\ -- CPP
272\> l \<\> u
273 2.+3.i -7. 0.
274 1. 2. -3.
275 1. -1.i 2.i@
276
277s is the sign of the permutation, required to obtain sign of the determinant:
278
279@\> s * product ('toList' $ 'takeDiag' u)
280(-18.0) :+ (-16.000000000000004)
281\> 'LinearAlgebra.Algorithms.det' m
282(-18.0) :+ (-16.000000000000004)@
283
284 -}
285luR :: Matrix Double -> (Matrix Double, Matrix Double, [Int], Double)
286luR m = (l,u,p, fromIntegral s') where
287 r = rows m
288 v = luRaux m
289 lu = reshape r $ subVector 0 (r*r) v
290 s':p = map round . toList . subVector (r*r) (r+1) $ v
291 u = triang r r 0 1`mul` lu
292 l = (triang r r 0 0 `mul` lu) `add` ident r
293 add = liftMatrix2 $ vectorZipR Add
294 mul = liftMatrix2 $ vectorZipR Mul
295
296-- | Complex version of 'luR'.
297luC :: Matrix (Complex Double) -> (Matrix (Complex Double), Matrix (Complex Double), [Int], Complex Double)
298luC m = (l,u,p, fromIntegral s') where
299 r = rows m
300 v = luCaux m
301 lu = reshape r $ subVector 0 (r*r) v
302 s':p = map (round.realPart) . toList . subVector (r*r) (r+1) $ v
303 u = triang r r 0 1 `mul` lu
304 l = (triang r r 0 0 `mul` lu) `add` liftMatrix comp (ident r)
305 add = liftMatrix2 $ vectorZipC Add
306 mul = liftMatrix2 $ vectorZipC Mul
307
308{- auxiliary function to get triangular matrices
309-}
310triang r c h v = reshape c $ fromList [el i j | i<-[0..r-1], j<-[0..c-1]]
311 where el i j = if j-i>=h then v else 1 - v
diff --git a/lib/Numeric/GSL/gsl-aux.c b/lib/Numeric/GSL/gsl-aux.c
index bd0a6bd..052cafd 100644
--- a/lib/Numeric/GSL/gsl-aux.c
+++ b/lib/Numeric/GSL/gsl-aux.c
@@ -1,11 +1,8 @@
1#include "gsl-aux.h" 1#include "gsl-aux.h"
2#include <gsl/gsl_blas.h> 2#include <gsl/gsl_blas.h>
3#include <gsl/gsl_linalg.h>
4#include <gsl/gsl_matrix.h>
5#include <gsl/gsl_math.h> 3#include <gsl/gsl_math.h>
6#include <gsl/gsl_errno.h> 4#include <gsl/gsl_errno.h>
7#include <gsl/gsl_fft_complex.h> 5#include <gsl/gsl_fft_complex.h>
8#include <gsl/gsl_eigen.h>
9#include <gsl/gsl_integration.h> 6#include <gsl/gsl_integration.h>
10#include <gsl/gsl_deriv.h> 7#include <gsl/gsl_deriv.h>
11#include <gsl/gsl_poly.h> 8#include <gsl/gsl_poly.h>
@@ -161,47 +158,6 @@ int mapC(int code, KCVEC(x), CVEC(r)) {
161} 158}
162 159
163 160
164/*
165int scaleR(double* alpha, KRVEC(x), RVEC(r)) {
166 REQUIRES(xn == rn,BAD_SIZE);
167 DEBUGMSG("scaleR");
168 KDVVIEW(x);
169 DVVIEW(r);
170 CHECK( gsl_vector_memcpy(V(r),V(x)) , MEM);
171 int res = gsl_vector_scale(V(r),*alpha);
172 CHECK(res,res);
173 OK
174}
175
176int scaleC(gsl_complex *alpha, KCVEC(x), CVEC(r)) {
177 REQUIRES(xn == rn,BAD_SIZE);
178 DEBUGMSG("scaleC");
179 //KCVVIEW(x);
180 CVVIEW(r);
181 gsl_vector_const_view vrx = gsl_vector_const_view_array((double*)xp,xn*2);
182 gsl_vector_view vrr = gsl_vector_view_array((double*)rp,rn*2);
183 CHECK(gsl_vector_memcpy(V(vrr),V(vrx)) , MEM);
184 gsl_blas_zscal(*alpha,V(r)); // void !
185 int res = 0;
186 CHECK(res,res);
187 OK
188}
189
190int addConstantR(double offs, KRVEC(x), RVEC(r)) {
191 REQUIRES(xn == rn,BAD_SIZE);
192 DEBUGMSG("addConstantR");
193 KDVVIEW(x);
194 DVVIEW(r);
195 CHECK(gsl_vector_memcpy(V(r),V(x)), MEM);
196 int res = gsl_vector_add_constant(V(r),offs);
197 CHECK(res,res);
198 OK
199}
200
201*/
202
203
204
205int mapValR(int code, double* pval, KRVEC(x), RVEC(r)) { 161int mapValR(int code, double* pval, KRVEC(x), RVEC(r)) {
206 int k; 162 int k;
207 double val = *pval; 163 double val = *pval;
@@ -291,248 +247,6 @@ int zipC(int code, KCVEC(a), KCVEC(b), CVEC(r)) {
291 247
292 248
293 249
294
295int luSolveR(KRMAT(a),KRMAT(b),RMAT(r)) {
296 REQUIRES(ar==ac && ac==br && ar==rr && bc==rc,BAD_SIZE);
297 DEBUGMSG("luSolveR");
298 KDMVIEW(a);
299 KDMVIEW(b);
300 DMVIEW(r);
301 int res;
302 gsl_matrix *LU = gsl_matrix_alloc(ar,ar);
303 CHECK(!LU,MEM);
304 int s;
305 gsl_permutation * p = gsl_permutation_alloc (ar);
306 CHECK(!p,MEM);
307 CHECK(gsl_matrix_memcpy(LU,M(a)),MEM);
308 res = gsl_linalg_LU_decomp(LU, p, &s);
309 CHECK(res,res);
310 int c;
311
312 for (c=0; c<bc; c++) {
313 gsl_vector_const_view colb = gsl_matrix_const_column (M(b), c);
314 gsl_vector_view colr = gsl_matrix_column (M(r), c);
315 res = gsl_linalg_LU_solve (LU, p, V(colb), V(colr));
316 CHECK(res,res);
317 }
318 gsl_permutation_free(p);
319 gsl_matrix_free(LU);
320 OK
321}
322
323
324int luSolveC(KCMAT(a),KCMAT(b),CMAT(r)) {
325 REQUIRES(ar==ac && ac==br && ar==rr && bc==rc,BAD_SIZE);
326 DEBUGMSG("luSolveC");
327 KCMVIEW(a);
328 KCMVIEW(b);
329 CMVIEW(r);
330 gsl_matrix_complex *LU = gsl_matrix_complex_alloc(ar,ar);
331 CHECK(!LU,MEM);
332 int s;
333 gsl_permutation * p = gsl_permutation_alloc (ar);
334 CHECK(!p,MEM);
335 CHECK(gsl_matrix_complex_memcpy(LU,M(a)),MEM);
336 int res;
337 res = gsl_linalg_complex_LU_decomp(LU, p, &s);
338 CHECK(res,res);
339 int c;
340 for (c=0; c<bc; c++) {
341 gsl_vector_complex_const_view colb = gsl_matrix_complex_const_column (M(b), c);
342 gsl_vector_complex_view colr = gsl_matrix_complex_column (M(r), c);
343 res = gsl_linalg_complex_LU_solve (LU, p, V(colb), V(colr));
344 CHECK(res,res);
345 }
346 gsl_permutation_free(p);
347 gsl_matrix_complex_free(LU);
348 OK
349}
350
351
352int luRaux(KRMAT(a),RVEC(b)) {
353 REQUIRES(ar==ac && bn==ar*ar+ar+1,BAD_SIZE);
354 DEBUGMSG("luRaux");
355 KDMVIEW(a);
356 //DVVIEW(b);
357 gsl_matrix_view LU = gsl_matrix_view_array(bp,ar,ac);
358 int s;
359 gsl_permutation * p = gsl_permutation_alloc (ar);
360 CHECK(!p,MEM);
361 CHECK(gsl_matrix_memcpy(M(LU),M(a)),MEM);
362 gsl_linalg_LU_decomp(M(LU), p, &s);
363 bp[ar*ar] = s;
364 int k;
365 for (k=0; k<ar; k++) {
366 bp[ar*ar+k+1] = gsl_permutation_get(p,k);
367 }
368 gsl_permutation_free(p);
369 OK
370}
371
372int luCaux(KCMAT(a),CVEC(b)) {
373 REQUIRES(ar==ac && bn==ar*ar+ar+1,BAD_SIZE);
374 DEBUGMSG("luCaux");
375 KCMVIEW(a);
376 //DVVIEW(b);
377 gsl_matrix_complex_view LU = gsl_matrix_complex_view_array((double*)bp,ar,ac);
378 int s;
379 gsl_permutation * p = gsl_permutation_alloc (ar);
380 CHECK(!p,MEM);
381 CHECK(gsl_matrix_complex_memcpy(M(LU),M(a)),MEM);
382 int res;
383 res = gsl_linalg_complex_LU_decomp(M(LU), p, &s);
384 CHECK(res,res);
385 ((double*)bp)[2*ar*ar] = s;
386 ((double*)bp)[2*ar*ar+1] = 0;
387 int k;
388 for (k=0; k<ar; k++) {
389 ((double*)bp)[2*ar*ar+2*k+2] = gsl_permutation_get(p,k);
390 ((double*)bp)[2*ar*ar+2*k+2+1] = 0;
391 }
392 gsl_permutation_free(p);
393 OK
394}
395
396int svd(KRMAT(a),RMAT(u), RVEC(s),RMAT(v)) {
397 REQUIRES(ar==ur && ac==uc && ac==sn && ac==vr && ac==vc,BAD_SIZE);
398 DEBUGMSG("svd");
399 KDMVIEW(a);
400 DMVIEW(u);
401 DVVIEW(s);
402 DMVIEW(v);
403 gsl_vector *workv = gsl_vector_alloc(ac);
404 CHECK(!workv,MEM);
405 gsl_matrix *workm = gsl_matrix_alloc(ac,ac);
406 CHECK(!workm,MEM);
407 CHECK(gsl_matrix_memcpy(M(u),M(a)),MEM);
408 // int res = gsl_linalg_SV_decomp_jacobi (&U.matrix, &V.matrix, &S.vector);
409 // doesn't work
410 //int res = gsl_linalg_SV_decomp (&U.matrix, &V.matrix, &S.vector, workv);
411 int res = gsl_linalg_SV_decomp_mod (M(u), workm, M(v), V(s), workv);
412 CHECK(res,res);
413 //gsl_matrix_transpose(M(v));
414 gsl_vector_free(workv);
415 gsl_matrix_free(workm);
416 OK
417}
418
419
420// for real symmetric matrices
421int eigensystemR(KRMAT(x),RVEC(l),RMAT(v)) {
422 REQUIRES(xr==xc && xr==ln && xr==vr && vr==vc,BAD_SIZE);
423 DEBUGMSG("eigensystemR (gsl_eigen_symmv)");
424 KDMVIEW(x);
425 DVVIEW(l);
426 DMVIEW(v);
427 gsl_matrix *XC = gsl_matrix_alloc(xr,xr);
428 gsl_matrix_memcpy(XC,M(x)); // needed because the argument is destroyed
429 // many thanks to Nico Mahlo for the bug report
430 gsl_eigen_symmv_workspace * w = gsl_eigen_symmv_alloc (xc);
431 int res = gsl_eigen_symmv (XC, V(l), M(v), w);
432 CHECK(res,res);
433 gsl_eigen_symmv_free (w);
434 gsl_matrix_free(XC);
435 gsl_eigen_symmv_sort (V(l), M(v), GSL_EIGEN_SORT_ABS_DESC);
436 OK
437}
438
439// for hermitian matrices
440int eigensystemC(KCMAT(x),RVEC(l),CMAT(v)) {
441 REQUIRES(xr==xc && xr==ln && xr==vr && vr==vc,BAD_SIZE);
442 DEBUGMSG("eigensystemC");
443 KCMVIEW(x);
444 DVVIEW(l);
445 CMVIEW(v);
446 gsl_matrix_complex *XC = gsl_matrix_complex_alloc(xr,xr);
447 gsl_matrix_complex_memcpy(XC,M(x)); // again needed because the argument is destroyed
448 gsl_eigen_hermv_workspace * w = gsl_eigen_hermv_alloc (xc);
449 int res = gsl_eigen_hermv (XC, V(l), M(v), w);
450 CHECK(res,res);
451 gsl_eigen_hermv_free (w);
452 gsl_matrix_complex_free(XC);
453 gsl_eigen_hermv_sort (V(l), M(v), GSL_EIGEN_SORT_ABS_DESC);
454 OK
455}
456
457int QR(KRMAT(x),RMAT(q),RMAT(r)) {
458 REQUIRES(xr==rr && xc==rc && qr==qc && xr==qr,BAD_SIZE);
459 DEBUGMSG("QR");
460 KDMVIEW(x);
461 DMVIEW(q);
462 DMVIEW(r);
463 gsl_matrix * a = gsl_matrix_alloc(xr,xc);
464 gsl_vector * tau = gsl_vector_alloc(MIN(xr,xc));
465 gsl_matrix_memcpy(a,M(x));
466 int res = gsl_linalg_QR_decomp(a,tau);
467 CHECK(res,res);
468 gsl_linalg_QR_unpack(a,tau,M(q),M(r));
469 gsl_vector_free(tau);
470 gsl_matrix_free(a);
471 OK
472}
473
474int QRpacked(KRMAT(x),RMAT(qr),RVEC(tau)) {
475 //REQUIRES(xr==rr && xc==rc && qr==qc && xr==qr,BAD_SIZE);
476 DEBUGMSG("QRpacked");
477 KDMVIEW(x);
478 DMVIEW(qr);
479 DVVIEW(tau);
480 //gsl_matrix * a = gsl_matrix_alloc(xr,xc);
481 //gsl_vector * tau = gsl_vector_alloc(MIN(xr,xc));
482 gsl_matrix_memcpy(M(qr),M(x));
483 int res = gsl_linalg_QR_decomp(M(qr),V(tau));
484 CHECK(res,res);
485 OK
486}
487
488
489int QRunpack(KRMAT(xqr),KRVEC(tau),RMAT(q),RMAT(r)) {
490 //REQUIRES(xr==rr && xc==rc && qr==qc && xr==qr,BAD_SIZE);
491 DEBUGMSG("QRunpack");
492 KDMVIEW(xqr);
493 KDVVIEW(tau);
494 DMVIEW(q);
495 DMVIEW(r);
496 gsl_linalg_QR_unpack(M(xqr),V(tau),M(q),M(r));
497 OK
498}
499
500
501int cholR(KRMAT(x),RMAT(l)) {
502 REQUIRES(xr==xc && lr==xr && lr==lc,BAD_SIZE);
503 DEBUGMSG("cholR");
504 KDMVIEW(x);
505 DMVIEW(l);
506 gsl_matrix_memcpy(M(l),M(x));
507 int res = gsl_linalg_cholesky_decomp(M(l));
508 CHECK(res,res);
509 int r,c;
510 for (r=0; r<xr-1; r++) {
511 for(c=r+1; c<xc; c++) {
512 lp[r*lc+c] = 0.;
513 }
514 }
515 OK
516}
517
518int cholC(KCMAT(x),CMAT(l)) {
519 REQUIRES(xr==xc && lr==xr && lr==lc,BAD_SIZE);
520 DEBUGMSG("cholC");
521 KCMVIEW(x);
522 CMVIEW(l);
523 gsl_matrix_complex_memcpy(M(l),M(x));
524 int res = 0; // gsl_linalg_complex_cholesky_decomp(M(l));
525 CHECK(res,res);
526 gsl_complex zero = {0.,0.};
527 int r,c;
528 for (r=0; r<xr-1; r++) {
529 for(c=r+1; c<xc; c++) {
530 lp[r*lc+c] = zero;
531 }
532 }
533 OK
534}
535
536int fft(int code, KCVEC(X), CVEC(R)) { 250int fft(int code, KCVEC(X), CVEC(R)) {
537 REQUIRES(Xn == Rn,BAD_SIZE); 251 REQUIRES(Xn == Rn,BAD_SIZE);
538 DEBUGMSG("fft"); 252 DEBUGMSG("fft");
diff --git a/lib/Numeric/GSL/gsl-aux.h b/lib/Numeric/GSL/gsl-aux.h
index eee15e7..cd17ef0 100644
--- a/lib/Numeric/GSL/gsl-aux.h
+++ b/lib/Numeric/GSL/gsl-aux.h
@@ -26,25 +26,6 @@ int zipR(int code, KRVEC(a), KRVEC(b), RVEC(r));
26int zipC(int code, KCVEC(a), KCVEC(b), CVEC(r)); 26int zipC(int code, KCVEC(a), KCVEC(b), CVEC(r));
27 27
28 28
29int luSolveR(KRMAT(a),KRMAT(b),RMAT(r));
30int luSolveC(KCMAT(a),KCMAT(b),CMAT(r));
31int luRaux(KRMAT(a),RVEC(b));
32int luCaux(KCMAT(a),CVEC(b));
33
34int svd(KRMAT(x),RMAT(u), RVEC(s),RMAT(v));
35
36int eigensystemR(KRMAT(x),RVEC(l),RMAT(v));
37int eigensystemC(KCMAT(x),RVEC(l),CMAT(v));
38
39int QR(KRMAT(x),RMAT(q),RMAT(r));
40
41int QRpacked(KRMAT(x),RMAT(qr),RVEC(tau));
42int QRunpack(KRMAT(qr),KRVEC(tau),RMAT(q),RMAT(r));
43
44int cholR(KRMAT(x),RMAT(l));
45
46int cholC(KCMAT(x),CMAT(l));
47
48int fft(int code, KCVEC(a), CVEC(b)); 29int fft(int code, KCVEC(a), CVEC(b));
49 30
50int integrate_qng(double f(double, void*), double a, double b, double prec, 31int integrate_qng(double f(double, void*), double a, double b, double prec,
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs
index bbc5986..c7118c1 100644
--- a/lib/Numeric/LinearAlgebra/Algorithms.hs
+++ b/lib/Numeric/LinearAlgebra/Algorithms.hs
@@ -20,6 +20,7 @@ imported from "Numeric.LinearAlgebra.LAPACK".
20 20
21module Numeric.LinearAlgebra.Algorithms ( 21module Numeric.LinearAlgebra.Algorithms (
22-- * Linear Systems 22-- * Linear Systems
23 multiply, dot,
23 linearSolve, 24 linearSolve,
24 inv, pinv, 25 inv, pinv,
25 pinvTol, det, rank, rcond, 26 pinvTol, det, rank, rcond,
@@ -51,6 +52,8 @@ module Numeric.LinearAlgebra.Algorithms (
51-- * Misc 52-- * Misc
52 ctrans, 53 ctrans,
53 eps, i, 54 eps, i,
55 outer, kronecker,
56 mulH,
54-- * Util 57-- * Util
55 haussholder, 58 haussholder,
56 unpackQR, unpackHess, 59 unpackQR, unpackHess,
@@ -60,13 +63,14 @@ module Numeric.LinearAlgebra.Algorithms (
60 63
61import Data.Packed.Internal hiding (fromComplex, toComplex, comp, conj, (//)) 64import Data.Packed.Internal hiding (fromComplex, toComplex, comp, conj, (//))
62import Data.Packed 65import Data.Packed
63import qualified Numeric.GSL.Matrix as GSL
64import Numeric.GSL.Vector 66import Numeric.GSL.Vector
65import Numeric.LinearAlgebra.LAPACK as LAPACK 67import Numeric.LinearAlgebra.LAPACK as LAPACK
66import Complex 68import Complex
67import Numeric.LinearAlgebra.Linear 69import Numeric.LinearAlgebra.Linear
68import Data.List(foldl1') 70import Data.List(foldl1')
69import Data.Array 71import Data.Array
72import Foreign
73import Foreign.C.Types
70 74
71-- | Auxiliary typeclass used to define generic computations for both real and complex matrices. 75-- | Auxiliary typeclass used to define generic computations for both real and complex matrices.
72class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where 76class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where
@@ -105,6 +109,7 @@ class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where
105 schur :: Matrix t -> (Matrix t, Matrix t) 109 schur :: Matrix t -> (Matrix t, Matrix t)
106 -- | Conjugate transpose. 110 -- | Conjugate transpose.
107 ctrans :: Matrix t -> Matrix t 111 ctrans :: Matrix t -> Matrix t
112 multiply :: Matrix t -> Matrix t -> Matrix t
108 113
109 114
110instance Field Double where 115instance Field Double where
@@ -116,9 +121,10 @@ instance Field Double where
116 eig = eigR 121 eig = eigR
117 eigSH' = eigS 122 eigSH' = eigS
118 cholSH = cholS 123 cholSH = cholS
119 qr = GSL.unpackQR . qrR 124 qr = unpackQR . qrR
120 hess = unpackHess hessR 125 hess = unpackHess hessR
121 schur = schurR 126 schur = schurR
127 multiply = multiplyR3
122 128
123instance Field (Complex Double) where 129instance Field (Complex Double) where
124 svd = svdC 130 svd = svdC
@@ -132,6 +138,8 @@ instance Field (Complex Double) where
132 qr = unpackQR . qrC 138 qr = unpackQR . qrC
133 hess = unpackHess hessC 139 hess = unpackHess hessC
134 schur = schurC 140 schur = schurC
141 multiply = mulCW -- workaround
142 -- multiplyC3
135 143
136-- | Eigenvalues and Eigenvectors of a complex hermitian or real symmetric matrix using lapack's dsyev or zheev. 144-- | Eigenvalues and Eigenvectors of a complex hermitian or real symmetric matrix using lapack's dsyev or zheev.
137-- 145--
@@ -501,3 +509,162 @@ luFact (lu,perm) | r <= c = (l ,u ,p, s)
501 u' = takeRows c (lu |*| tu) 509 u' = takeRows c (lu |*| tu)
502 (|+|) = add 510 (|+|) = add
503 (|*|) = mul 511 (|*|) = mul
512
513--------------------------------------------------
514
515-- | euclidean inner product
516dot :: (Field t) => Vector t -> Vector t -> t
517dot u v = multiply r c @@> (0,0)
518 where r = asRow u
519 c = asColumn v
520
521
522{- | Outer product of two vectors.
523
524@\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3]
525(3><3)
526 [ 5.0, 2.0, 3.0
527 , 10.0, 4.0, 6.0
528 , 15.0, 6.0, 9.0 ]@
529-}
530outer :: (Field t) => Vector t -> Vector t -> Matrix t
531outer u v = asColumn u `multiply` asRow v
532
533{- | Kronecker product of two matrices.
534
535@m1=(2><3)
536 [ 1.0, 2.0, 0.0
537 , 0.0, -1.0, 3.0 ]
538m2=(4><3)
539 [ 1.0, 2.0, 3.0
540 , 4.0, 5.0, 6.0
541 , 7.0, 8.0, 9.0
542 , 10.0, 11.0, 12.0 ]@
543
544@\> kronecker m1 m2
545(8><9)
546 [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0
547 , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0
548 , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0
549 , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0
550 , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0
551 , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0
552 , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0
553 , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ]@
554-}
555kronecker :: (Field t) => Matrix t -> Matrix t -> Matrix t
556kronecker a b = fromBlocks
557 . partit (cols a)
558 . map (reshape (cols b))
559 . toRows
560 $ flatten a `outer` flatten b
561
562---------------------------------------------------------------------
563-- reference multiply
564---------------------------------------------------------------------
565
566mulH a b = fromLists [[ dot ai bj | bj <- toColumns b] | ai <- toRows a ]
567 where dot u v = sum $ zipWith (*) (toList u) (toList v)
568
569-----------------------------------------------------------------------------------
570-- workaround
571-----------------------------------------------------------------------------------
572
573mulCW a b = toComplex (rr,ri)
574 where rr = multiply ar br `sub` multiply ai bi
575 ri = multiply ar bi `add` multiply ai br
576 (ar,ai) = fromComplex a
577 (br,bi) = fromComplex b
578
579-----------------------------------------------------------------------------------
580-- Direct CBLAS
581-----------------------------------------------------------------------------------
582
583newtype CBLASOrder = CBLASOrder CInt deriving (Eq, Show)
584newtype CBLASTrans = CBLASTrans CInt deriving (Eq, Show)
585
586rowMajor, colMajor :: CBLASOrder
587rowMajor = CBLASOrder 101
588colMajor = CBLASOrder 102
589
590noTrans, trans', conjTrans :: CBLASTrans
591noTrans = CBLASTrans 111
592trans' = CBLASTrans 112
593conjTrans = CBLASTrans 113
594
595foreign import ccall "cblas.h cblas_dgemm"
596 dgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt -> Double -> Ptr Double -> CInt -> Ptr Double -> CInt -> Double -> Ptr Double -> CInt -> IO ()
597
598
599multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double
600multiplyR3 a b = multiply3 dgemm "cblas_dgemm" (fmat a) (fmat b)
601 where
602 multiply3 f st a b
603 | cols a == rows b = unsafePerformIO $ do
604 s <- createMatrix ColumnMajor (rows a) (cols b)
605 let g ar ac ap br bc bp rr rc rp = f colMajor noTrans noTrans ar bc ac 1 ap ar bp br 0 rp rr >> return 0
606 app3 g mat a mat b mat s st
607 return s
608 | otherwise = error $ st ++ " (matrix product) of nonconformant matrices"
609
610
611foreign import ccall "cblas.h cblas_zgemm"
612 zgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> IO ()
613
614multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
615multiplyC3 a b = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat a) (fmat b)
616 where
617 multiply3 f st a b
618 | cols a == rows b = do
619 s <- createMatrix ColumnMajor (rows a) (cols b)
620 palpha <- new 1
621 pbeta <- new 0
622 let g ar ac ap br bc bp rr rc rp = f colMajor noTrans noTrans ar bc ac palpha ap ar bp br pbeta rp rr >> return 0
623 app3 g mat a mat b mat s st
624 free palpha
625 free pbeta
626 return s
627 -- if toLists s== toLists s then return s else error $ "HORROR " ++ (show (toLists s))
628 | otherwise = error $ st ++ " (matrix product) of nonconformant matrices"
629
630-----------------------------------------------------------------------------------
631-- BLAS via auxiliary C
632-----------------------------------------------------------------------------------
633
634foreign import ccall "multiply.h multiplyR2" dgemmc :: TMMM
635foreign import ccall "multiply.h multiplyC2" zgemmc :: TCMCMCM
636
637multiply2 f st a b
638 | cols a == rows b = unsafePerformIO $ do
639 s <- createMatrix ColumnMajor (rows a) (cols b)
640 app3 f mat a mat b mat s st
641 if toLists s== toLists s then return s else error $ "AYYY " ++ (show (toLists s))
642 | otherwise = error $ st ++ " (matrix product) of nonconformant matrices"
643
644multiplyR2 :: Matrix Double -> Matrix Double -> Matrix Double
645multiplyR2 a b = multiply2 dgemmc "dgemmc" (fmat a) (fmat b)
646
647multiplyC2 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
648multiplyC2 a b = multiply2 zgemmc "zgemmc" (fmat a) (fmat b)
649
650-----------------------------------------------------------------------------------
651-- direct C multiplication
652-----------------------------------------------------------------------------------
653
654foreign import ccall "multiply.h multiplyR" cmultiplyR :: TMMM
655foreign import ccall "multiply.h multiplyC" cmultiplyC :: TCMCMCM
656
657cmultiply f st a b
658-- | cols a == rows b =
659 = unsafePerformIO $ do
660 s <- createMatrix RowMajor (rows a) (cols b)
661 app3 f mat a mat b mat s st
662 if toLists s== toLists s then return s else error $ "BRUTAL " ++ (show (toLists s))
663 -- return s
664-- | otherwise = error $ st ++ " (matrix product) of nonconformant matrices"
665
666multiplyR :: Matrix Double -> Matrix Double -> Matrix Double
667multiplyR a b = cmultiply cmultiplyR "cmultiplyR" (cmat a) (cmat b)
668
669multiplyC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
670multiplyC a b = cmultiply cmultiplyC "cmultiplyR" (cmat a) (cmat b)
diff --git a/lib/Numeric/LinearAlgebra/Interface.hs b/lib/Numeric/LinearAlgebra/Interface.hs
index 4a9b309..0ae9698 100644
--- a/lib/Numeric/LinearAlgebra/Interface.hs
+++ b/lib/Numeric/LinearAlgebra/Interface.hs
@@ -29,7 +29,7 @@ import Numeric.LinearAlgebra.Algorithms
29class Mul a b c | a b -> c where 29class Mul a b c | a b -> c where
30 infixl 7 <> 30 infixl 7 <>
31 -- | matrix product 31 -- | matrix product
32 (<>) :: Element t => a t -> b t -> c t 32 (<>) :: Field t => a t -> b t -> c t
33 33
34instance Mul Matrix Matrix Matrix where 34instance Mul Matrix Matrix Matrix where
35 (<>) = multiply 35 (<>) = multiply
@@ -43,7 +43,7 @@ instance Mul Vector Matrix Vector where
43--------------------------------------------------- 43---------------------------------------------------
44 44
45-- | @u \<.\> v = dot u v@ 45-- | @u \<.\> v = dot u v@
46(<.>) :: (Element t) => Vector t -> Vector t -> t 46(<.>) :: (Field t) => Vector t -> Vector t -> t
47infixl 7 <.> 47infixl 7 <.>
48(<.>) = dot 48(<.>) = dot
49 49
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
index 310f6ee..0dccea2 100644
--- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
+++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
@@ -814,3 +814,77 @@ int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)) {
814 free(auxipiv); 814 free(auxipiv);
815 OK 815 OK
816} 816}
817
818////////////////////////////////////////////////////////////
819
820int multiplyR(KDMAT(a),KDMAT(b),DMAT(r)) {
821 REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
822 int i,j,k;
823 for (i=0;i<ar;i++) {
824 for(j=0;j<bc;j++) {
825 double temp = 0;
826 for(k=0;k<ac;k++) {
827 temp += ap[i*ac+k]*bp[k*bc+j];
828 }
829 rp[i*rc+j] = temp;
830 }
831 }
832 OK
833}
834
835int multiplyC(KCMAT(a),KCMAT(b),CMAT(r)) {
836 REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
837 int i,j,k;
838 for (i=0;i<ar;i++) {
839 for(j=0;j<bc;j++) {
840 doublecomplex temp = {0,0};
841 for(k=0;k<ac;k++) {
842 doublecomplex aik = ((doublecomplex*)ap)[i*ac+k];
843 doublecomplex bkj = ((doublecomplex*)bp)[k*bc+j];
844 //double w = aik.r+aik.i+bkj.r+bkj.i;
845 //if (w>w) exit(1);
846 //printf("%d",w>w);
847 temp.r += aik.r * bkj.r - aik.i * bkj.i;
848 temp.i += aik.r * bkj.i + aik.i * bkj.r;
849 //printf("%f %f %f %f \n",aik.r,aik.i,bkj.r,bkj.i);
850 //printf("%f %f %f \n",w,temp.r,temp.i);
851
852 }
853 ((doublecomplex*)rp)[i*rc+j] = temp;
854 //printf("%f %f\n",temp.r,temp.i);
855 }
856 }
857 OK
858}
859
860void dgemm_(char *, char *, integer *, integer *, integer *,
861 double *, const double *, integer *, const double *,
862 integer *, double *, double *, integer *);
863
864void zgemm_(char *, char *, integer *, integer *, integer *,
865 doublecomplex *, const doublecomplex *, integer *, const doublecomplex *,
866 integer *, doublecomplex *, doublecomplex *, integer *);
867
868
869int multiplyR2(KDMAT(a),KDMAT(b),DMAT(r)) {
870 REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
871 double alpha = 1;
872 double beta = 0;
873 integer m = ar;
874 integer n = bc;
875 integer k = ac;
876 int i,j;
877 dgemm_("N","N",&m,&n,&k,&alpha,ap,&m,bp,&k,&beta,rp,&m);
878 OK
879}
880
881int multiplyC2(KCMAT(a),KCMAT(b),CMAT(r)) {
882 REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
883 integer m = ar;
884 integer n = bc;
885 integer k = ac;
886 doublecomplex alpha = {1,0};
887 doublecomplex beta = {0,0};
888 zgemm_("N","N",&m,&n,&k,&alpha,(doublecomplex*)ap,&m,(doublecomplex*)bp,&k,&beta,(doublecomplex*)rp,&m);
889 OK
890}
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h
index 79e52be..c0361a6 100644
--- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h
+++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h
@@ -84,3 +84,9 @@ int schur_l_C(KCMAT(a), CMAT(u), CMAT(s));
84 84
85int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r)); 85int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r));
86int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)); 86int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r));
87
88int multiplyR(KDMAT(a),KDMAT(b),DMAT(r));
89int multiplyC(KCMAT(a),KCMAT(b),CMAT(r));
90
91int multiplyR2(KDMAT(a),KDMAT(b),DMAT(r));
92int multiplyC2(KCMAT(a),KCMAT(b),CMAT(r));
diff --git a/lib/Numeric/LinearAlgebra/Linear.hs b/lib/Numeric/LinearAlgebra/Linear.hs
index 0ddbb55..1bf8b04 100644
--- a/lib/Numeric/LinearAlgebra/Linear.hs
+++ b/lib/Numeric/LinearAlgebra/Linear.hs
@@ -15,12 +15,11 @@ Basic optimized operations on vectors and matrices.
15----------------------------------------------------------------------------- 15-----------------------------------------------------------------------------
16 16
17module Numeric.LinearAlgebra.Linear ( 17module Numeric.LinearAlgebra.Linear (
18 Linear(..), 18 Linear(..)
19 multiply, dot, outer, kronecker
20) where 19) where
21 20
22 21
23import Data.Packed.Internal(multiply,partit) 22import Data.Packed.Internal(partit)
24import Data.Packed 23import Data.Packed
25import Numeric.GSL.Vector 24import Numeric.GSL.Vector
26import Complex 25import Complex
@@ -69,52 +68,3 @@ instance (Linear Vector a, Container Matrix a) => (Linear Matrix a) where
69 mul = liftMatrix2 mul 68 mul = liftMatrix2 mul
70 divide = liftMatrix2 divide 69 divide = liftMatrix2 divide
71 equal a b = cols a == cols b && flatten a `equal` flatten b 70 equal a b = cols a == cols b && flatten a `equal` flatten b
72
73--------------------------------------------------
74
75-- | euclidean inner product
76dot :: (Element t) => Vector t -> Vector t -> t
77dot u v = multiply r c @@> (0,0)
78 where r = asRow u
79 c = asColumn v
80
81
82{- | Outer product of two vectors.
83
84@\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3]
85(3><3)
86 [ 5.0, 2.0, 3.0
87 , 10.0, 4.0, 6.0
88 , 15.0, 6.0, 9.0 ]@
89-}
90outer :: (Element t) => Vector t -> Vector t -> Matrix t
91outer u v = asColumn u `multiply` asRow v
92
93{- | Kronecker product of two matrices.
94
95@m1=(2><3)
96 [ 1.0, 2.0, 0.0
97 , 0.0, -1.0, 3.0 ]
98m2=(4><3)
99 [ 1.0, 2.0, 3.0
100 , 4.0, 5.0, 6.0
101 , 7.0, 8.0, 9.0
102 , 10.0, 11.0, 12.0 ]@
103
104@\> kronecker m1 m2
105(8><9)
106 [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0
107 , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0
108 , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0
109 , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0
110 , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0
111 , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0
112 , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0
113 , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ]@
114-}
115kronecker :: (Element t) => Matrix t -> Matrix t -> Matrix t
116kronecker a b = fromBlocks
117 . partit (cols a)
118 . map (reshape (cols b))
119 . toRows
120 $ flatten a `outer` flatten b
diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs
index 7b28075..07b9f63 100644
--- a/lib/Numeric/LinearAlgebra/Tests.hs
+++ b/lib/Numeric/LinearAlgebra/Tests.hs
@@ -123,6 +123,11 @@ runTests :: Int -- ^ maximum dimension
123runTests n = do 123runTests n = do
124 setErrorHandlerOff 124 setErrorHandlerOff
125 let test p = qCheck n p 125 let test p = qCheck n p
126 putStrLn "------ mult"
127 test (multProp1 . rConsist)
128 test (multProp1 . cConsist)
129 test (multProp2 . rConsist)
130 test (multProp2 . cConsist)
126 putStrLn "------ lu" 131 putStrLn "------ lu"
127 test (luProp . rM) 132 test (luProp . rM)
128 test (luProp . cM) 133 test (luProp . cM)
diff --git a/lib/Numeric/LinearAlgebra/Tests/Instances.hs b/lib/Numeric/LinearAlgebra/Tests/Instances.hs
index af486c8..e7fecf2 100644
--- a/lib/Numeric/LinearAlgebra/Tests/Instances.hs
+++ b/lib/Numeric/LinearAlgebra/Tests/Instances.hs
@@ -20,6 +20,7 @@ module Numeric.LinearAlgebra.Tests.Instances(
20 WC(..), rWC,cWC, 20 WC(..), rWC,cWC,
21 SqWC(..), rSqWC, cSqWC, 21 SqWC(..), rSqWC, cSqWC,
22 PosDef(..), rPosDef, cPosDef, 22 PosDef(..), rPosDef, cPosDef,
23 Consistent(..), rConsist, cConsist,
23 RM,CM, rM,cM 24 RM,CM, rM,cM
24) where 25) where
25 26
@@ -116,6 +117,19 @@ instance (Field a, Arbitrary a) => Arbitrary (PosDef a) where
116 return $ PosDef (0.5 .* p + 0.5 .* ctrans p) 117 return $ PosDef (0.5 .* p + 0.5 .* ctrans p)
117 coarbitrary = undefined 118 coarbitrary = undefined
118 119
120-- a pair of matrices that can be multiplied
121newtype (Consistent a) = Consistent (Matrix a, Matrix a) deriving Show
122instance (Field a, Arbitrary a) => Arbitrary (Consistent a) where
123 arbitrary = do
124 n <- chooseDim
125 k <- chooseDim
126 m <- chooseDim
127 la <- vector (n*k)
128 lb <- vector (k*m)
129 return $ Consistent ((n><k) la, (k><m) lb)
130 coarbitrary = undefined
131
132
119type RM = Matrix Double 133type RM = Matrix Double
120type CM = Matrix (Complex Double) 134type CM = Matrix (Complex Double)
121 135
@@ -140,3 +154,5 @@ cSqWC (SqWC m) = m :: CM
140rPosDef (PosDef m) = m :: RM 154rPosDef (PosDef m) = m :: RM
141cPosDef (PosDef m) = m :: CM 155cPosDef (PosDef m) = m :: CM
142 156
157rConsist (Consistent (a,b)) = (a,b::RM)
158cConsist (Consistent (a,b)) = (a,b::CM)
diff --git a/lib/Numeric/LinearAlgebra/Tests/Properties.hs b/lib/Numeric/LinearAlgebra/Tests/Properties.hs
index 55e9a1b..5663b86 100644
--- a/lib/Numeric/LinearAlgebra/Tests/Properties.hs
+++ b/lib/Numeric/LinearAlgebra/Tests/Properties.hs
@@ -34,7 +34,8 @@ module Numeric.LinearAlgebra.Tests.Properties (
34 hessProp, 34 hessProp,
35 schurProp1, schurProp2, 35 schurProp1, schurProp2,
36 cholProp, 36 cholProp,
37 expmDiagProp 37 expmDiagProp,
38 multProp1, multProp2
38) where 39) where
39 40
40import Numeric.LinearAlgebra 41import Numeric.LinearAlgebra
@@ -151,3 +152,6 @@ cholProp m = m |~| ctrans c <> c && upperTriang c
151expmDiagProp m = expm (logm m) :~ 7 ~: complex m 152expmDiagProp m = expm (logm m) :~ 7 ~: complex m
152 where logm m = matFunc log m 153 where logm m = matFunc log m
153 154
155multProp1 (a,b) = a <> b |~| mulH a b
156
157multProp2 (a,b) = trans (a <> b) |~| trans b <> trans a