diff options
-rw-r--r-- | packages/base/src/C/lapack-aux.c | 56 | ||||
-rw-r--r-- | packages/base/src/Data/Packed/Internal/Matrix.hs | 20 | ||||
-rw-r--r-- | packages/base/src/Data/Packed/Internal/Numeric.hs | 53 |
3 files changed, 99 insertions, 30 deletions
diff --git a/packages/base/src/C/lapack-aux.c b/packages/base/src/C/lapack-aux.c index d56d466..e76d31e 100644 --- a/packages/base/src/C/lapack-aux.c +++ b/packages/base/src/C/lapack-aux.c | |||
@@ -1623,44 +1623,64 @@ int chooseD(KIVEC(cond),KDVEC(lt),KDVEC(eq),KDVEC(gt),DVEC(r)) { | |||
1623 | //////////////////////// extract ///////////////////////////////// | 1623 | //////////////////////// extract ///////////////////////////////// |
1624 | 1624 | ||
1625 | #define EXTRACT_IMP \ | 1625 | #define EXTRACT_IMP \ |
1626 | REQUIRES((tm == 0 && jn==rr && mc==rc) || (jn==rr && mr==rc) ,BAD_SIZE); \ | 1626 | /*REQUIRES((tm == 0 && jn==rr && mc==rc) || (jn==rr && mr==rc) ,BAD_SIZE); */ \ |
1627 | DEBUGMSG("extractRD") \ | 1627 | DEBUGMSG("extractRD") \ |
1628 | int k,i,s; \ | 1628 | int k,i,s; \ |
1629 | if (tm==0) { \ | 1629 | if (tm==0) { \ |
1630 | for (k=0;k<jn;k++) { \ | 1630 | if (mode==0) { \ |
1631 | s = jp[k]; \ | 1631 | for (k=0; k<jp[1]-jp[0]+1; k++) { \ |
1632 | for (i=0; i<mc; i++) { \ | 1632 | s = k + jp[0]; \ |
1633 | rp[rc*k+i] = mp[mc*s+i]; \ | 1633 | printf("%d\n",s); \ |
1634 | } \ | 1634 | for (i=0; i<mc; i++) { \ |
1635 | } \ | 1635 | rp[rc*k+i] = mp[mc*s+i]; \ |
1636 | } \ | ||
1637 | } \ | ||
1638 | } else { \ | ||
1639 | for (k=0;k<jn;k++) { \ | ||
1640 | s = jp[k]; \ | ||
1641 | for (i=0; i<mc; i++) { \ | ||
1642 | rp[rc*k+i] = mp[mc*s+i]; \ | ||
1643 | } \ | ||
1644 | } \ | ||
1645 | } \ | ||
1636 | } else { \ | 1646 | } else { \ |
1637 | for (k=0;k<jn;k++) { \ | 1647 | if (mode==0) { \ |
1638 | s = jp[k]; \ | 1648 | for (k=0; k<jp[1]-jp[0]+1; k++) { \ |
1639 | for (i=0; i<mr; i++) { \ | 1649 | s = k + jp[0]; \ |
1640 | rp[rc*k+i] = mp[mc*i+s]; \ | 1650 | printf("%d\n",s); \ |
1641 | } \ | 1651 | for (i=0; i<mr; i++) { \ |
1642 | } \ | 1652 | rp[rc*k+i] = mp[mc*i+s]; \ |
1653 | } \ | ||
1654 | } \ | ||
1655 | } else { \ | ||
1656 | for (k=0;k<jn;k++) { \ | ||
1657 | s = jp[k]; \ | ||
1658 | for (i=0; i<mr; i++) { \ | ||
1659 | rp[rc*k+i] = mp[mc*i+s]; \ | ||
1660 | } \ | ||
1661 | } \ | ||
1662 | } \ | ||
1643 | } \ | 1663 | } \ |
1644 | OK | 1664 | OK |
1645 | 1665 | ||
1646 | 1666 | ||
1647 | int extractRD(int tm, KIVEC(j), KDMAT(m), DMAT(r)) { | 1667 | int extractRD(int mode, int tm, KIVEC(j), KDMAT(m), DMAT(r)) { |
1648 | EXTRACT_IMP | 1668 | EXTRACT_IMP |
1649 | } | 1669 | } |
1650 | 1670 | ||
1651 | int extractRF(int tm, KIVEC(j), KFMAT(m), FMAT(r)) { | 1671 | int extractRF(int mode, int tm, KIVEC(j), KFMAT(m), FMAT(r)) { |
1652 | EXTRACT_IMP | 1672 | EXTRACT_IMP |
1653 | } | 1673 | } |
1654 | 1674 | ||
1655 | int extractRC(int tm, KIVEC(j), KCMAT(m), CMAT(r)) { | 1675 | int extractRC(int mode, int tm, KIVEC(j), KCMAT(m), CMAT(r)) { |
1656 | EXTRACT_IMP | 1676 | EXTRACT_IMP |
1657 | } | 1677 | } |
1658 | 1678 | ||
1659 | int extractRQ(int tm, KIVEC(j), KQMAT(m), QMAT(r)) { | 1679 | int extractRQ(int mode, int tm, KIVEC(j), KQMAT(m), QMAT(r)) { |
1660 | EXTRACT_IMP | 1680 | EXTRACT_IMP |
1661 | } | 1681 | } |
1662 | 1682 | ||
1663 | int extractRI(int tm, KIVEC(j), KIMAT(m), IMAT(r)) { | 1683 | int extractRI(int mode, int tm, KIVEC(j), KIMAT(m), IMAT(r)) { |
1664 | EXTRACT_IMP | 1684 | EXTRACT_IMP |
1665 | } | 1685 | } |
1666 | 1686 | ||
diff --git a/packages/base/src/Data/Packed/Internal/Matrix.hs b/packages/base/src/Data/Packed/Internal/Matrix.hs index be5fb03..1aee7d3 100644 --- a/packages/base/src/Data/Packed/Internal/Matrix.hs +++ b/packages/base/src/Data/Packed/Internal/Matrix.hs | |||
@@ -265,7 +265,7 @@ class (Storable a) => Element a where | |||
265 | transdata = transdataP -- transdata' | 265 | transdata = transdataP -- transdata' |
266 | constantD :: a -> Int -> Vector a | 266 | constantD :: a -> Int -> Vector a |
267 | constantD = constantP -- constant' | 267 | constantD = constantP -- constant' |
268 | extractR :: Matrix a -> Idxs -> Matrix a | 268 | extractR :: Matrix a -> CInt -> Idxs -> Matrix a |
269 | 269 | ||
270 | instance Element Float where | 270 | instance Element Float where |
271 | transdata = transdataAux ctransF | 271 | transdata = transdataAux ctransF |
@@ -444,23 +444,25 @@ tt x@Matrix{order = ColumnMajor} = trans x | |||
444 | tt x@Matrix{order = RowMajor} = x | 444 | tt x@Matrix{order = RowMajor} = x |
445 | 445 | ||
446 | --extractAux :: Matrix Double -> Idxs -> Matrix Double | 446 | --extractAux :: Matrix Double -> Idxs -> Matrix Double |
447 | extractAux f m v = unsafePerformIO $ do | 447 | extractAux f m mode v = unsafePerformIO $ do |
448 | r <- createMatrix RowMajor (dim v) (cols m) | 448 | let nr | mode == 0 = fromIntegral $ max 0 (v@>1 - v@>0 + 1) |
449 | app3 (f (isT m)) vec v mat (tt m) mat r "extractAux" | 449 | | otherwise = dim v |
450 | r <- createMatrix RowMajor nr (cols m) | ||
451 | app3 (f mode (isT m)) vec v mat (tt m) mat r "extractAux" | ||
450 | return r | 452 | return r |
451 | 453 | ||
452 | foreign import ccall unsafe "extractRD" c_extractRD | 454 | foreign import ccall unsafe "extractRD" c_extractRD |
453 | :: CInt -> CIdxs (CM Double (CM Double (IO CInt))) | 455 | :: CInt -> CInt -> CIdxs (CM Double (CM Double (IO CInt))) |
454 | 456 | ||
455 | foreign import ccall unsafe "extractRF" c_extractRF | 457 | foreign import ccall unsafe "extractRF" c_extractRF |
456 | :: CInt -> CIdxs (CM Float (CM Float (IO CInt))) | 458 | :: CInt -> CInt -> CIdxs (CM Float (CM Float (IO CInt))) |
457 | 459 | ||
458 | foreign import ccall unsafe "extractRC" c_extractRC | 460 | foreign import ccall unsafe "extractRC" c_extractRC |
459 | :: CInt -> CIdxs (CM (Complex Double) (CM (Complex Double) (IO CInt))) | 461 | :: CInt -> CInt -> CIdxs (CM (Complex Double) (CM (Complex Double) (IO CInt))) |
460 | 462 | ||
461 | foreign import ccall unsafe "extractRQ" c_extractRQ | 463 | foreign import ccall unsafe "extractRQ" c_extractRQ |
462 | :: CInt -> CIdxs (CM (Complex Float) (CM (Complex Float) (IO CInt))) | 464 | :: CInt -> CInt -> CIdxs (CM (Complex Float) (CM (Complex Float) (IO CInt))) |
463 | 465 | ||
464 | foreign import ccall unsafe "extractRI" c_extractRI | 466 | foreign import ccall unsafe "extractRI" c_extractRI |
465 | :: CInt -> CIdxs (CM CInt (CM CInt (IO CInt))) | 467 | :: CInt -> CInt -> CIdxs (CM CInt (CM CInt (IO CInt))) |
466 | 468 | ||
diff --git a/packages/base/src/Data/Packed/Internal/Numeric.hs b/packages/base/src/Data/Packed/Internal/Numeric.hs index 9b6b55b..f1b4898 100644 --- a/packages/base/src/Data/Packed/Internal/Numeric.hs +++ b/packages/base/src/Data/Packed/Internal/Numeric.hs | |||
@@ -39,7 +39,7 @@ module Data.Packed.Internal.Numeric ( | |||
39 | roundVector, | 39 | roundVector, |
40 | RealOf, ComplexOf, SingleOf, DoubleOf, | 40 | RealOf, ComplexOf, SingleOf, DoubleOf, |
41 | IndexOf, | 41 | IndexOf, |
42 | CInt, | 42 | CInt, Extractor(..), (??),(¿¿), |
43 | module Data.Complex | 43 | module Data.Complex |
44 | ) where | 44 | ) where |
45 | 45 | ||
@@ -53,6 +53,7 @@ import Data.Complex | |||
53 | import Numeric.LinearAlgebra.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ) | 53 | import Numeric.LinearAlgebra.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ) |
54 | import Data.Packed.Internal | 54 | import Data.Packed.Internal |
55 | import Foreign.C.Types(CInt) | 55 | import Foreign.C.Types(CInt) |
56 | import Text.Printf(printf) | ||
56 | 57 | ||
57 | ------------------------------------------------------------------- | 58 | ------------------------------------------------------------------- |
58 | 59 | ||
@@ -66,8 +67,53 @@ type family ArgOf (c :: * -> *) a | |||
66 | type instance ArgOf Vector a = a -> a | 67 | type instance ArgOf Vector a = a -> a |
67 | type instance ArgOf Matrix a = a -> a -> a | 68 | type instance ArgOf Matrix a = a -> a -> a |
68 | 69 | ||
70 | -------------------------------------------------------------------------- | ||
71 | |||
72 | data Extractor = All | Range Int Int | At [Int] | AtCyc [Int] | Take Int | Drop Int | ||
73 | |||
74 | idxs js = fromList (map fromIntegral js) :: Idxs | ||
75 | |||
76 | infixl 9 ??, ¿¿ | ||
77 | (??),(¿¿) :: Element t => Matrix t -> Extractor -> Matrix t | ||
78 | |||
79 | m ?? All = m | ||
80 | m ?? Take 0 = (0><cols m) [] | ||
81 | m ?? Take n | abs n >= rows m = m | ||
82 | m ?? Drop 0 = m | ||
83 | m ?? Drop n | abs n >= rows m = (0><cols m) [] | ||
84 | m ?? Range a b | a > b = m ?? Take 0 | ||
85 | m ?? Range a b | a < 0 || b >= cols m = error $ | ||
86 | printf "can't extract rows %d to %d from matrix %dx%d" a b (rows m) (cols m) | ||
87 | m ?? At ps | minimum ps < 0 || maximum ps >= rows m = error $ | ||
88 | printf "can't extract rows %s from matrix %dx%d" (show ps) (rows m) (cols m) | ||
89 | |||
90 | m ?? er = extractR m mode js | ||
91 | where | ||
92 | (mode,js) = mkExt (rows m) er | ||
93 | ran a b = (0, idxs [a,b]) | ||
94 | pos ks = (1, idxs ks) | ||
95 | mkExt _ (At ks) = pos ks | ||
96 | mkExt n (AtCyc ks) = pos (map (`mod` n) ks) | ||
97 | mkExt n All = ran 0 (n-1) | ||
98 | mkExt _ (Range mn mx) = ran mn mx | ||
99 | mkExt n (Take k) | ||
100 | | k >= 0 = ran 0 (k-1) | ||
101 | | otherwise = mkExt n (Drop (n+k)) | ||
102 | mkExt n (Drop k) | ||
103 | | k >= 0 = ran k (n-1) | ||
104 | | otherwise = mkExt n (Take (n+k)) | ||
105 | |||
106 | |||
107 | m ¿¿ Range a b | a < 0 || b > cols m -1 = error $ | ||
108 | printf "can't extract columns %d to %d from matrix %dx%d" a b (rows m) (cols m) | ||
109 | |||
110 | m ¿¿ At ps | minimum ps < 0 || maximum ps >= cols m = error $ | ||
111 | printf "can't extract columns %s from matrix %dx%d" (show ps) (rows m) (cols m) | ||
112 | m ¿¿ ec = trans (trans m ?? ec) | ||
113 | |||
69 | ------------------------------------------------------------------- | 114 | ------------------------------------------------------------------- |
70 | 115 | ||
116 | |||
71 | -- | Basic element-by-element functions for numeric containers | 117 | -- | Basic element-by-element functions for numeric containers |
72 | class Element e => SContainer c e | 118 | class Element e => SContainer c e |
73 | where | 119 | where |
@@ -123,6 +169,7 @@ class (Complexable c, Fractional e, SContainer c e) => Container c e | |||
123 | -- element by element inverse tangent | 169 | -- element by element inverse tangent |
124 | arctan2' :: c e -> c e -> c e | 170 | arctan2' :: c e -> c e -> c e |
125 | 171 | ||
172 | |||
126 | -------------------------------------------------------------------------- | 173 | -------------------------------------------------------------------------- |
127 | 174 | ||
128 | instance SContainer Vector CInt | 175 | instance SContainer Vector CInt |
@@ -245,14 +292,14 @@ instance SContainer Vector (Complex Double) | |||
245 | accum' = accumV | 292 | accum' = accumV |
246 | cond' = undefined -- cannot match | 293 | cond' = undefined -- cannot match |
247 | 294 | ||
248 | 295 | ||
249 | instance Container Vector (Complex Double) | 296 | instance Container Vector (Complex Double) |
250 | where | 297 | where |
251 | scaleRecip = vectorMapValC Recip | 298 | scaleRecip = vectorMapValC Recip |
252 | divide = vectorZipC Div | 299 | divide = vectorZipC Div |
253 | arctan2' = vectorZipC ATan2 | 300 | arctan2' = vectorZipC ATan2 |
254 | conj' = conjugateC | 301 | conj' = conjugateC |
255 | 302 | ||
256 | 303 | ||
257 | instance SContainer Vector (Complex Float) | 304 | instance SContainer Vector (Complex Float) |
258 | where | 305 | where |