summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--packages/base/src/C/lapack-aux.c88
-rw-r--r--packages/base/src/Data/Packed/Internal/Matrix.hs44
-rw-r--r--packages/base/src/Data/Packed/Internal/Numeric.hs66
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Data.hs4
4 files changed, 98 insertions, 104 deletions
diff --git a/packages/base/src/C/lapack-aux.c b/packages/base/src/C/lapack-aux.c
index e76d31e..c2cdc62 100644
--- a/packages/base/src/C/lapack-aux.c
+++ b/packages/base/src/C/lapack-aux.c
@@ -1290,19 +1290,18 @@ int multiplyQ(int ta, int tb, KQMAT(a),KQMAT(b),QMAT(r)) {
1290int multiplyI(int ta, int tb, KIMAT(a), KIMAT(b), IMAT(r)) { 1290int multiplyI(int ta, int tb, KIMAT(a), KIMAT(b), IMAT(r)) {
1291 int i,j,k; 1291 int i,j,k;
1292 int n; 1292 int n;
1293 int u, v; 1293 int ai,ak,bk,bj;
1294 if (ta==0) { 1294
1295 n = ac; 1295 n = ta ? ar : ac;
1296 } else { 1296
1297 n = ar; 1297 if (ta==0) { ai = 1; ak = ar; } else { ai = ar; ak = 1; }
1298 } 1298 if (tb==0) { bk = 1; bj = br; } else { bk = br; bj = 1; }
1299
1299 for (i=0;i<rr;i++) { 1300 for (i=0;i<rr;i++) {
1300 for (j=0;j<rc;j++) { 1301 for (j=0;j<rc;j++) {
1301 rp[i*rc+j] = 0; 1302 rp[i+rr*j] = 0;
1302 for (k=0; k<n; k++) { 1303 for (k=0; k<n; k++) {
1303 u = ta==0 ? ap[i*ac+k] : ap[k*ac+i]; 1304 rp[i+rr*j] += ap[ai*i+ak*k] * bp[bk*k+bj*j];
1304 v = tb==0 ? bp[k*bc+j] : bp[j*bc+k];
1305 rp[i*rc+j] += u*v;
1306 } 1305 }
1307 } 1306 }
1308 } 1307 }
@@ -1622,65 +1621,44 @@ int chooseD(KIVEC(cond),KDVEC(lt),KDVEC(eq),KDVEC(gt),DVEC(r)) {
1622 1621
1623//////////////////////// extract ///////////////////////////////// 1622//////////////////////// extract /////////////////////////////////
1624 1623
1625#define EXTRACT_IMP \ 1624#define EXTRACT_IMP \
1626 /*REQUIRES((tm == 0 && jn==rr && mc==rc) || (jn==rr && mr==rc) ,BAD_SIZE); */ \ 1625 int i,j,si,sj,ni,nj,ai,aj; \
1627 DEBUGMSG("extractRD") \ 1626 if (tm==0) { \
1628 int k,i,s; \ 1627 ai=mc; aj=1; \
1629 if (tm==0) { \ 1628 } else { \
1630 if (mode==0) { \ 1629 ai=1, aj=mc; \
1631 for (k=0; k<jp[1]-jp[0]+1; k++) { \ 1630 } \
1632 s = k + jp[0]; \ 1631 ni = modei ? in : ip[1]-ip[0]+1; \
1633 printf("%d\n",s); \ 1632 nj = modej ? jn : jp[1]-jp[0]+1; \
1634 for (i=0; i<mc; i++) { \ 1633 \
1635 rp[rc*k+i] = mp[mc*s+i]; \ 1634 for (i=0; i<ni; i++) { \
1636 } \ 1635 si = modei ? ip[i] : i+ip[0]; \
1637 } \ 1636 \
1638 } else { \ 1637 for (j=0; j<nj; j++) { \
1639 for (k=0;k<jn;k++) { \ 1638 sj = modej ? jp[j] : j+jp[0]; \
1640 s = jp[k]; \ 1639 \
1641 for (i=0; i<mc; i++) { \ 1640 rp[rc*i+j] = mp[ai*si+aj*sj]; \
1642 rp[rc*k+i] = mp[mc*s+i]; \ 1641 } \
1643 } \ 1642 } \
1644 } \
1645 } \
1646 } else { \
1647 if (mode==0) { \
1648 for (k=0; k<jp[1]-jp[0]+1; k++) { \
1649 s = k + jp[0]; \
1650 printf("%d\n",s); \
1651 for (i=0; i<mr; i++) { \
1652 rp[rc*k+i] = mp[mc*i+s]; \
1653 } \
1654 } \
1655 } else { \
1656 for (k=0;k<jn;k++) { \
1657 s = jp[k]; \
1658 for (i=0; i<mr; i++) { \
1659 rp[rc*k+i] = mp[mc*i+s]; \
1660 } \
1661 } \
1662 } \
1663 } \
1664 OK 1643 OK
1665 1644
1666 1645int extractD(int modei, int modej, int tm, KIVEC(i), KIVEC(j), KDMAT(m), DMAT(r)) {
1667int extractRD(int mode, int tm, KIVEC(j), KDMAT(m), DMAT(r)) {
1668 EXTRACT_IMP 1646 EXTRACT_IMP
1669} 1647}
1670 1648
1671int extractRF(int mode, int tm, KIVEC(j), KFMAT(m), FMAT(r)) { 1649int extractF(int modei, int modej, int tm, KIVEC(i), KIVEC(j), KFMAT(m), FMAT(r)) {
1672 EXTRACT_IMP 1650 EXTRACT_IMP
1673} 1651}
1674 1652
1675int extractRC(int mode, int tm, KIVEC(j), KCMAT(m), CMAT(r)) { 1653int extractC(int modei, int modej, int tm, KIVEC(i), KIVEC(j), KCMAT(m), CMAT(r)) {
1676 EXTRACT_IMP 1654 EXTRACT_IMP
1677} 1655}
1678 1656
1679int extractRQ(int mode, int tm, KIVEC(j), KQMAT(m), QMAT(r)) { 1657int extractQ(int modei, int modej, int tm, KIVEC(i), KIVEC(j), KQMAT(m), QMAT(r)) {
1680 EXTRACT_IMP 1658 EXTRACT_IMP
1681} 1659}
1682 1660
1683int extractRI(int mode, int tm, KIVEC(j), KIMAT(m), IMAT(r)) { 1661int extractI(int modei, int modej, int tm, KIVEC(i), KIVEC(j), KIMAT(m), IMAT(r)) {
1684 EXTRACT_IMP 1662 EXTRACT_IMP
1685} 1663}
1686 1664
diff --git a/packages/base/src/Data/Packed/Internal/Matrix.hs b/packages/base/src/Data/Packed/Internal/Matrix.hs
index 1aee7d3..76d2204 100644
--- a/packages/base/src/Data/Packed/Internal/Matrix.hs
+++ b/packages/base/src/Data/Packed/Internal/Matrix.hs
@@ -265,32 +265,32 @@ class (Storable a) => Element a where
265 transdata = transdataP -- transdata' 265 transdata = transdataP -- transdata'
266 constantD :: a -> Int -> Vector a 266 constantD :: a -> Int -> Vector a
267 constantD = constantP -- constant' 267 constantD = constantP -- constant'
268 extractR :: Matrix a -> CInt -> Idxs -> Matrix a 268 extractR :: Matrix a -> CInt -> Idxs -> CInt -> Idxs -> Matrix a
269 269
270instance Element Float where 270instance Element Float where
271 transdata = transdataAux ctransF 271 transdata = transdataAux ctransF
272 constantD = constantAux cconstantF 272 constantD = constantAux cconstantF
273 extractR = extractAux c_extractRF 273 extractR = extractAux c_extractF
274 274
275instance Element Double where 275instance Element Double where
276 transdata = transdataAux ctransR 276 transdata = transdataAux ctransR
277 constantD = constantAux cconstantR 277 constantD = constantAux cconstantR
278 extractR = extractAux c_extractRD 278 extractR = extractAux c_extractD
279 279
280instance Element (Complex Float) where 280instance Element (Complex Float) where
281 transdata = transdataAux ctransQ 281 transdata = transdataAux ctransQ
282 constantD = constantAux cconstantQ 282 constantD = constantAux cconstantQ
283 extractR = extractAux c_extractRQ 283 extractR = extractAux c_extractQ
284 284
285instance Element (Complex Double) where 285instance Element (Complex Double) where
286 transdata = transdataAux ctransC 286 transdata = transdataAux ctransC
287 constantD = constantAux cconstantC 287 constantD = constantAux cconstantC
288 extractR = extractAux c_extractRC 288 extractR = extractAux c_extractC
289 289
290instance Element (CInt) where 290instance Element (CInt) where
291 transdata = transdataAux ctransI 291 transdata = transdataAux ctransI
292 constantD = constantAux cconstantI 292 constantD = constantAux cconstantI
293 extractR = extractAux c_extractRI 293 extractR = extractAux c_extractI
294 294
295 295
296------------------------------------------------------------------- 296-------------------------------------------------------------------
@@ -443,26 +443,26 @@ isT Matrix{order = RowMajor} = 0
443tt x@Matrix{order = ColumnMajor} = trans x 443tt x@Matrix{order = ColumnMajor} = trans x
444tt x@Matrix{order = RowMajor} = x 444tt x@Matrix{order = RowMajor} = x
445 445
446--extractAux :: Matrix Double -> Idxs -> Matrix Double 446
447extractAux f m mode v = unsafePerformIO $ do 447extractAux f m moder vr modec vc = unsafePerformIO $ do
448 let nr | mode == 0 = fromIntegral $ max 0 (v@>1 - v@>0 + 1) 448 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr
449 | otherwise = dim v 449 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc
450 r <- createMatrix RowMajor nr (cols m) 450 r <- createMatrix RowMajor nr nc
451 app3 (f mode (isT m)) vec v mat (tt m) mat r "extractAux" 451 app4 (f moder modec (isT m)) vec vr vec vc mat (tt m) mat r "extractAux"
452 return r 452 return r
453 453
454foreign import ccall unsafe "extractRD" c_extractRD 454foreign import ccall unsafe "extractD" c_extractD
455 :: CInt -> CInt -> CIdxs (CM Double (CM Double (IO CInt))) 455 :: CInt -> CInt -> CInt -> CIdxs (CIdxs (CM Double (CM Double (IO CInt))))
456 456
457foreign import ccall unsafe "extractRF" c_extractRF 457foreign import ccall unsafe "extractF" c_extractF
458 :: CInt -> CInt -> CIdxs (CM Float (CM Float (IO CInt))) 458 :: CInt -> CInt -> CInt -> CIdxs (CIdxs (CM Float (CM Float (IO CInt))))
459 459
460foreign import ccall unsafe "extractRC" c_extractRC 460foreign import ccall unsafe "extractC" c_extractC
461 :: CInt -> CInt -> CIdxs (CM (Complex Double) (CM (Complex Double) (IO CInt))) 461 :: CInt -> CInt -> CInt -> CIdxs (CIdxs (CM (Complex Double) (CM (Complex Double) (IO CInt))))
462 462
463foreign import ccall unsafe "extractRQ" c_extractRQ 463foreign import ccall unsafe "extractQ" c_extractQ
464 :: CInt -> CInt -> CIdxs (CM (Complex Float) (CM (Complex Float) (IO CInt))) 464 :: CInt -> CInt -> CInt -> CIdxs (CIdxs (CM (Complex Float) (CM (Complex Float) (IO CInt))))
465 465
466foreign import ccall unsafe "extractRI" c_extractRI 466foreign import ccall unsafe "extractI" c_extractI
467 :: CInt -> CInt -> CIdxs (CM CInt (CM CInt (IO CInt))) 467 :: CInt -> CInt -> CInt -> CIdxs (CIdxs (CM CInt (CM CInt (IO CInt))))
468 468
diff --git a/packages/base/src/Data/Packed/Internal/Numeric.hs b/packages/base/src/Data/Packed/Internal/Numeric.hs
index 00ec70c..353877a 100644
--- a/packages/base/src/Data/Packed/Internal/Numeric.hs
+++ b/packages/base/src/Data/Packed/Internal/Numeric.hs
@@ -39,7 +39,7 @@ module Data.Packed.Internal.Numeric (
39 roundVector, 39 roundVector,
40 RealOf, ComplexOf, SingleOf, DoubleOf, 40 RealOf, ComplexOf, SingleOf, DoubleOf,
41 IndexOf, 41 IndexOf,
42 CInt, Extractor(..), (??),(¿¿), 42 CInt, Extractor(..), (??),
43 module Data.Complex 43 module Data.Complex
44) where 44) where
45 45
@@ -69,27 +69,50 @@ type instance ArgOf Matrix a = a -> a -> a
69 69
70-------------------------------------------------------------------------- 70--------------------------------------------------------------------------
71 71
72data Extractor = All | Range Int Int | At [Int] | AtCyc [Int] | Take Int | Drop Int 72data Extractor
73 = All
74 | Range Int Int
75 | At [Int]
76 | AtCyc [Int]
77 | Take Int
78 | Drop Int
79 deriving Show
73 80
74idxs js = fromList (map fromIntegral js) :: Idxs 81idxs js = fromList (map fromIntegral js) :: Idxs
75 82
76infixl 9 ??, ¿¿ 83infixl 9 ??
77(??),(¿¿) :: Element t => Matrix t -> Extractor -> Matrix t 84(??) :: Element t => Matrix t -> (Extractor,Extractor) -> Matrix t
78 85
79m ?? All = m 86
80m ?? Take 0 = (0><cols m) [] 87extractError m e = error $ printf "can't extract %s from matrix %dx%d" (show e) (rows m) (cols m)
81m ?? Take n | abs n >= rows m = m 88
82m ?? Drop 0 = m 89m ?? e@(Range a b,_) | a < 0 || b >= rows m = extractError m e
83m ?? Drop n | abs n >= rows m = (0><cols m) [] 90m ?? e@(_,Range a b) | a < 0 || b >= cols m = extractError m e
84m ?? Range a b | a > b = m ?? Take 0 91m ?? e@(At ps,_) | minimum ps < 0 || maximum ps >= rows m = extractError m e
85m ?? Range a b | a < 0 || b >= cols m = error $ 92m ?? e@(_,At ps) | minimum ps < 0 || maximum ps >= cols m = extractError m e
86 printf "can't extract rows %d to %d from matrix %dx%d" a b (rows m) (cols m) 93
87m ?? At ps | minimum ps < 0 || maximum ps >= rows m = error $ 94m ?? (All,All) = m
88 printf "can't extract rows %s from matrix %dx%d" (show ps) (rows m) (cols m) 95
89 96m ?? (Range a b,e) | a > b = m ?? (Take 0,e)
90m ?? er = extractR m mode js 97m ?? (e,Range a b) | a > b = m ?? (e,Take 0)
98
99m ?? (Take 0,e) = (0><cols m) [] ?? (All,e)
100m ?? (e,Take 0) = (rows m><0) [] ?? (e,All)
101
102m ?? (Take n,e) | abs n > rows m = m ?? (All,e)
103m ?? (e,Take n) | abs n > cols m = m ?? (e,All)
104
105m ?? (Drop 0,e) = m ?? (All,e)
106m ?? (e,Drop 0) = m ?? (e,All)
107
108m ?? (Drop n,e) | abs n > rows m = m ?? (Take 0,e)
109m ?? (e,Drop n) | abs n > cols m = m ?? (e,Take 0)
110
111
112m ?? (er,ec) = extractR m moder rs modec cs
91 where 113 where
92 (mode,js) = mkExt (rows m) er 114 (moder,rs) = mkExt (rows m) er
115 (modec,cs) = mkExt (cols m) ec
93 ran a b = (0, idxs [a,b]) 116 ran a b = (0, idxs [a,b])
94 pos ks = (1, idxs ks) 117 pos ks = (1, idxs ks)
95 mkExt _ (At ks) = pos ks 118 mkExt _ (At ks) = pos ks
@@ -104,13 +127,6 @@ m ?? er = extractR m mode js
104 | otherwise = mkExt n (Take (n+k)) 127 | otherwise = mkExt n (Take (n+k))
105 128
106 129
107m ¿¿ Range a b | a < 0 || b > cols m -1 = error $
108 printf "can't extract columns %d to %d from matrix %dx%d" a b (rows m) (cols m)
109
110m ¿¿ At ps | minimum ps < 0 || maximum ps >= cols m = error $
111 printf "can't extract columns %s from matrix %dx%d" (show ps) (rows m) (cols m)
112m ¿¿ ec = trans (trans m ?? ec)
113
114------------------------------------------------------------------- 130-------------------------------------------------------------------
115 131
116 132
diff --git a/packages/base/src/Numeric/LinearAlgebra/Data.hs b/packages/base/src/Numeric/LinearAlgebra/Data.hs
index e4485fc..c94350f 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Data.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Data.hs
@@ -37,7 +37,7 @@ module Numeric.LinearAlgebra.Data(
37 fromList, toList, subVector, takesV, vjoin, 37 fromList, toList, subVector, takesV, vjoin,
38 flatten, reshape, asRow, asColumn, row, col, 38 flatten, reshape, asRow, asColumn, row, col,
39 fromRows, toRows, fromColumns, toColumns, fromLists, toLists, fromArray2D, 39 fromRows, toRows, fromColumns, toColumns, fromLists, toLists, fromArray2D,
40 Extractor(..), (??), (¿¿), 40 Extractor(..), (??),
41 takeRows, dropRows, takeColumns, dropColumns, subMatrix, (?), (¿), fliprl, flipud, 41 takeRows, dropRows, takeColumns, dropColumns, subMatrix, (?), (¿), fliprl, flipud,
42 42
43 -- * Block matrix 43 -- * Block matrix
@@ -81,6 +81,6 @@ import Numeric.LinearAlgebra.Util hiding ((&),(#))
81import Data.Complex 81import Data.Complex
82import Numeric.Sparse 82import Numeric.Sparse
83import Data.Packed.Internal.Vector(Idxs) 83import Data.Packed.Internal.Vector(Idxs)
84import Data.Packed.Internal.Numeric(CInt,Extractor(..),(??),(¿¿)) 84import Data.Packed.Internal.Numeric(CInt,Extractor(..),(??))
85 85
86 86