diff options
Diffstat (limited to 'packages/base')
-rw-r--r-- | packages/base/src/C/lapack-aux.c | 88 | ||||
-rw-r--r-- | packages/base/src/Data/Packed/Internal/Matrix.hs | 44 | ||||
-rw-r--r-- | packages/base/src/Data/Packed/Internal/Numeric.hs | 66 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Data.hs | 4 |
4 files changed, 98 insertions, 104 deletions
diff --git a/packages/base/src/C/lapack-aux.c b/packages/base/src/C/lapack-aux.c index e76d31e..c2cdc62 100644 --- a/packages/base/src/C/lapack-aux.c +++ b/packages/base/src/C/lapack-aux.c | |||
@@ -1290,19 +1290,18 @@ int multiplyQ(int ta, int tb, KQMAT(a),KQMAT(b),QMAT(r)) { | |||
1290 | int multiplyI(int ta, int tb, KIMAT(a), KIMAT(b), IMAT(r)) { | 1290 | int multiplyI(int ta, int tb, KIMAT(a), KIMAT(b), IMAT(r)) { |
1291 | int i,j,k; | 1291 | int i,j,k; |
1292 | int n; | 1292 | int n; |
1293 | int u, v; | 1293 | int ai,ak,bk,bj; |
1294 | if (ta==0) { | 1294 | |
1295 | n = ac; | 1295 | n = ta ? ar : ac; |
1296 | } else { | 1296 | |
1297 | n = ar; | 1297 | if (ta==0) { ai = 1; ak = ar; } else { ai = ar; ak = 1; } |
1298 | } | 1298 | if (tb==0) { bk = 1; bj = br; } else { bk = br; bj = 1; } |
1299 | |||
1299 | for (i=0;i<rr;i++) { | 1300 | for (i=0;i<rr;i++) { |
1300 | for (j=0;j<rc;j++) { | 1301 | for (j=0;j<rc;j++) { |
1301 | rp[i*rc+j] = 0; | 1302 | rp[i+rr*j] = 0; |
1302 | for (k=0; k<n; k++) { | 1303 | for (k=0; k<n; k++) { |
1303 | u = ta==0 ? ap[i*ac+k] : ap[k*ac+i]; | 1304 | rp[i+rr*j] += ap[ai*i+ak*k] * bp[bk*k+bj*j]; |
1304 | v = tb==0 ? bp[k*bc+j] : bp[j*bc+k]; | ||
1305 | rp[i*rc+j] += u*v; | ||
1306 | } | 1305 | } |
1307 | } | 1306 | } |
1308 | } | 1307 | } |
@@ -1622,65 +1621,44 @@ int chooseD(KIVEC(cond),KDVEC(lt),KDVEC(eq),KDVEC(gt),DVEC(r)) { | |||
1622 | 1621 | ||
1623 | //////////////////////// extract ///////////////////////////////// | 1622 | //////////////////////// extract ///////////////////////////////// |
1624 | 1623 | ||
1625 | #define EXTRACT_IMP \ | 1624 | #define EXTRACT_IMP \ |
1626 | /*REQUIRES((tm == 0 && jn==rr && mc==rc) || (jn==rr && mr==rc) ,BAD_SIZE); */ \ | 1625 | int i,j,si,sj,ni,nj,ai,aj; \ |
1627 | DEBUGMSG("extractRD") \ | 1626 | if (tm==0) { \ |
1628 | int k,i,s; \ | 1627 | ai=mc; aj=1; \ |
1629 | if (tm==0) { \ | 1628 | } else { \ |
1630 | if (mode==0) { \ | 1629 | ai=1, aj=mc; \ |
1631 | for (k=0; k<jp[1]-jp[0]+1; k++) { \ | 1630 | } \ |
1632 | s = k + jp[0]; \ | 1631 | ni = modei ? in : ip[1]-ip[0]+1; \ |
1633 | printf("%d\n",s); \ | 1632 | nj = modej ? jn : jp[1]-jp[0]+1; \ |
1634 | for (i=0; i<mc; i++) { \ | 1633 | \ |
1635 | rp[rc*k+i] = mp[mc*s+i]; \ | 1634 | for (i=0; i<ni; i++) { \ |
1636 | } \ | 1635 | si = modei ? ip[i] : i+ip[0]; \ |
1637 | } \ | 1636 | \ |
1638 | } else { \ | 1637 | for (j=0; j<nj; j++) { \ |
1639 | for (k=0;k<jn;k++) { \ | 1638 | sj = modej ? jp[j] : j+jp[0]; \ |
1640 | s = jp[k]; \ | 1639 | \ |
1641 | for (i=0; i<mc; i++) { \ | 1640 | rp[rc*i+j] = mp[ai*si+aj*sj]; \ |
1642 | rp[rc*k+i] = mp[mc*s+i]; \ | 1641 | } \ |
1643 | } \ | 1642 | } \ |
1644 | } \ | ||
1645 | } \ | ||
1646 | } else { \ | ||
1647 | if (mode==0) { \ | ||
1648 | for (k=0; k<jp[1]-jp[0]+1; k++) { \ | ||
1649 | s = k + jp[0]; \ | ||
1650 | printf("%d\n",s); \ | ||
1651 | for (i=0; i<mr; i++) { \ | ||
1652 | rp[rc*k+i] = mp[mc*i+s]; \ | ||
1653 | } \ | ||
1654 | } \ | ||
1655 | } else { \ | ||
1656 | for (k=0;k<jn;k++) { \ | ||
1657 | s = jp[k]; \ | ||
1658 | for (i=0; i<mr; i++) { \ | ||
1659 | rp[rc*k+i] = mp[mc*i+s]; \ | ||
1660 | } \ | ||
1661 | } \ | ||
1662 | } \ | ||
1663 | } \ | ||
1664 | OK | 1643 | OK |
1665 | 1644 | ||
1666 | 1645 | int extractD(int modei, int modej, int tm, KIVEC(i), KIVEC(j), KDMAT(m), DMAT(r)) { | |
1667 | int extractRD(int mode, int tm, KIVEC(j), KDMAT(m), DMAT(r)) { | ||
1668 | EXTRACT_IMP | 1646 | EXTRACT_IMP |
1669 | } | 1647 | } |
1670 | 1648 | ||
1671 | int extractRF(int mode, int tm, KIVEC(j), KFMAT(m), FMAT(r)) { | 1649 | int extractF(int modei, int modej, int tm, KIVEC(i), KIVEC(j), KFMAT(m), FMAT(r)) { |
1672 | EXTRACT_IMP | 1650 | EXTRACT_IMP |
1673 | } | 1651 | } |
1674 | 1652 | ||
1675 | int extractRC(int mode, int tm, KIVEC(j), KCMAT(m), CMAT(r)) { | 1653 | int extractC(int modei, int modej, int tm, KIVEC(i), KIVEC(j), KCMAT(m), CMAT(r)) { |
1676 | EXTRACT_IMP | 1654 | EXTRACT_IMP |
1677 | } | 1655 | } |
1678 | 1656 | ||
1679 | int extractRQ(int mode, int tm, KIVEC(j), KQMAT(m), QMAT(r)) { | 1657 | int extractQ(int modei, int modej, int tm, KIVEC(i), KIVEC(j), KQMAT(m), QMAT(r)) { |
1680 | EXTRACT_IMP | 1658 | EXTRACT_IMP |
1681 | } | 1659 | } |
1682 | 1660 | ||
1683 | int extractRI(int mode, int tm, KIVEC(j), KIMAT(m), IMAT(r)) { | 1661 | int extractI(int modei, int modej, int tm, KIVEC(i), KIVEC(j), KIMAT(m), IMAT(r)) { |
1684 | EXTRACT_IMP | 1662 | EXTRACT_IMP |
1685 | } | 1663 | } |
1686 | 1664 | ||
diff --git a/packages/base/src/Data/Packed/Internal/Matrix.hs b/packages/base/src/Data/Packed/Internal/Matrix.hs index 1aee7d3..76d2204 100644 --- a/packages/base/src/Data/Packed/Internal/Matrix.hs +++ b/packages/base/src/Data/Packed/Internal/Matrix.hs | |||
@@ -265,32 +265,32 @@ class (Storable a) => Element a where | |||
265 | transdata = transdataP -- transdata' | 265 | transdata = transdataP -- transdata' |
266 | constantD :: a -> Int -> Vector a | 266 | constantD :: a -> Int -> Vector a |
267 | constantD = constantP -- constant' | 267 | constantD = constantP -- constant' |
268 | extractR :: Matrix a -> CInt -> Idxs -> Matrix a | 268 | extractR :: Matrix a -> CInt -> Idxs -> CInt -> Idxs -> Matrix a |
269 | 269 | ||
270 | instance Element Float where | 270 | instance Element Float where |
271 | transdata = transdataAux ctransF | 271 | transdata = transdataAux ctransF |
272 | constantD = constantAux cconstantF | 272 | constantD = constantAux cconstantF |
273 | extractR = extractAux c_extractRF | 273 | extractR = extractAux c_extractF |
274 | 274 | ||
275 | instance Element Double where | 275 | instance Element Double where |
276 | transdata = transdataAux ctransR | 276 | transdata = transdataAux ctransR |
277 | constantD = constantAux cconstantR | 277 | constantD = constantAux cconstantR |
278 | extractR = extractAux c_extractRD | 278 | extractR = extractAux c_extractD |
279 | 279 | ||
280 | instance Element (Complex Float) where | 280 | instance Element (Complex Float) where |
281 | transdata = transdataAux ctransQ | 281 | transdata = transdataAux ctransQ |
282 | constantD = constantAux cconstantQ | 282 | constantD = constantAux cconstantQ |
283 | extractR = extractAux c_extractRQ | 283 | extractR = extractAux c_extractQ |
284 | 284 | ||
285 | instance Element (Complex Double) where | 285 | instance Element (Complex Double) where |
286 | transdata = transdataAux ctransC | 286 | transdata = transdataAux ctransC |
287 | constantD = constantAux cconstantC | 287 | constantD = constantAux cconstantC |
288 | extractR = extractAux c_extractRC | 288 | extractR = extractAux c_extractC |
289 | 289 | ||
290 | instance Element (CInt) where | 290 | instance Element (CInt) where |
291 | transdata = transdataAux ctransI | 291 | transdata = transdataAux ctransI |
292 | constantD = constantAux cconstantI | 292 | constantD = constantAux cconstantI |
293 | extractR = extractAux c_extractRI | 293 | extractR = extractAux c_extractI |
294 | 294 | ||
295 | 295 | ||
296 | ------------------------------------------------------------------- | 296 | ------------------------------------------------------------------- |
@@ -443,26 +443,26 @@ isT Matrix{order = RowMajor} = 0 | |||
443 | tt x@Matrix{order = ColumnMajor} = trans x | 443 | tt x@Matrix{order = ColumnMajor} = trans x |
444 | tt x@Matrix{order = RowMajor} = x | 444 | tt x@Matrix{order = RowMajor} = x |
445 | 445 | ||
446 | --extractAux :: Matrix Double -> Idxs -> Matrix Double | 446 | |
447 | extractAux f m mode v = unsafePerformIO $ do | 447 | extractAux f m moder vr modec vc = unsafePerformIO $ do |
448 | let nr | mode == 0 = fromIntegral $ max 0 (v@>1 - v@>0 + 1) | 448 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr |
449 | | otherwise = dim v | 449 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc |
450 | r <- createMatrix RowMajor nr (cols m) | 450 | r <- createMatrix RowMajor nr nc |
451 | app3 (f mode (isT m)) vec v mat (tt m) mat r "extractAux" | 451 | app4 (f moder modec (isT m)) vec vr vec vc mat (tt m) mat r "extractAux" |
452 | return r | 452 | return r |
453 | 453 | ||
454 | foreign import ccall unsafe "extractRD" c_extractRD | 454 | foreign import ccall unsafe "extractD" c_extractD |
455 | :: CInt -> CInt -> CIdxs (CM Double (CM Double (IO CInt))) | 455 | :: CInt -> CInt -> CInt -> CIdxs (CIdxs (CM Double (CM Double (IO CInt)))) |
456 | 456 | ||
457 | foreign import ccall unsafe "extractRF" c_extractRF | 457 | foreign import ccall unsafe "extractF" c_extractF |
458 | :: CInt -> CInt -> CIdxs (CM Float (CM Float (IO CInt))) | 458 | :: CInt -> CInt -> CInt -> CIdxs (CIdxs (CM Float (CM Float (IO CInt)))) |
459 | 459 | ||
460 | foreign import ccall unsafe "extractRC" c_extractRC | 460 | foreign import ccall unsafe "extractC" c_extractC |
461 | :: CInt -> CInt -> CIdxs (CM (Complex Double) (CM (Complex Double) (IO CInt))) | 461 | :: CInt -> CInt -> CInt -> CIdxs (CIdxs (CM (Complex Double) (CM (Complex Double) (IO CInt)))) |
462 | 462 | ||
463 | foreign import ccall unsafe "extractRQ" c_extractRQ | 463 | foreign import ccall unsafe "extractQ" c_extractQ |
464 | :: CInt -> CInt -> CIdxs (CM (Complex Float) (CM (Complex Float) (IO CInt))) | 464 | :: CInt -> CInt -> CInt -> CIdxs (CIdxs (CM (Complex Float) (CM (Complex Float) (IO CInt)))) |
465 | 465 | ||
466 | foreign import ccall unsafe "extractRI" c_extractRI | 466 | foreign import ccall unsafe "extractI" c_extractI |
467 | :: CInt -> CInt -> CIdxs (CM CInt (CM CInt (IO CInt))) | 467 | :: CInt -> CInt -> CInt -> CIdxs (CIdxs (CM CInt (CM CInt (IO CInt)))) |
468 | 468 | ||
diff --git a/packages/base/src/Data/Packed/Internal/Numeric.hs b/packages/base/src/Data/Packed/Internal/Numeric.hs index 00ec70c..353877a 100644 --- a/packages/base/src/Data/Packed/Internal/Numeric.hs +++ b/packages/base/src/Data/Packed/Internal/Numeric.hs | |||
@@ -39,7 +39,7 @@ module Data.Packed.Internal.Numeric ( | |||
39 | roundVector, | 39 | roundVector, |
40 | RealOf, ComplexOf, SingleOf, DoubleOf, | 40 | RealOf, ComplexOf, SingleOf, DoubleOf, |
41 | IndexOf, | 41 | IndexOf, |
42 | CInt, Extractor(..), (??),(¿¿), | 42 | CInt, Extractor(..), (??), |
43 | module Data.Complex | 43 | module Data.Complex |
44 | ) where | 44 | ) where |
45 | 45 | ||
@@ -69,27 +69,50 @@ type instance ArgOf Matrix a = a -> a -> a | |||
69 | 69 | ||
70 | -------------------------------------------------------------------------- | 70 | -------------------------------------------------------------------------- |
71 | 71 | ||
72 | data Extractor = All | Range Int Int | At [Int] | AtCyc [Int] | Take Int | Drop Int | 72 | data Extractor |
73 | = All | ||
74 | | Range Int Int | ||
75 | | At [Int] | ||
76 | | AtCyc [Int] | ||
77 | | Take Int | ||
78 | | Drop Int | ||
79 | deriving Show | ||
73 | 80 | ||
74 | idxs js = fromList (map fromIntegral js) :: Idxs | 81 | idxs js = fromList (map fromIntegral js) :: Idxs |
75 | 82 | ||
76 | infixl 9 ??, ¿¿ | 83 | infixl 9 ?? |
77 | (??),(¿¿) :: Element t => Matrix t -> Extractor -> Matrix t | 84 | (??) :: Element t => Matrix t -> (Extractor,Extractor) -> Matrix t |
78 | 85 | ||
79 | m ?? All = m | 86 | |
80 | m ?? Take 0 = (0><cols m) [] | 87 | extractError m e = error $ printf "can't extract %s from matrix %dx%d" (show e) (rows m) (cols m) |
81 | m ?? Take n | abs n >= rows m = m | 88 | |
82 | m ?? Drop 0 = m | 89 | m ?? e@(Range a b,_) | a < 0 || b >= rows m = extractError m e |
83 | m ?? Drop n | abs n >= rows m = (0><cols m) [] | 90 | m ?? e@(_,Range a b) | a < 0 || b >= cols m = extractError m e |
84 | m ?? Range a b | a > b = m ?? Take 0 | 91 | m ?? e@(At ps,_) | minimum ps < 0 || maximum ps >= rows m = extractError m e |
85 | m ?? Range a b | a < 0 || b >= cols m = error $ | 92 | m ?? e@(_,At ps) | minimum ps < 0 || maximum ps >= cols m = extractError m e |
86 | printf "can't extract rows %d to %d from matrix %dx%d" a b (rows m) (cols m) | 93 | |
87 | m ?? At ps | minimum ps < 0 || maximum ps >= rows m = error $ | 94 | m ?? (All,All) = m |
88 | printf "can't extract rows %s from matrix %dx%d" (show ps) (rows m) (cols m) | 95 | |
89 | 96 | m ?? (Range a b,e) | a > b = m ?? (Take 0,e) | |
90 | m ?? er = extractR m mode js | 97 | m ?? (e,Range a b) | a > b = m ?? (e,Take 0) |
98 | |||
99 | m ?? (Take 0,e) = (0><cols m) [] ?? (All,e) | ||
100 | m ?? (e,Take 0) = (rows m><0) [] ?? (e,All) | ||
101 | |||
102 | m ?? (Take n,e) | abs n > rows m = m ?? (All,e) | ||
103 | m ?? (e,Take n) | abs n > cols m = m ?? (e,All) | ||
104 | |||
105 | m ?? (Drop 0,e) = m ?? (All,e) | ||
106 | m ?? (e,Drop 0) = m ?? (e,All) | ||
107 | |||
108 | m ?? (Drop n,e) | abs n > rows m = m ?? (Take 0,e) | ||
109 | m ?? (e,Drop n) | abs n > cols m = m ?? (e,Take 0) | ||
110 | |||
111 | |||
112 | m ?? (er,ec) = extractR m moder rs modec cs | ||
91 | where | 113 | where |
92 | (mode,js) = mkExt (rows m) er | 114 | (moder,rs) = mkExt (rows m) er |
115 | (modec,cs) = mkExt (cols m) ec | ||
93 | ran a b = (0, idxs [a,b]) | 116 | ran a b = (0, idxs [a,b]) |
94 | pos ks = (1, idxs ks) | 117 | pos ks = (1, idxs ks) |
95 | mkExt _ (At ks) = pos ks | 118 | mkExt _ (At ks) = pos ks |
@@ -104,13 +127,6 @@ m ?? er = extractR m mode js | |||
104 | | otherwise = mkExt n (Take (n+k)) | 127 | | otherwise = mkExt n (Take (n+k)) |
105 | 128 | ||
106 | 129 | ||
107 | m ¿¿ Range a b | a < 0 || b > cols m -1 = error $ | ||
108 | printf "can't extract columns %d to %d from matrix %dx%d" a b (rows m) (cols m) | ||
109 | |||
110 | m ¿¿ At ps | minimum ps < 0 || maximum ps >= cols m = error $ | ||
111 | printf "can't extract columns %s from matrix %dx%d" (show ps) (rows m) (cols m) | ||
112 | m ¿¿ ec = trans (trans m ?? ec) | ||
113 | |||
114 | ------------------------------------------------------------------- | 130 | ------------------------------------------------------------------- |
115 | 131 | ||
116 | 132 | ||
diff --git a/packages/base/src/Numeric/LinearAlgebra/Data.hs b/packages/base/src/Numeric/LinearAlgebra/Data.hs index e4485fc..c94350f 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Data.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Data.hs | |||
@@ -37,7 +37,7 @@ module Numeric.LinearAlgebra.Data( | |||
37 | fromList, toList, subVector, takesV, vjoin, | 37 | fromList, toList, subVector, takesV, vjoin, |
38 | flatten, reshape, asRow, asColumn, row, col, | 38 | flatten, reshape, asRow, asColumn, row, col, |
39 | fromRows, toRows, fromColumns, toColumns, fromLists, toLists, fromArray2D, | 39 | fromRows, toRows, fromColumns, toColumns, fromLists, toLists, fromArray2D, |
40 | Extractor(..), (??), (¿¿), | 40 | Extractor(..), (??), |
41 | takeRows, dropRows, takeColumns, dropColumns, subMatrix, (?), (¿), fliprl, flipud, | 41 | takeRows, dropRows, takeColumns, dropColumns, subMatrix, (?), (¿), fliprl, flipud, |
42 | 42 | ||
43 | -- * Block matrix | 43 | -- * Block matrix |
@@ -81,6 +81,6 @@ import Numeric.LinearAlgebra.Util hiding ((&),(#)) | |||
81 | import Data.Complex | 81 | import Data.Complex |
82 | import Numeric.Sparse | 82 | import Numeric.Sparse |
83 | import Data.Packed.Internal.Vector(Idxs) | 83 | import Data.Packed.Internal.Vector(Idxs) |
84 | import Data.Packed.Internal.Numeric(CInt,Extractor(..),(??),(¿¿)) | 84 | import Data.Packed.Internal.Numeric(CInt,Extractor(..),(??)) |
85 | 85 | ||
86 | 86 | ||