diff options
Diffstat (limited to 'lib/Numeric')
-rw-r--r-- | lib/Numeric/GSL/Matrix.hs | 311 | ||||
-rw-r--r-- | lib/Numeric/GSL/gsl-aux.c | 286 | ||||
-rw-r--r-- | lib/Numeric/GSL/gsl-aux.h | 19 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Algorithms.hs | 171 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Interface.hs | 4 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | 74 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h | 6 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Linear.hs | 54 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Tests.hs | 5 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Tests/Instances.hs | 16 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Tests/Properties.hs | 6 |
11 files changed, 279 insertions, 673 deletions
diff --git a/lib/Numeric/GSL/Matrix.hs b/lib/Numeric/GSL/Matrix.hs deleted file mode 100644 index f62bb82..0000000 --- a/lib/Numeric/GSL/Matrix.hs +++ /dev/null | |||
@@ -1,311 +0,0 @@ | |||
1 | ----------------------------------------------------------------------------- | ||
2 | -- | | ||
3 | -- Module : Numeric.GSL.Matrix | ||
4 | -- Copyright : (c) Alberto Ruiz 2007 | ||
5 | -- License : GPL-style | ||
6 | -- | ||
7 | -- Maintainer : Alberto Ruiz <aruiz@um.es> | ||
8 | -- Stability : provisional | ||
9 | -- Portability : portable (uses FFI) | ||
10 | -- | ||
11 | -- A few linear algebra computations based on GSL. | ||
12 | -- | ||
13 | ----------------------------------------------------------------------------- | ||
14 | -- #hide | ||
15 | |||
16 | module Numeric.GSL.Matrix( | ||
17 | eigSg, eigHg, | ||
18 | svdg, | ||
19 | qr, qrPacked, unpackQR, | ||
20 | cholR, cholC, | ||
21 | luSolveR, luSolveC, | ||
22 | luR, luC | ||
23 | ) where | ||
24 | |||
25 | import Data.Packed.Internal | ||
26 | import Data.Packed.Matrix(ident) | ||
27 | import Numeric.GSL.Vector | ||
28 | import Foreign | ||
29 | import Complex | ||
30 | |||
31 | {- | eigendecomposition of a real symmetric matrix using /gsl_eigen_symmv/. | ||
32 | |||
33 | > > let (l,v) = eigS $ 'fromLists' [[1,2],[2,1]] | ||
34 | > > l | ||
35 | > 3.000 -1.000 | ||
36 | > | ||
37 | > > v | ||
38 | > 0.707 -0.707 | ||
39 | > 0.707 0.707 | ||
40 | > | ||
41 | > > v <> diag l <> trans v | ||
42 | > 1.000 2.000 | ||
43 | > 2.000 1.000 | ||
44 | |||
45 | -} | ||
46 | eigSg :: Matrix Double -> (Vector Double, Matrix Double) | ||
47 | eigSg = eigSg' . cmat | ||
48 | |||
49 | eigSg' m | ||
50 | | r == 1 = (fromList [cdat m `at` 0], singleton 1) | ||
51 | | otherwise = unsafePerformIO $ do | ||
52 | l <- createVector r | ||
53 | v <- createMatrix RowMajor r r | ||
54 | app3 c_eigS mat m vec l mat v "eigSg" | ||
55 | return (l,v) | ||
56 | where r = rows m | ||
57 | foreign import ccall "gsl-aux.h eigensystemR" c_eigS :: TMVM | ||
58 | |||
59 | ------------------------------------------------------------------ | ||
60 | |||
61 | |||
62 | |||
63 | {- | eigendecomposition of a complex hermitian matrix using /gsl_eigen_hermv/ | ||
64 | |||
65 | > > let (l,v) = eigH $ 'fromLists' [[1,2+i],[2-i,3]] | ||
66 | > | ||
67 | > > l | ||
68 | > 4.449 -0.449 | ||
69 | > | ||
70 | > > v | ||
71 | > -0.544 0.839 | ||
72 | > (-0.751,0.375) (-0.487,0.243) | ||
73 | > | ||
74 | > > v <> diag l <> (conjTrans) v | ||
75 | > 1.000 (2.000,1.000) | ||
76 | > (2.000,-1.000) 3.000 | ||
77 | |||
78 | -} | ||
79 | eigHg :: Matrix (Complex Double)-> (Vector Double, Matrix (Complex Double)) | ||
80 | eigHg = eigHg' . cmat | ||
81 | |||
82 | eigHg' m | ||
83 | | r == 1 = (fromList [realPart $ cdat m `at` 0], singleton 1) | ||
84 | | otherwise = unsafePerformIO $ do | ||
85 | l <- createVector r | ||
86 | v <- createMatrix RowMajor r r | ||
87 | app3 c_eigH mat m vec l mat v "eigHg" | ||
88 | return (l,v) | ||
89 | where r = rows m | ||
90 | foreign import ccall "gsl-aux.h eigensystemC" c_eigH :: TCMVCM | ||
91 | |||
92 | |||
93 | {- | Singular value decomposition of a real matrix, using /gsl_linalg_SV_decomp_mod/: | ||
94 | |||
95 | |||
96 | -} | ||
97 | svdg :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) | ||
98 | svdg x = if rows x >= cols x | ||
99 | then svd' (cmat x) | ||
100 | else (v, s, u) where (u,s,v) = svd' (cmat (trans x)) | ||
101 | |||
102 | svd' x = unsafePerformIO $ do | ||
103 | u <- createMatrix RowMajor r c | ||
104 | s <- createVector c | ||
105 | v <- createMatrix RowMajor c c | ||
106 | app4 c_svd mat x mat u vec s mat v "svdg" | ||
107 | return (u,s,v) | ||
108 | where r = rows x | ||
109 | c = cols x | ||
110 | foreign import ccall "gsl-aux.h svd" c_svd :: TMMVM | ||
111 | |||
112 | {- | QR decomposition of a real matrix using /gsl_linalg_QR_decomp/ and /gsl_linalg_QR_unpack/. | ||
113 | |||
114 | -} | ||
115 | qr :: Matrix Double -> (Matrix Double, Matrix Double) | ||
116 | qr = qr' . cmat | ||
117 | |||
118 | qr' x = unsafePerformIO $ do | ||
119 | q <- createMatrix RowMajor r r | ||
120 | rot <- createMatrix RowMajor r c | ||
121 | app3 c_qr mat x mat q mat rot "qr" | ||
122 | return (q,rot) | ||
123 | where r = rows x | ||
124 | c = cols x | ||
125 | foreign import ccall "gsl-aux.h QR" c_qr :: TMMM | ||
126 | |||
127 | qrPacked :: Matrix Double -> (Matrix Double, Vector Double) | ||
128 | qrPacked = qrPacked' . cmat | ||
129 | |||
130 | qrPacked' x = unsafePerformIO $ do | ||
131 | qrp <- createMatrix RowMajor r c | ||
132 | tau <- createVector (min r c) | ||
133 | app3 c_qrPacked mat x mat qrp vec tau "qrUnpacked" | ||
134 | return (qrp,tau) | ||
135 | where r = rows x | ||
136 | c = cols x | ||
137 | foreign import ccall "gsl-aux.h QRpacked" c_qrPacked :: TMMV | ||
138 | |||
139 | unpackQR :: (Matrix Double, Vector Double) -> (Matrix Double, Matrix Double) | ||
140 | unpackQR (qrp,tau) = unpackQR' (cmat qrp, tau) | ||
141 | |||
142 | unpackQR' (qrp,tau) = unsafePerformIO $ do | ||
143 | q <- createMatrix RowMajor r r | ||
144 | res <- createMatrix RowMajor r c | ||
145 | app4 c_qrUnpack mat qrp vec tau mat q mat res "qrUnpack" | ||
146 | return (q,res) | ||
147 | where r = rows qrp | ||
148 | c = cols qrp | ||
149 | foreign import ccall "gsl-aux.h QRunpack" c_qrUnpack :: TMVMM | ||
150 | |||
151 | {- | Cholesky decomposition of a symmetric positive definite real matrix using /gsl_linalg_cholesky_decomp/. | ||
152 | |||
153 | @\> chol $ (2><2) [1,2, | ||
154 | 2,9::Double] | ||
155 | (2><2) | ||
156 | [ 1.0, 0.0 | ||
157 | , 2.0, 2.23606797749979 ]@ | ||
158 | |||
159 | -} | ||
160 | cholR :: Matrix Double -> Matrix Double | ||
161 | cholR = cholR' . cmat | ||
162 | |||
163 | cholR' x = unsafePerformIO $ do | ||
164 | r <- createMatrix RowMajor n n | ||
165 | app2 c_cholR mat x mat r "cholR" | ||
166 | return r | ||
167 | where n = rows x | ||
168 | foreign import ccall "gsl-aux.h cholR" c_cholR :: TMM | ||
169 | |||
170 | cholC :: Matrix (Complex Double) -> Matrix (Complex Double) | ||
171 | cholC = cholC' . cmat | ||
172 | |||
173 | cholC' x = unsafePerformIO $ do | ||
174 | r <- createMatrix RowMajor n n | ||
175 | app2 c_cholC mat x mat r "cholC" | ||
176 | return r | ||
177 | where n = rows x | ||
178 | foreign import ccall "gsl-aux.h cholC" c_cholC :: TCMCM | ||
179 | |||
180 | |||
181 | -------------------------------------------------------- | ||
182 | |||
183 | {- -| efficient multiplication by the inverse of a matrix (for real matrices) | ||
184 | -} | ||
185 | luSolveR :: Matrix Double -> Matrix Double -> Matrix Double | ||
186 | luSolveR a b = luSolveR' (cmat a) (cmat b) | ||
187 | |||
188 | luSolveR' a b | ||
189 | | n1==n2 && n1==r = unsafePerformIO $ do | ||
190 | s <- createMatrix RowMajor r c | ||
191 | app3 c_luSolveR mat a mat b mat s "luSolveR" | ||
192 | return s | ||
193 | | otherwise = error "luSolveR of nonsquare matrix" | ||
194 | where n1 = rows a | ||
195 | n2 = cols a | ||
196 | r = rows b | ||
197 | c = cols b | ||
198 | foreign import ccall "gsl-aux.h luSolveR" c_luSolveR :: TMMM | ||
199 | |||
200 | {- -| efficient multiplication by the inverse of a matrix (for complex matrices). | ||
201 | -} | ||
202 | luSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | ||
203 | luSolveC a b = luSolveC' (cmat a) (cmat b) | ||
204 | |||
205 | luSolveC' a b | ||
206 | | n1==n2 && n1==r = unsafePerformIO $ do | ||
207 | s <- createMatrix RowMajor r c | ||
208 | app3 c_luSolveC mat a mat b mat s "luSolveC" | ||
209 | return s | ||
210 | | otherwise = error "luSolveC of nonsquare matrix" | ||
211 | where n1 = rows a | ||
212 | n2 = cols a | ||
213 | r = rows b | ||
214 | c = cols b | ||
215 | foreign import ccall "gsl-aux.h luSolveC" c_luSolveC :: TCMCMCM | ||
216 | |||
217 | {- | lu decomposition of real matrix (packed as a vector including l, u, the permutation and sign) | ||
218 | -} | ||
219 | luRaux :: Matrix Double -> Vector Double | ||
220 | luRaux = luRaux' . cmat | ||
221 | |||
222 | luRaux' x = unsafePerformIO $ do | ||
223 | res <- createVector (r*r+r+1) | ||
224 | app2 c_luRaux mat x vec res "luRaux" | ||
225 | return res | ||
226 | where r = rows x | ||
227 | foreign import ccall "gsl-aux.h luRaux" c_luRaux :: TMV | ||
228 | |||
229 | {- | lu decomposition of complex matrix (packed as a vector including l, u, the permutation and sign) | ||
230 | -} | ||
231 | luCaux :: Matrix (Complex Double) -> Vector (Complex Double) | ||
232 | luCaux = luCaux' . cmat | ||
233 | |||
234 | luCaux' x = unsafePerformIO $ do | ||
235 | res <- createVector (r*r+r+1) | ||
236 | app2 c_luCaux mat x vec res "luCaux" | ||
237 | return res | ||
238 | where r = rows x | ||
239 | foreign import ccall "gsl-aux.h luCaux" c_luCaux :: TCMCV | ||
240 | |||
241 | {- | The LU decomposition of a square matrix. Is based on /gsl_linalg_LU_decomp/ and /gsl_linalg_complex_LU_decomp/ as described in <http://www.gnu.org/software/Numeric.GSL/manual/Numeric.GSL-ref_13.html#SEC223>. | ||
242 | |||
243 | @\> let m = 'fromLists' [[1,2,-3],[2+3*i,-7,0],[1,-i,2*i]] | ||
244 | \> let (l,u,p,s) = luR m@ | ||
245 | |||
246 | L is the lower triangular: | ||
247 | |||
248 | @\> l | ||
249 | 1. 0. 0. | ||
250 | 0.154-0.231i 1. 0. | ||
251 | 0.154-0.231i 0.624-0.522i 1.@ | ||
252 | |||
253 | U is the upper triangular: | ||
254 | |||
255 | @\> u | ||
256 | 2.+3.i -7. 0. | ||
257 | 0. 3.077-1.615i -3. | ||
258 | 0. 0. 1.873+0.433i@ | ||
259 | |||
260 | p is a permutation: | ||
261 | |||
262 | @\> p | ||
263 | [1,0,2]@ | ||
264 | |||
265 | L \* U obtains a permuted version of the original matrix: | ||
266 | |||
267 | @\> extractRows p m | ||
268 | 2.+3.i -7. 0. | ||
269 | 1. 2. -3. | ||
270 | 1. -1.i 2.i | ||
271 | \ -- CPP | ||
272 | \> l \<\> u | ||
273 | 2.+3.i -7. 0. | ||
274 | 1. 2. -3. | ||
275 | 1. -1.i 2.i@ | ||
276 | |||
277 | s is the sign of the permutation, required to obtain sign of the determinant: | ||
278 | |||
279 | @\> s * product ('toList' $ 'takeDiag' u) | ||
280 | (-18.0) :+ (-16.000000000000004) | ||
281 | \> 'LinearAlgebra.Algorithms.det' m | ||
282 | (-18.0) :+ (-16.000000000000004)@ | ||
283 | |||
284 | -} | ||
285 | luR :: Matrix Double -> (Matrix Double, Matrix Double, [Int], Double) | ||
286 | luR m = (l,u,p, fromIntegral s') where | ||
287 | r = rows m | ||
288 | v = luRaux m | ||
289 | lu = reshape r $ subVector 0 (r*r) v | ||
290 | s':p = map round . toList . subVector (r*r) (r+1) $ v | ||
291 | u = triang r r 0 1`mul` lu | ||
292 | l = (triang r r 0 0 `mul` lu) `add` ident r | ||
293 | add = liftMatrix2 $ vectorZipR Add | ||
294 | mul = liftMatrix2 $ vectorZipR Mul | ||
295 | |||
296 | -- | Complex version of 'luR'. | ||
297 | luC :: Matrix (Complex Double) -> (Matrix (Complex Double), Matrix (Complex Double), [Int], Complex Double) | ||
298 | luC m = (l,u,p, fromIntegral s') where | ||
299 | r = rows m | ||
300 | v = luCaux m | ||
301 | lu = reshape r $ subVector 0 (r*r) v | ||
302 | s':p = map (round.realPart) . toList . subVector (r*r) (r+1) $ v | ||
303 | u = triang r r 0 1 `mul` lu | ||
304 | l = (triang r r 0 0 `mul` lu) `add` liftMatrix comp (ident r) | ||
305 | add = liftMatrix2 $ vectorZipC Add | ||
306 | mul = liftMatrix2 $ vectorZipC Mul | ||
307 | |||
308 | {- auxiliary function to get triangular matrices | ||
309 | -} | ||
310 | triang r c h v = reshape c $ fromList [el i j | i<-[0..r-1], j<-[0..c-1]] | ||
311 | where el i j = if j-i>=h then v else 1 - v | ||
diff --git a/lib/Numeric/GSL/gsl-aux.c b/lib/Numeric/GSL/gsl-aux.c index bd0a6bd..052cafd 100644 --- a/lib/Numeric/GSL/gsl-aux.c +++ b/lib/Numeric/GSL/gsl-aux.c | |||
@@ -1,11 +1,8 @@ | |||
1 | #include "gsl-aux.h" | 1 | #include "gsl-aux.h" |
2 | #include <gsl/gsl_blas.h> | 2 | #include <gsl/gsl_blas.h> |
3 | #include <gsl/gsl_linalg.h> | ||
4 | #include <gsl/gsl_matrix.h> | ||
5 | #include <gsl/gsl_math.h> | 3 | #include <gsl/gsl_math.h> |
6 | #include <gsl/gsl_errno.h> | 4 | #include <gsl/gsl_errno.h> |
7 | #include <gsl/gsl_fft_complex.h> | 5 | #include <gsl/gsl_fft_complex.h> |
8 | #include <gsl/gsl_eigen.h> | ||
9 | #include <gsl/gsl_integration.h> | 6 | #include <gsl/gsl_integration.h> |
10 | #include <gsl/gsl_deriv.h> | 7 | #include <gsl/gsl_deriv.h> |
11 | #include <gsl/gsl_poly.h> | 8 | #include <gsl/gsl_poly.h> |
@@ -161,47 +158,6 @@ int mapC(int code, KCVEC(x), CVEC(r)) { | |||
161 | } | 158 | } |
162 | 159 | ||
163 | 160 | ||
164 | /* | ||
165 | int scaleR(double* alpha, KRVEC(x), RVEC(r)) { | ||
166 | REQUIRES(xn == rn,BAD_SIZE); | ||
167 | DEBUGMSG("scaleR"); | ||
168 | KDVVIEW(x); | ||
169 | DVVIEW(r); | ||
170 | CHECK( gsl_vector_memcpy(V(r),V(x)) , MEM); | ||
171 | int res = gsl_vector_scale(V(r),*alpha); | ||
172 | CHECK(res,res); | ||
173 | OK | ||
174 | } | ||
175 | |||
176 | int scaleC(gsl_complex *alpha, KCVEC(x), CVEC(r)) { | ||
177 | REQUIRES(xn == rn,BAD_SIZE); | ||
178 | DEBUGMSG("scaleC"); | ||
179 | //KCVVIEW(x); | ||
180 | CVVIEW(r); | ||
181 | gsl_vector_const_view vrx = gsl_vector_const_view_array((double*)xp,xn*2); | ||
182 | gsl_vector_view vrr = gsl_vector_view_array((double*)rp,rn*2); | ||
183 | CHECK(gsl_vector_memcpy(V(vrr),V(vrx)) , MEM); | ||
184 | gsl_blas_zscal(*alpha,V(r)); // void ! | ||
185 | int res = 0; | ||
186 | CHECK(res,res); | ||
187 | OK | ||
188 | } | ||
189 | |||
190 | int addConstantR(double offs, KRVEC(x), RVEC(r)) { | ||
191 | REQUIRES(xn == rn,BAD_SIZE); | ||
192 | DEBUGMSG("addConstantR"); | ||
193 | KDVVIEW(x); | ||
194 | DVVIEW(r); | ||
195 | CHECK(gsl_vector_memcpy(V(r),V(x)), MEM); | ||
196 | int res = gsl_vector_add_constant(V(r),offs); | ||
197 | CHECK(res,res); | ||
198 | OK | ||
199 | } | ||
200 | |||
201 | */ | ||
202 | |||
203 | |||
204 | |||
205 | int mapValR(int code, double* pval, KRVEC(x), RVEC(r)) { | 161 | int mapValR(int code, double* pval, KRVEC(x), RVEC(r)) { |
206 | int k; | 162 | int k; |
207 | double val = *pval; | 163 | double val = *pval; |
@@ -291,248 +247,6 @@ int zipC(int code, KCVEC(a), KCVEC(b), CVEC(r)) { | |||
291 | 247 | ||
292 | 248 | ||
293 | 249 | ||
294 | |||
295 | int luSolveR(KRMAT(a),KRMAT(b),RMAT(r)) { | ||
296 | REQUIRES(ar==ac && ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
297 | DEBUGMSG("luSolveR"); | ||
298 | KDMVIEW(a); | ||
299 | KDMVIEW(b); | ||
300 | DMVIEW(r); | ||
301 | int res; | ||
302 | gsl_matrix *LU = gsl_matrix_alloc(ar,ar); | ||
303 | CHECK(!LU,MEM); | ||
304 | int s; | ||
305 | gsl_permutation * p = gsl_permutation_alloc (ar); | ||
306 | CHECK(!p,MEM); | ||
307 | CHECK(gsl_matrix_memcpy(LU,M(a)),MEM); | ||
308 | res = gsl_linalg_LU_decomp(LU, p, &s); | ||
309 | CHECK(res,res); | ||
310 | int c; | ||
311 | |||
312 | for (c=0; c<bc; c++) { | ||
313 | gsl_vector_const_view colb = gsl_matrix_const_column (M(b), c); | ||
314 | gsl_vector_view colr = gsl_matrix_column (M(r), c); | ||
315 | res = gsl_linalg_LU_solve (LU, p, V(colb), V(colr)); | ||
316 | CHECK(res,res); | ||
317 | } | ||
318 | gsl_permutation_free(p); | ||
319 | gsl_matrix_free(LU); | ||
320 | OK | ||
321 | } | ||
322 | |||
323 | |||
324 | int luSolveC(KCMAT(a),KCMAT(b),CMAT(r)) { | ||
325 | REQUIRES(ar==ac && ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
326 | DEBUGMSG("luSolveC"); | ||
327 | KCMVIEW(a); | ||
328 | KCMVIEW(b); | ||
329 | CMVIEW(r); | ||
330 | gsl_matrix_complex *LU = gsl_matrix_complex_alloc(ar,ar); | ||
331 | CHECK(!LU,MEM); | ||
332 | int s; | ||
333 | gsl_permutation * p = gsl_permutation_alloc (ar); | ||
334 | CHECK(!p,MEM); | ||
335 | CHECK(gsl_matrix_complex_memcpy(LU,M(a)),MEM); | ||
336 | int res; | ||
337 | res = gsl_linalg_complex_LU_decomp(LU, p, &s); | ||
338 | CHECK(res,res); | ||
339 | int c; | ||
340 | for (c=0; c<bc; c++) { | ||
341 | gsl_vector_complex_const_view colb = gsl_matrix_complex_const_column (M(b), c); | ||
342 | gsl_vector_complex_view colr = gsl_matrix_complex_column (M(r), c); | ||
343 | res = gsl_linalg_complex_LU_solve (LU, p, V(colb), V(colr)); | ||
344 | CHECK(res,res); | ||
345 | } | ||
346 | gsl_permutation_free(p); | ||
347 | gsl_matrix_complex_free(LU); | ||
348 | OK | ||
349 | } | ||
350 | |||
351 | |||
352 | int luRaux(KRMAT(a),RVEC(b)) { | ||
353 | REQUIRES(ar==ac && bn==ar*ar+ar+1,BAD_SIZE); | ||
354 | DEBUGMSG("luRaux"); | ||
355 | KDMVIEW(a); | ||
356 | //DVVIEW(b); | ||
357 | gsl_matrix_view LU = gsl_matrix_view_array(bp,ar,ac); | ||
358 | int s; | ||
359 | gsl_permutation * p = gsl_permutation_alloc (ar); | ||
360 | CHECK(!p,MEM); | ||
361 | CHECK(gsl_matrix_memcpy(M(LU),M(a)),MEM); | ||
362 | gsl_linalg_LU_decomp(M(LU), p, &s); | ||
363 | bp[ar*ar] = s; | ||
364 | int k; | ||
365 | for (k=0; k<ar; k++) { | ||
366 | bp[ar*ar+k+1] = gsl_permutation_get(p,k); | ||
367 | } | ||
368 | gsl_permutation_free(p); | ||
369 | OK | ||
370 | } | ||
371 | |||
372 | int luCaux(KCMAT(a),CVEC(b)) { | ||
373 | REQUIRES(ar==ac && bn==ar*ar+ar+1,BAD_SIZE); | ||
374 | DEBUGMSG("luCaux"); | ||
375 | KCMVIEW(a); | ||
376 | //DVVIEW(b); | ||
377 | gsl_matrix_complex_view LU = gsl_matrix_complex_view_array((double*)bp,ar,ac); | ||
378 | int s; | ||
379 | gsl_permutation * p = gsl_permutation_alloc (ar); | ||
380 | CHECK(!p,MEM); | ||
381 | CHECK(gsl_matrix_complex_memcpy(M(LU),M(a)),MEM); | ||
382 | int res; | ||
383 | res = gsl_linalg_complex_LU_decomp(M(LU), p, &s); | ||
384 | CHECK(res,res); | ||
385 | ((double*)bp)[2*ar*ar] = s; | ||
386 | ((double*)bp)[2*ar*ar+1] = 0; | ||
387 | int k; | ||
388 | for (k=0; k<ar; k++) { | ||
389 | ((double*)bp)[2*ar*ar+2*k+2] = gsl_permutation_get(p,k); | ||
390 | ((double*)bp)[2*ar*ar+2*k+2+1] = 0; | ||
391 | } | ||
392 | gsl_permutation_free(p); | ||
393 | OK | ||
394 | } | ||
395 | |||
396 | int svd(KRMAT(a),RMAT(u), RVEC(s),RMAT(v)) { | ||
397 | REQUIRES(ar==ur && ac==uc && ac==sn && ac==vr && ac==vc,BAD_SIZE); | ||
398 | DEBUGMSG("svd"); | ||
399 | KDMVIEW(a); | ||
400 | DMVIEW(u); | ||
401 | DVVIEW(s); | ||
402 | DMVIEW(v); | ||
403 | gsl_vector *workv = gsl_vector_alloc(ac); | ||
404 | CHECK(!workv,MEM); | ||
405 | gsl_matrix *workm = gsl_matrix_alloc(ac,ac); | ||
406 | CHECK(!workm,MEM); | ||
407 | CHECK(gsl_matrix_memcpy(M(u),M(a)),MEM); | ||
408 | // int res = gsl_linalg_SV_decomp_jacobi (&U.matrix, &V.matrix, &S.vector); | ||
409 | // doesn't work | ||
410 | //int res = gsl_linalg_SV_decomp (&U.matrix, &V.matrix, &S.vector, workv); | ||
411 | int res = gsl_linalg_SV_decomp_mod (M(u), workm, M(v), V(s), workv); | ||
412 | CHECK(res,res); | ||
413 | //gsl_matrix_transpose(M(v)); | ||
414 | gsl_vector_free(workv); | ||
415 | gsl_matrix_free(workm); | ||
416 | OK | ||
417 | } | ||
418 | |||
419 | |||
420 | // for real symmetric matrices | ||
421 | int eigensystemR(KRMAT(x),RVEC(l),RMAT(v)) { | ||
422 | REQUIRES(xr==xc && xr==ln && xr==vr && vr==vc,BAD_SIZE); | ||
423 | DEBUGMSG("eigensystemR (gsl_eigen_symmv)"); | ||
424 | KDMVIEW(x); | ||
425 | DVVIEW(l); | ||
426 | DMVIEW(v); | ||
427 | gsl_matrix *XC = gsl_matrix_alloc(xr,xr); | ||
428 | gsl_matrix_memcpy(XC,M(x)); // needed because the argument is destroyed | ||
429 | // many thanks to Nico Mahlo for the bug report | ||
430 | gsl_eigen_symmv_workspace * w = gsl_eigen_symmv_alloc (xc); | ||
431 | int res = gsl_eigen_symmv (XC, V(l), M(v), w); | ||
432 | CHECK(res,res); | ||
433 | gsl_eigen_symmv_free (w); | ||
434 | gsl_matrix_free(XC); | ||
435 | gsl_eigen_symmv_sort (V(l), M(v), GSL_EIGEN_SORT_ABS_DESC); | ||
436 | OK | ||
437 | } | ||
438 | |||
439 | // for hermitian matrices | ||
440 | int eigensystemC(KCMAT(x),RVEC(l),CMAT(v)) { | ||
441 | REQUIRES(xr==xc && xr==ln && xr==vr && vr==vc,BAD_SIZE); | ||
442 | DEBUGMSG("eigensystemC"); | ||
443 | KCMVIEW(x); | ||
444 | DVVIEW(l); | ||
445 | CMVIEW(v); | ||
446 | gsl_matrix_complex *XC = gsl_matrix_complex_alloc(xr,xr); | ||
447 | gsl_matrix_complex_memcpy(XC,M(x)); // again needed because the argument is destroyed | ||
448 | gsl_eigen_hermv_workspace * w = gsl_eigen_hermv_alloc (xc); | ||
449 | int res = gsl_eigen_hermv (XC, V(l), M(v), w); | ||
450 | CHECK(res,res); | ||
451 | gsl_eigen_hermv_free (w); | ||
452 | gsl_matrix_complex_free(XC); | ||
453 | gsl_eigen_hermv_sort (V(l), M(v), GSL_EIGEN_SORT_ABS_DESC); | ||
454 | OK | ||
455 | } | ||
456 | |||
457 | int QR(KRMAT(x),RMAT(q),RMAT(r)) { | ||
458 | REQUIRES(xr==rr && xc==rc && qr==qc && xr==qr,BAD_SIZE); | ||
459 | DEBUGMSG("QR"); | ||
460 | KDMVIEW(x); | ||
461 | DMVIEW(q); | ||
462 | DMVIEW(r); | ||
463 | gsl_matrix * a = gsl_matrix_alloc(xr,xc); | ||
464 | gsl_vector * tau = gsl_vector_alloc(MIN(xr,xc)); | ||
465 | gsl_matrix_memcpy(a,M(x)); | ||
466 | int res = gsl_linalg_QR_decomp(a,tau); | ||
467 | CHECK(res,res); | ||
468 | gsl_linalg_QR_unpack(a,tau,M(q),M(r)); | ||
469 | gsl_vector_free(tau); | ||
470 | gsl_matrix_free(a); | ||
471 | OK | ||
472 | } | ||
473 | |||
474 | int QRpacked(KRMAT(x),RMAT(qr),RVEC(tau)) { | ||
475 | //REQUIRES(xr==rr && xc==rc && qr==qc && xr==qr,BAD_SIZE); | ||
476 | DEBUGMSG("QRpacked"); | ||
477 | KDMVIEW(x); | ||
478 | DMVIEW(qr); | ||
479 | DVVIEW(tau); | ||
480 | //gsl_matrix * a = gsl_matrix_alloc(xr,xc); | ||
481 | //gsl_vector * tau = gsl_vector_alloc(MIN(xr,xc)); | ||
482 | gsl_matrix_memcpy(M(qr),M(x)); | ||
483 | int res = gsl_linalg_QR_decomp(M(qr),V(tau)); | ||
484 | CHECK(res,res); | ||
485 | OK | ||
486 | } | ||
487 | |||
488 | |||
489 | int QRunpack(KRMAT(xqr),KRVEC(tau),RMAT(q),RMAT(r)) { | ||
490 | //REQUIRES(xr==rr && xc==rc && qr==qc && xr==qr,BAD_SIZE); | ||
491 | DEBUGMSG("QRunpack"); | ||
492 | KDMVIEW(xqr); | ||
493 | KDVVIEW(tau); | ||
494 | DMVIEW(q); | ||
495 | DMVIEW(r); | ||
496 | gsl_linalg_QR_unpack(M(xqr),V(tau),M(q),M(r)); | ||
497 | OK | ||
498 | } | ||
499 | |||
500 | |||
501 | int cholR(KRMAT(x),RMAT(l)) { | ||
502 | REQUIRES(xr==xc && lr==xr && lr==lc,BAD_SIZE); | ||
503 | DEBUGMSG("cholR"); | ||
504 | KDMVIEW(x); | ||
505 | DMVIEW(l); | ||
506 | gsl_matrix_memcpy(M(l),M(x)); | ||
507 | int res = gsl_linalg_cholesky_decomp(M(l)); | ||
508 | CHECK(res,res); | ||
509 | int r,c; | ||
510 | for (r=0; r<xr-1; r++) { | ||
511 | for(c=r+1; c<xc; c++) { | ||
512 | lp[r*lc+c] = 0.; | ||
513 | } | ||
514 | } | ||
515 | OK | ||
516 | } | ||
517 | |||
518 | int cholC(KCMAT(x),CMAT(l)) { | ||
519 | REQUIRES(xr==xc && lr==xr && lr==lc,BAD_SIZE); | ||
520 | DEBUGMSG("cholC"); | ||
521 | KCMVIEW(x); | ||
522 | CMVIEW(l); | ||
523 | gsl_matrix_complex_memcpy(M(l),M(x)); | ||
524 | int res = 0; // gsl_linalg_complex_cholesky_decomp(M(l)); | ||
525 | CHECK(res,res); | ||
526 | gsl_complex zero = {0.,0.}; | ||
527 | int r,c; | ||
528 | for (r=0; r<xr-1; r++) { | ||
529 | for(c=r+1; c<xc; c++) { | ||
530 | lp[r*lc+c] = zero; | ||
531 | } | ||
532 | } | ||
533 | OK | ||
534 | } | ||
535 | |||
536 | int fft(int code, KCVEC(X), CVEC(R)) { | 250 | int fft(int code, KCVEC(X), CVEC(R)) { |
537 | REQUIRES(Xn == Rn,BAD_SIZE); | 251 | REQUIRES(Xn == Rn,BAD_SIZE); |
538 | DEBUGMSG("fft"); | 252 | DEBUGMSG("fft"); |
diff --git a/lib/Numeric/GSL/gsl-aux.h b/lib/Numeric/GSL/gsl-aux.h index eee15e7..cd17ef0 100644 --- a/lib/Numeric/GSL/gsl-aux.h +++ b/lib/Numeric/GSL/gsl-aux.h | |||
@@ -26,25 +26,6 @@ int zipR(int code, KRVEC(a), KRVEC(b), RVEC(r)); | |||
26 | int zipC(int code, KCVEC(a), KCVEC(b), CVEC(r)); | 26 | int zipC(int code, KCVEC(a), KCVEC(b), CVEC(r)); |
27 | 27 | ||
28 | 28 | ||
29 | int luSolveR(KRMAT(a),KRMAT(b),RMAT(r)); | ||
30 | int luSolveC(KCMAT(a),KCMAT(b),CMAT(r)); | ||
31 | int luRaux(KRMAT(a),RVEC(b)); | ||
32 | int luCaux(KCMAT(a),CVEC(b)); | ||
33 | |||
34 | int svd(KRMAT(x),RMAT(u), RVEC(s),RMAT(v)); | ||
35 | |||
36 | int eigensystemR(KRMAT(x),RVEC(l),RMAT(v)); | ||
37 | int eigensystemC(KCMAT(x),RVEC(l),CMAT(v)); | ||
38 | |||
39 | int QR(KRMAT(x),RMAT(q),RMAT(r)); | ||
40 | |||
41 | int QRpacked(KRMAT(x),RMAT(qr),RVEC(tau)); | ||
42 | int QRunpack(KRMAT(qr),KRVEC(tau),RMAT(q),RMAT(r)); | ||
43 | |||
44 | int cholR(KRMAT(x),RMAT(l)); | ||
45 | |||
46 | int cholC(KCMAT(x),CMAT(l)); | ||
47 | |||
48 | int fft(int code, KCVEC(a), CVEC(b)); | 29 | int fft(int code, KCVEC(a), CVEC(b)); |
49 | 30 | ||
50 | int integrate_qng(double f(double, void*), double a, double b, double prec, | 31 | int integrate_qng(double f(double, void*), double a, double b, double prec, |
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs index bbc5986..c7118c1 100644 --- a/lib/Numeric/LinearAlgebra/Algorithms.hs +++ b/lib/Numeric/LinearAlgebra/Algorithms.hs | |||
@@ -20,6 +20,7 @@ imported from "Numeric.LinearAlgebra.LAPACK". | |||
20 | 20 | ||
21 | module Numeric.LinearAlgebra.Algorithms ( | 21 | module Numeric.LinearAlgebra.Algorithms ( |
22 | -- * Linear Systems | 22 | -- * Linear Systems |
23 | multiply, dot, | ||
23 | linearSolve, | 24 | linearSolve, |
24 | inv, pinv, | 25 | inv, pinv, |
25 | pinvTol, det, rank, rcond, | 26 | pinvTol, det, rank, rcond, |
@@ -51,6 +52,8 @@ module Numeric.LinearAlgebra.Algorithms ( | |||
51 | -- * Misc | 52 | -- * Misc |
52 | ctrans, | 53 | ctrans, |
53 | eps, i, | 54 | eps, i, |
55 | outer, kronecker, | ||
56 | mulH, | ||
54 | -- * Util | 57 | -- * Util |
55 | haussholder, | 58 | haussholder, |
56 | unpackQR, unpackHess, | 59 | unpackQR, unpackHess, |
@@ -60,13 +63,14 @@ module Numeric.LinearAlgebra.Algorithms ( | |||
60 | 63 | ||
61 | import Data.Packed.Internal hiding (fromComplex, toComplex, comp, conj, (//)) | 64 | import Data.Packed.Internal hiding (fromComplex, toComplex, comp, conj, (//)) |
62 | import Data.Packed | 65 | import Data.Packed |
63 | import qualified Numeric.GSL.Matrix as GSL | ||
64 | import Numeric.GSL.Vector | 66 | import Numeric.GSL.Vector |
65 | import Numeric.LinearAlgebra.LAPACK as LAPACK | 67 | import Numeric.LinearAlgebra.LAPACK as LAPACK |
66 | import Complex | 68 | import Complex |
67 | import Numeric.LinearAlgebra.Linear | 69 | import Numeric.LinearAlgebra.Linear |
68 | import Data.List(foldl1') | 70 | import Data.List(foldl1') |
69 | import Data.Array | 71 | import Data.Array |
72 | import Foreign | ||
73 | import Foreign.C.Types | ||
70 | 74 | ||
71 | -- | Auxiliary typeclass used to define generic computations for both real and complex matrices. | 75 | -- | Auxiliary typeclass used to define generic computations for both real and complex matrices. |
72 | class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where | 76 | class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where |
@@ -105,6 +109,7 @@ class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where | |||
105 | schur :: Matrix t -> (Matrix t, Matrix t) | 109 | schur :: Matrix t -> (Matrix t, Matrix t) |
106 | -- | Conjugate transpose. | 110 | -- | Conjugate transpose. |
107 | ctrans :: Matrix t -> Matrix t | 111 | ctrans :: Matrix t -> Matrix t |
112 | multiply :: Matrix t -> Matrix t -> Matrix t | ||
108 | 113 | ||
109 | 114 | ||
110 | instance Field Double where | 115 | instance Field Double where |
@@ -116,9 +121,10 @@ instance Field Double where | |||
116 | eig = eigR | 121 | eig = eigR |
117 | eigSH' = eigS | 122 | eigSH' = eigS |
118 | cholSH = cholS | 123 | cholSH = cholS |
119 | qr = GSL.unpackQR . qrR | 124 | qr = unpackQR . qrR |
120 | hess = unpackHess hessR | 125 | hess = unpackHess hessR |
121 | schur = schurR | 126 | schur = schurR |
127 | multiply = multiplyR3 | ||
122 | 128 | ||
123 | instance Field (Complex Double) where | 129 | instance Field (Complex Double) where |
124 | svd = svdC | 130 | svd = svdC |
@@ -132,6 +138,8 @@ instance Field (Complex Double) where | |||
132 | qr = unpackQR . qrC | 138 | qr = unpackQR . qrC |
133 | hess = unpackHess hessC | 139 | hess = unpackHess hessC |
134 | schur = schurC | 140 | schur = schurC |
141 | multiply = mulCW -- workaround | ||
142 | -- multiplyC3 | ||
135 | 143 | ||
136 | -- | Eigenvalues and Eigenvectors of a complex hermitian or real symmetric matrix using lapack's dsyev or zheev. | 144 | -- | Eigenvalues and Eigenvectors of a complex hermitian or real symmetric matrix using lapack's dsyev or zheev. |
137 | -- | 145 | -- |
@@ -501,3 +509,162 @@ luFact (lu,perm) | r <= c = (l ,u ,p, s) | |||
501 | u' = takeRows c (lu |*| tu) | 509 | u' = takeRows c (lu |*| tu) |
502 | (|+|) = add | 510 | (|+|) = add |
503 | (|*|) = mul | 511 | (|*|) = mul |
512 | |||
513 | -------------------------------------------------- | ||
514 | |||
515 | -- | euclidean inner product | ||
516 | dot :: (Field t) => Vector t -> Vector t -> t | ||
517 | dot u v = multiply r c @@> (0,0) | ||
518 | where r = asRow u | ||
519 | c = asColumn v | ||
520 | |||
521 | |||
522 | {- | Outer product of two vectors. | ||
523 | |||
524 | @\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3] | ||
525 | (3><3) | ||
526 | [ 5.0, 2.0, 3.0 | ||
527 | , 10.0, 4.0, 6.0 | ||
528 | , 15.0, 6.0, 9.0 ]@ | ||
529 | -} | ||
530 | outer :: (Field t) => Vector t -> Vector t -> Matrix t | ||
531 | outer u v = asColumn u `multiply` asRow v | ||
532 | |||
533 | {- | Kronecker product of two matrices. | ||
534 | |||
535 | @m1=(2><3) | ||
536 | [ 1.0, 2.0, 0.0 | ||
537 | , 0.0, -1.0, 3.0 ] | ||
538 | m2=(4><3) | ||
539 | [ 1.0, 2.0, 3.0 | ||
540 | , 4.0, 5.0, 6.0 | ||
541 | , 7.0, 8.0, 9.0 | ||
542 | , 10.0, 11.0, 12.0 ]@ | ||
543 | |||
544 | @\> kronecker m1 m2 | ||
545 | (8><9) | ||
546 | [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0 | ||
547 | , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0 | ||
548 | , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0 | ||
549 | , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0 | ||
550 | , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0 | ||
551 | , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0 | ||
552 | , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0 | ||
553 | , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ]@ | ||
554 | -} | ||
555 | kronecker :: (Field t) => Matrix t -> Matrix t -> Matrix t | ||
556 | kronecker a b = fromBlocks | ||
557 | . partit (cols a) | ||
558 | . map (reshape (cols b)) | ||
559 | . toRows | ||
560 | $ flatten a `outer` flatten b | ||
561 | |||
562 | --------------------------------------------------------------------- | ||
563 | -- reference multiply | ||
564 | --------------------------------------------------------------------- | ||
565 | |||
566 | mulH a b = fromLists [[ dot ai bj | bj <- toColumns b] | ai <- toRows a ] | ||
567 | where dot u v = sum $ zipWith (*) (toList u) (toList v) | ||
568 | |||
569 | ----------------------------------------------------------------------------------- | ||
570 | -- workaround | ||
571 | ----------------------------------------------------------------------------------- | ||
572 | |||
573 | mulCW a b = toComplex (rr,ri) | ||
574 | where rr = multiply ar br `sub` multiply ai bi | ||
575 | ri = multiply ar bi `add` multiply ai br | ||
576 | (ar,ai) = fromComplex a | ||
577 | (br,bi) = fromComplex b | ||
578 | |||
579 | ----------------------------------------------------------------------------------- | ||
580 | -- Direct CBLAS | ||
581 | ----------------------------------------------------------------------------------- | ||
582 | |||
583 | newtype CBLASOrder = CBLASOrder CInt deriving (Eq, Show) | ||
584 | newtype CBLASTrans = CBLASTrans CInt deriving (Eq, Show) | ||
585 | |||
586 | rowMajor, colMajor :: CBLASOrder | ||
587 | rowMajor = CBLASOrder 101 | ||
588 | colMajor = CBLASOrder 102 | ||
589 | |||
590 | noTrans, trans', conjTrans :: CBLASTrans | ||
591 | noTrans = CBLASTrans 111 | ||
592 | trans' = CBLASTrans 112 | ||
593 | conjTrans = CBLASTrans 113 | ||
594 | |||
595 | foreign import ccall "cblas.h cblas_dgemm" | ||
596 | dgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt -> Double -> Ptr Double -> CInt -> Ptr Double -> CInt -> Double -> Ptr Double -> CInt -> IO () | ||
597 | |||
598 | |||
599 | multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double | ||
600 | multiplyR3 a b = multiply3 dgemm "cblas_dgemm" (fmat a) (fmat b) | ||
601 | where | ||
602 | multiply3 f st a b | ||
603 | | cols a == rows b = unsafePerformIO $ do | ||
604 | s <- createMatrix ColumnMajor (rows a) (cols b) | ||
605 | let g ar ac ap br bc bp rr rc rp = f colMajor noTrans noTrans ar bc ac 1 ap ar bp br 0 rp rr >> return 0 | ||
606 | app3 g mat a mat b mat s st | ||
607 | return s | ||
608 | | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" | ||
609 | |||
610 | |||
611 | foreign import ccall "cblas.h cblas_zgemm" | ||
612 | zgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> IO () | ||
613 | |||
614 | multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | ||
615 | multiplyC3 a b = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat a) (fmat b) | ||
616 | where | ||
617 | multiply3 f st a b | ||
618 | | cols a == rows b = do | ||
619 | s <- createMatrix ColumnMajor (rows a) (cols b) | ||
620 | palpha <- new 1 | ||
621 | pbeta <- new 0 | ||
622 | let g ar ac ap br bc bp rr rc rp = f colMajor noTrans noTrans ar bc ac palpha ap ar bp br pbeta rp rr >> return 0 | ||
623 | app3 g mat a mat b mat s st | ||
624 | free palpha | ||
625 | free pbeta | ||
626 | return s | ||
627 | -- if toLists s== toLists s then return s else error $ "HORROR " ++ (show (toLists s)) | ||
628 | | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" | ||
629 | |||
630 | ----------------------------------------------------------------------------------- | ||
631 | -- BLAS via auxiliary C | ||
632 | ----------------------------------------------------------------------------------- | ||
633 | |||
634 | foreign import ccall "multiply.h multiplyR2" dgemmc :: TMMM | ||
635 | foreign import ccall "multiply.h multiplyC2" zgemmc :: TCMCMCM | ||
636 | |||
637 | multiply2 f st a b | ||
638 | | cols a == rows b = unsafePerformIO $ do | ||
639 | s <- createMatrix ColumnMajor (rows a) (cols b) | ||
640 | app3 f mat a mat b mat s st | ||
641 | if toLists s== toLists s then return s else error $ "AYYY " ++ (show (toLists s)) | ||
642 | | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" | ||
643 | |||
644 | multiplyR2 :: Matrix Double -> Matrix Double -> Matrix Double | ||
645 | multiplyR2 a b = multiply2 dgemmc "dgemmc" (fmat a) (fmat b) | ||
646 | |||
647 | multiplyC2 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | ||
648 | multiplyC2 a b = multiply2 zgemmc "zgemmc" (fmat a) (fmat b) | ||
649 | |||
650 | ----------------------------------------------------------------------------------- | ||
651 | -- direct C multiplication | ||
652 | ----------------------------------------------------------------------------------- | ||
653 | |||
654 | foreign import ccall "multiply.h multiplyR" cmultiplyR :: TMMM | ||
655 | foreign import ccall "multiply.h multiplyC" cmultiplyC :: TCMCMCM | ||
656 | |||
657 | cmultiply f st a b | ||
658 | -- | cols a == rows b = | ||
659 | = unsafePerformIO $ do | ||
660 | s <- createMatrix RowMajor (rows a) (cols b) | ||
661 | app3 f mat a mat b mat s st | ||
662 | if toLists s== toLists s then return s else error $ "BRUTAL " ++ (show (toLists s)) | ||
663 | -- return s | ||
664 | -- | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" | ||
665 | |||
666 | multiplyR :: Matrix Double -> Matrix Double -> Matrix Double | ||
667 | multiplyR a b = cmultiply cmultiplyR "cmultiplyR" (cmat a) (cmat b) | ||
668 | |||
669 | multiplyC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | ||
670 | multiplyC a b = cmultiply cmultiplyC "cmultiplyR" (cmat a) (cmat b) | ||
diff --git a/lib/Numeric/LinearAlgebra/Interface.hs b/lib/Numeric/LinearAlgebra/Interface.hs index 4a9b309..0ae9698 100644 --- a/lib/Numeric/LinearAlgebra/Interface.hs +++ b/lib/Numeric/LinearAlgebra/Interface.hs | |||
@@ -29,7 +29,7 @@ import Numeric.LinearAlgebra.Algorithms | |||
29 | class Mul a b c | a b -> c where | 29 | class Mul a b c | a b -> c where |
30 | infixl 7 <> | 30 | infixl 7 <> |
31 | -- | matrix product | 31 | -- | matrix product |
32 | (<>) :: Element t => a t -> b t -> c t | 32 | (<>) :: Field t => a t -> b t -> c t |
33 | 33 | ||
34 | instance Mul Matrix Matrix Matrix where | 34 | instance Mul Matrix Matrix Matrix where |
35 | (<>) = multiply | 35 | (<>) = multiply |
@@ -43,7 +43,7 @@ instance Mul Vector Matrix Vector where | |||
43 | --------------------------------------------------- | 43 | --------------------------------------------------- |
44 | 44 | ||
45 | -- | @u \<.\> v = dot u v@ | 45 | -- | @u \<.\> v = dot u v@ |
46 | (<.>) :: (Element t) => Vector t -> Vector t -> t | 46 | (<.>) :: (Field t) => Vector t -> Vector t -> t |
47 | infixl 7 <.> | 47 | infixl 7 <.> |
48 | (<.>) = dot | 48 | (<.>) = dot |
49 | 49 | ||
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c index 310f6ee..0dccea2 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | |||
@@ -814,3 +814,77 @@ int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)) { | |||
814 | free(auxipiv); | 814 | free(auxipiv); |
815 | OK | 815 | OK |
816 | } | 816 | } |
817 | |||
818 | //////////////////////////////////////////////////////////// | ||
819 | |||
820 | int multiplyR(KDMAT(a),KDMAT(b),DMAT(r)) { | ||
821 | REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
822 | int i,j,k; | ||
823 | for (i=0;i<ar;i++) { | ||
824 | for(j=0;j<bc;j++) { | ||
825 | double temp = 0; | ||
826 | for(k=0;k<ac;k++) { | ||
827 | temp += ap[i*ac+k]*bp[k*bc+j]; | ||
828 | } | ||
829 | rp[i*rc+j] = temp; | ||
830 | } | ||
831 | } | ||
832 | OK | ||
833 | } | ||
834 | |||
835 | int multiplyC(KCMAT(a),KCMAT(b),CMAT(r)) { | ||
836 | REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
837 | int i,j,k; | ||
838 | for (i=0;i<ar;i++) { | ||
839 | for(j=0;j<bc;j++) { | ||
840 | doublecomplex temp = {0,0}; | ||
841 | for(k=0;k<ac;k++) { | ||
842 | doublecomplex aik = ((doublecomplex*)ap)[i*ac+k]; | ||
843 | doublecomplex bkj = ((doublecomplex*)bp)[k*bc+j]; | ||
844 | //double w = aik.r+aik.i+bkj.r+bkj.i; | ||
845 | //if (w>w) exit(1); | ||
846 | //printf("%d",w>w); | ||
847 | temp.r += aik.r * bkj.r - aik.i * bkj.i; | ||
848 | temp.i += aik.r * bkj.i + aik.i * bkj.r; | ||
849 | //printf("%f %f %f %f \n",aik.r,aik.i,bkj.r,bkj.i); | ||
850 | //printf("%f %f %f \n",w,temp.r,temp.i); | ||
851 | |||
852 | } | ||
853 | ((doublecomplex*)rp)[i*rc+j] = temp; | ||
854 | //printf("%f %f\n",temp.r,temp.i); | ||
855 | } | ||
856 | } | ||
857 | OK | ||
858 | } | ||
859 | |||
860 | void dgemm_(char *, char *, integer *, integer *, integer *, | ||
861 | double *, const double *, integer *, const double *, | ||
862 | integer *, double *, double *, integer *); | ||
863 | |||
864 | void zgemm_(char *, char *, integer *, integer *, integer *, | ||
865 | doublecomplex *, const doublecomplex *, integer *, const doublecomplex *, | ||
866 | integer *, doublecomplex *, doublecomplex *, integer *); | ||
867 | |||
868 | |||
869 | int multiplyR2(KDMAT(a),KDMAT(b),DMAT(r)) { | ||
870 | REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
871 | double alpha = 1; | ||
872 | double beta = 0; | ||
873 | integer m = ar; | ||
874 | integer n = bc; | ||
875 | integer k = ac; | ||
876 | int i,j; | ||
877 | dgemm_("N","N",&m,&n,&k,&alpha,ap,&m,bp,&k,&beta,rp,&m); | ||
878 | OK | ||
879 | } | ||
880 | |||
881 | int multiplyC2(KCMAT(a),KCMAT(b),CMAT(r)) { | ||
882 | REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
883 | integer m = ar; | ||
884 | integer n = bc; | ||
885 | integer k = ac; | ||
886 | doublecomplex alpha = {1,0}; | ||
887 | doublecomplex beta = {0,0}; | ||
888 | zgemm_("N","N",&m,&n,&k,&alpha,(doublecomplex*)ap,&m,(doublecomplex*)bp,&k,&beta,(doublecomplex*)rp,&m); | ||
889 | OK | ||
890 | } | ||
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h index 79e52be..c0361a6 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h | |||
@@ -84,3 +84,9 @@ int schur_l_C(KCMAT(a), CMAT(u), CMAT(s)); | |||
84 | 84 | ||
85 | int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r)); | 85 | int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r)); |
86 | int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)); | 86 | int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)); |
87 | |||
88 | int multiplyR(KDMAT(a),KDMAT(b),DMAT(r)); | ||
89 | int multiplyC(KCMAT(a),KCMAT(b),CMAT(r)); | ||
90 | |||
91 | int multiplyR2(KDMAT(a),KDMAT(b),DMAT(r)); | ||
92 | int multiplyC2(KCMAT(a),KCMAT(b),CMAT(r)); | ||
diff --git a/lib/Numeric/LinearAlgebra/Linear.hs b/lib/Numeric/LinearAlgebra/Linear.hs index 0ddbb55..1bf8b04 100644 --- a/lib/Numeric/LinearAlgebra/Linear.hs +++ b/lib/Numeric/LinearAlgebra/Linear.hs | |||
@@ -15,12 +15,11 @@ Basic optimized operations on vectors and matrices. | |||
15 | ----------------------------------------------------------------------------- | 15 | ----------------------------------------------------------------------------- |
16 | 16 | ||
17 | module Numeric.LinearAlgebra.Linear ( | 17 | module Numeric.LinearAlgebra.Linear ( |
18 | Linear(..), | 18 | Linear(..) |
19 | multiply, dot, outer, kronecker | ||
20 | ) where | 19 | ) where |
21 | 20 | ||
22 | 21 | ||
23 | import Data.Packed.Internal(multiply,partit) | 22 | import Data.Packed.Internal(partit) |
24 | import Data.Packed | 23 | import Data.Packed |
25 | import Numeric.GSL.Vector | 24 | import Numeric.GSL.Vector |
26 | import Complex | 25 | import Complex |
@@ -69,52 +68,3 @@ instance (Linear Vector a, Container Matrix a) => (Linear Matrix a) where | |||
69 | mul = liftMatrix2 mul | 68 | mul = liftMatrix2 mul |
70 | divide = liftMatrix2 divide | 69 | divide = liftMatrix2 divide |
71 | equal a b = cols a == cols b && flatten a `equal` flatten b | 70 | equal a b = cols a == cols b && flatten a `equal` flatten b |
72 | |||
73 | -------------------------------------------------- | ||
74 | |||
75 | -- | euclidean inner product | ||
76 | dot :: (Element t) => Vector t -> Vector t -> t | ||
77 | dot u v = multiply r c @@> (0,0) | ||
78 | where r = asRow u | ||
79 | c = asColumn v | ||
80 | |||
81 | |||
82 | {- | Outer product of two vectors. | ||
83 | |||
84 | @\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3] | ||
85 | (3><3) | ||
86 | [ 5.0, 2.0, 3.0 | ||
87 | , 10.0, 4.0, 6.0 | ||
88 | , 15.0, 6.0, 9.0 ]@ | ||
89 | -} | ||
90 | outer :: (Element t) => Vector t -> Vector t -> Matrix t | ||
91 | outer u v = asColumn u `multiply` asRow v | ||
92 | |||
93 | {- | Kronecker product of two matrices. | ||
94 | |||
95 | @m1=(2><3) | ||
96 | [ 1.0, 2.0, 0.0 | ||
97 | , 0.0, -1.0, 3.0 ] | ||
98 | m2=(4><3) | ||
99 | [ 1.0, 2.0, 3.0 | ||
100 | , 4.0, 5.0, 6.0 | ||
101 | , 7.0, 8.0, 9.0 | ||
102 | , 10.0, 11.0, 12.0 ]@ | ||
103 | |||
104 | @\> kronecker m1 m2 | ||
105 | (8><9) | ||
106 | [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0 | ||
107 | , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0 | ||
108 | , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0 | ||
109 | , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0 | ||
110 | , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0 | ||
111 | , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0 | ||
112 | , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0 | ||
113 | , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ]@ | ||
114 | -} | ||
115 | kronecker :: (Element t) => Matrix t -> Matrix t -> Matrix t | ||
116 | kronecker a b = fromBlocks | ||
117 | . partit (cols a) | ||
118 | . map (reshape (cols b)) | ||
119 | . toRows | ||
120 | $ flatten a `outer` flatten b | ||
diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs index 7b28075..07b9f63 100644 --- a/lib/Numeric/LinearAlgebra/Tests.hs +++ b/lib/Numeric/LinearAlgebra/Tests.hs | |||
@@ -123,6 +123,11 @@ runTests :: Int -- ^ maximum dimension | |||
123 | runTests n = do | 123 | runTests n = do |
124 | setErrorHandlerOff | 124 | setErrorHandlerOff |
125 | let test p = qCheck n p | 125 | let test p = qCheck n p |
126 | putStrLn "------ mult" | ||
127 | test (multProp1 . rConsist) | ||
128 | test (multProp1 . cConsist) | ||
129 | test (multProp2 . rConsist) | ||
130 | test (multProp2 . cConsist) | ||
126 | putStrLn "------ lu" | 131 | putStrLn "------ lu" |
127 | test (luProp . rM) | 132 | test (luProp . rM) |
128 | test (luProp . cM) | 133 | test (luProp . cM) |
diff --git a/lib/Numeric/LinearAlgebra/Tests/Instances.hs b/lib/Numeric/LinearAlgebra/Tests/Instances.hs index af486c8..e7fecf2 100644 --- a/lib/Numeric/LinearAlgebra/Tests/Instances.hs +++ b/lib/Numeric/LinearAlgebra/Tests/Instances.hs | |||
@@ -20,6 +20,7 @@ module Numeric.LinearAlgebra.Tests.Instances( | |||
20 | WC(..), rWC,cWC, | 20 | WC(..), rWC,cWC, |
21 | SqWC(..), rSqWC, cSqWC, | 21 | SqWC(..), rSqWC, cSqWC, |
22 | PosDef(..), rPosDef, cPosDef, | 22 | PosDef(..), rPosDef, cPosDef, |
23 | Consistent(..), rConsist, cConsist, | ||
23 | RM,CM, rM,cM | 24 | RM,CM, rM,cM |
24 | ) where | 25 | ) where |
25 | 26 | ||
@@ -116,6 +117,19 @@ instance (Field a, Arbitrary a) => Arbitrary (PosDef a) where | |||
116 | return $ PosDef (0.5 .* p + 0.5 .* ctrans p) | 117 | return $ PosDef (0.5 .* p + 0.5 .* ctrans p) |
117 | coarbitrary = undefined | 118 | coarbitrary = undefined |
118 | 119 | ||
120 | -- a pair of matrices that can be multiplied | ||
121 | newtype (Consistent a) = Consistent (Matrix a, Matrix a) deriving Show | ||
122 | instance (Field a, Arbitrary a) => Arbitrary (Consistent a) where | ||
123 | arbitrary = do | ||
124 | n <- chooseDim | ||
125 | k <- chooseDim | ||
126 | m <- chooseDim | ||
127 | la <- vector (n*k) | ||
128 | lb <- vector (k*m) | ||
129 | return $ Consistent ((n><k) la, (k><m) lb) | ||
130 | coarbitrary = undefined | ||
131 | |||
132 | |||
119 | type RM = Matrix Double | 133 | type RM = Matrix Double |
120 | type CM = Matrix (Complex Double) | 134 | type CM = Matrix (Complex Double) |
121 | 135 | ||
@@ -140,3 +154,5 @@ cSqWC (SqWC m) = m :: CM | |||
140 | rPosDef (PosDef m) = m :: RM | 154 | rPosDef (PosDef m) = m :: RM |
141 | cPosDef (PosDef m) = m :: CM | 155 | cPosDef (PosDef m) = m :: CM |
142 | 156 | ||
157 | rConsist (Consistent (a,b)) = (a,b::RM) | ||
158 | cConsist (Consistent (a,b)) = (a,b::CM) | ||
diff --git a/lib/Numeric/LinearAlgebra/Tests/Properties.hs b/lib/Numeric/LinearAlgebra/Tests/Properties.hs index 55e9a1b..5663b86 100644 --- a/lib/Numeric/LinearAlgebra/Tests/Properties.hs +++ b/lib/Numeric/LinearAlgebra/Tests/Properties.hs | |||
@@ -34,7 +34,8 @@ module Numeric.LinearAlgebra.Tests.Properties ( | |||
34 | hessProp, | 34 | hessProp, |
35 | schurProp1, schurProp2, | 35 | schurProp1, schurProp2, |
36 | cholProp, | 36 | cholProp, |
37 | expmDiagProp | 37 | expmDiagProp, |
38 | multProp1, multProp2 | ||
38 | ) where | 39 | ) where |
39 | 40 | ||
40 | import Numeric.LinearAlgebra | 41 | import Numeric.LinearAlgebra |
@@ -151,3 +152,6 @@ cholProp m = m |~| ctrans c <> c && upperTriang c | |||
151 | expmDiagProp m = expm (logm m) :~ 7 ~: complex m | 152 | expmDiagProp m = expm (logm m) :~ 7 ~: complex m |
152 | where logm m = matFunc log m | 153 | where logm m = matFunc log m |
153 | 154 | ||
155 | multProp1 (a,b) = a <> b |~| mulH a b | ||
156 | |||
157 | multProp2 (a,b) = trans (a <> b) |~| trans b <> trans a | ||