diff options
author | Alberto Ruiz <aruiz@um.es> | 2015-05-28 09:17:40 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2015-05-28 12:25:20 +0200 |
commit | 3f68d78613ed61540c38548fb3b7e8fca77a85d2 (patch) | |
tree | 229661ff3c051036b0ce422059fe0517a3dd5366 /packages/base | |
parent | f05ef81b63e4ee6403433919ce48f223cf0b1e45 (diff) |
use omat for multiplyI
Diffstat (limited to 'packages/base')
-rw-r--r-- | packages/base/src/C/lapack-aux.c | 25 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/LAPACK.hs | 10 |
2 files changed, 15 insertions, 20 deletions
diff --git a/packages/base/src/C/lapack-aux.c b/packages/base/src/C/lapack-aux.c index a977d5f..ac03120 100644 --- a/packages/base/src/C/lapack-aux.c +++ b/packages/base/src/C/lapack-aux.c | |||
@@ -1287,28 +1287,19 @@ int multiplyQ(int ta, int tb, KQMAT(a),KQMAT(b),QMAT(r)) { | |||
1287 | OK | 1287 | OK |
1288 | } | 1288 | } |
1289 | 1289 | ||
1290 | int multiplyI(int ta, int tb, KIMAT(a), KIMAT(b), IMAT(r)) { | 1290 | |
1291 | int i,j,k; | 1291 | int multiplyI(KOIMAT(a), KOIMAT(b), OIMAT(r)) { |
1292 | int n; | 1292 | { TRAV(r,i,j) { |
1293 | int ai,ak,bk,bj; | 1293 | int k; |
1294 | 1294 | AT(r,i,j) = 0; | |
1295 | n = ta ? ar : ac; | 1295 | for (k=0;k<ac;k++) { |
1296 | 1296 | AT(r,i,j) += AT(a,i,k) * AT(b,k,j); | |
1297 | if (ta==0) { ai = 1; ak = ar; } else { ai = ar; ak = 1; } | ||
1298 | if (tb==0) { bk = 1; bj = br; } else { bk = br; bj = 1; } | ||
1299 | |||
1300 | for (i=0;i<rr;i++) { | ||
1301 | for (j=0;j<rc;j++) { | ||
1302 | rp[i+rr*j] = 0; | ||
1303 | for (k=0; k<n; k++) { | ||
1304 | rp[i+rr*j] += ap[ai*i+ak*k] * bp[bk*k+bj*j]; | ||
1305 | } | ||
1306 | } | 1297 | } |
1298 | } | ||
1307 | } | 1299 | } |
1308 | OK | 1300 | OK |
1309 | } | 1301 | } |
1310 | 1302 | ||
1311 | |||
1312 | //////////////////// transpose ///////////////////////// | 1303 | //////////////////// transpose ///////////////////////// |
1313 | 1304 | ||
1314 | int transF(KFMAT(x),FMAT(t)) { | 1305 | int transF(KFMAT(x),FMAT(t)) { |
diff --git a/packages/base/src/Numeric/LinearAlgebra/LAPACK.hs b/packages/base/src/Numeric/LinearAlgebra/LAPACK.hs index 219d996..6fb2b13 100644 --- a/packages/base/src/Numeric/LinearAlgebra/LAPACK.hs +++ b/packages/base/src/Numeric/LinearAlgebra/LAPACK.hs | |||
@@ -58,7 +58,7 @@ foreign import ccall unsafe "multiplyR" dgemmc :: CInt -> CInt -> TMMM | |||
58 | foreign import ccall unsafe "multiplyC" zgemmc :: CInt -> CInt -> TCMCMCM | 58 | foreign import ccall unsafe "multiplyC" zgemmc :: CInt -> CInt -> TCMCMCM |
59 | foreign import ccall unsafe "multiplyF" sgemmc :: CInt -> CInt -> TFMFMFM | 59 | foreign import ccall unsafe "multiplyF" sgemmc :: CInt -> CInt -> TFMFMFM |
60 | foreign import ccall unsafe "multiplyQ" cgemmc :: CInt -> CInt -> TQMQMQM | 60 | foreign import ccall unsafe "multiplyQ" cgemmc :: CInt -> CInt -> TQMQMQM |
61 | foreign import ccall unsafe "multiplyI" c_multiplyI :: CInt -> CInt -> (CM CInt (CM CInt (CM CInt (IO CInt)))) | 61 | foreign import ccall unsafe "multiplyI" c_multiplyI :: OM CInt (OM CInt (OM CInt (IO CInt))) |
62 | 62 | ||
63 | isT Matrix{order = ColumnMajor} = 0 | 63 | isT Matrix{order = ColumnMajor} = 0 |
64 | isT Matrix{order = RowMajor} = 1 | 64 | isT Matrix{order = RowMajor} = 1 |
@@ -90,8 +90,12 @@ multiplyQ :: Matrix (Complex Float) -> Matrix (Complex Float) -> Matrix (Complex | |||
90 | multiplyQ a b = multiplyAux cgemmc "cgemmc" a b | 90 | multiplyQ a b = multiplyAux cgemmc "cgemmc" a b |
91 | 91 | ||
92 | multiplyI :: Matrix CInt -> Matrix CInt -> Matrix CInt | 92 | multiplyI :: Matrix CInt -> Matrix CInt -> Matrix CInt |
93 | multiplyI = multiplyAux c_multiplyI "c_multiplyI" | 93 | multiplyI a b = unsafePerformIO $ do |
94 | 94 | when (cols a /= rows b) $ error $ | |
95 | "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b | ||
96 | s <- createMatrix ColumnMajor (rows a) (cols b) | ||
97 | app3 c_multiplyI omat a omat b omat s "c_multiplyI" | ||
98 | return s | ||
95 | 99 | ||
96 | ----------------------------------------------------------------------------- | 100 | ----------------------------------------------------------------------------- |
97 | foreign import ccall unsafe "svd_l_R" dgesvd :: TMMVM | 101 | foreign import ccall unsafe "svd_l_R" dgesvd :: TMMVM |