summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2015-05-22 12:00:42 +0200
committerAlberto Ruiz <aruiz@um.es>2015-05-22 12:00:42 +0200
commite635f3889aed9b4bf7ef02c98945e9065d114df3 (patch)
tree1ae46c334e55052ec1eaa3b3c20bf5becdec164c
parent0c8a545e67ebef2b3a4e376fac8f23a651dfbe6d (diff)
extraction modes
-rw-r--r--packages/base/src/C/lapack-aux.c56
-rw-r--r--packages/base/src/Data/Packed/Internal/Matrix.hs20
-rw-r--r--packages/base/src/Data/Packed/Internal/Numeric.hs53
3 files changed, 99 insertions, 30 deletions
diff --git a/packages/base/src/C/lapack-aux.c b/packages/base/src/C/lapack-aux.c
index d56d466..e76d31e 100644
--- a/packages/base/src/C/lapack-aux.c
+++ b/packages/base/src/C/lapack-aux.c
@@ -1623,44 +1623,64 @@ int chooseD(KIVEC(cond),KDVEC(lt),KDVEC(eq),KDVEC(gt),DVEC(r)) {
1623//////////////////////// extract ///////////////////////////////// 1623//////////////////////// extract /////////////////////////////////
1624 1624
1625#define EXTRACT_IMP \ 1625#define EXTRACT_IMP \
1626 REQUIRES((tm == 0 && jn==rr && mc==rc) || (jn==rr && mr==rc) ,BAD_SIZE); \ 1626 /*REQUIRES((tm == 0 && jn==rr && mc==rc) || (jn==rr && mr==rc) ,BAD_SIZE); */ \
1627 DEBUGMSG("extractRD") \ 1627 DEBUGMSG("extractRD") \
1628 int k,i,s; \ 1628 int k,i,s; \
1629 if (tm==0) { \ 1629 if (tm==0) { \
1630 for (k=0;k<jn;k++) { \ 1630 if (mode==0) { \
1631 s = jp[k]; \ 1631 for (k=0; k<jp[1]-jp[0]+1; k++) { \
1632 for (i=0; i<mc; i++) { \ 1632 s = k + jp[0]; \
1633 rp[rc*k+i] = mp[mc*s+i]; \ 1633 printf("%d\n",s); \
1634 } \ 1634 for (i=0; i<mc; i++) { \
1635 } \ 1635 rp[rc*k+i] = mp[mc*s+i]; \
1636 } \
1637 } \
1638 } else { \
1639 for (k=0;k<jn;k++) { \
1640 s = jp[k]; \
1641 for (i=0; i<mc; i++) { \
1642 rp[rc*k+i] = mp[mc*s+i]; \
1643 } \
1644 } \
1645 } \
1636 } else { \ 1646 } else { \
1637 for (k=0;k<jn;k++) { \ 1647 if (mode==0) { \
1638 s = jp[k]; \ 1648 for (k=0; k<jp[1]-jp[0]+1; k++) { \
1639 for (i=0; i<mr; i++) { \ 1649 s = k + jp[0]; \
1640 rp[rc*k+i] = mp[mc*i+s]; \ 1650 printf("%d\n",s); \
1641 } \ 1651 for (i=0; i<mr; i++) { \
1642 } \ 1652 rp[rc*k+i] = mp[mc*i+s]; \
1653 } \
1654 } \
1655 } else { \
1656 for (k=0;k<jn;k++) { \
1657 s = jp[k]; \
1658 for (i=0; i<mr; i++) { \
1659 rp[rc*k+i] = mp[mc*i+s]; \
1660 } \
1661 } \
1662 } \
1643 } \ 1663 } \
1644 OK 1664 OK
1645 1665
1646 1666
1647int extractRD(int tm, KIVEC(j), KDMAT(m), DMAT(r)) { 1667int extractRD(int mode, int tm, KIVEC(j), KDMAT(m), DMAT(r)) {
1648 EXTRACT_IMP 1668 EXTRACT_IMP
1649} 1669}
1650 1670
1651int extractRF(int tm, KIVEC(j), KFMAT(m), FMAT(r)) { 1671int extractRF(int mode, int tm, KIVEC(j), KFMAT(m), FMAT(r)) {
1652 EXTRACT_IMP 1672 EXTRACT_IMP
1653} 1673}
1654 1674
1655int extractRC(int tm, KIVEC(j), KCMAT(m), CMAT(r)) { 1675int extractRC(int mode, int tm, KIVEC(j), KCMAT(m), CMAT(r)) {
1656 EXTRACT_IMP 1676 EXTRACT_IMP
1657} 1677}
1658 1678
1659int extractRQ(int tm, KIVEC(j), KQMAT(m), QMAT(r)) { 1679int extractRQ(int mode, int tm, KIVEC(j), KQMAT(m), QMAT(r)) {
1660 EXTRACT_IMP 1680 EXTRACT_IMP
1661} 1681}
1662 1682
1663int extractRI(int tm, KIVEC(j), KIMAT(m), IMAT(r)) { 1683int extractRI(int mode, int tm, KIVEC(j), KIMAT(m), IMAT(r)) {
1664 EXTRACT_IMP 1684 EXTRACT_IMP
1665} 1685}
1666 1686
diff --git a/packages/base/src/Data/Packed/Internal/Matrix.hs b/packages/base/src/Data/Packed/Internal/Matrix.hs
index be5fb03..1aee7d3 100644
--- a/packages/base/src/Data/Packed/Internal/Matrix.hs
+++ b/packages/base/src/Data/Packed/Internal/Matrix.hs
@@ -265,7 +265,7 @@ class (Storable a) => Element a where
265 transdata = transdataP -- transdata' 265 transdata = transdataP -- transdata'
266 constantD :: a -> Int -> Vector a 266 constantD :: a -> Int -> Vector a
267 constantD = constantP -- constant' 267 constantD = constantP -- constant'
268 extractR :: Matrix a -> Idxs -> Matrix a 268 extractR :: Matrix a -> CInt -> Idxs -> Matrix a
269 269
270instance Element Float where 270instance Element Float where
271 transdata = transdataAux ctransF 271 transdata = transdataAux ctransF
@@ -444,23 +444,25 @@ tt x@Matrix{order = ColumnMajor} = trans x
444tt x@Matrix{order = RowMajor} = x 444tt x@Matrix{order = RowMajor} = x
445 445
446--extractAux :: Matrix Double -> Idxs -> Matrix Double 446--extractAux :: Matrix Double -> Idxs -> Matrix Double
447extractAux f m v = unsafePerformIO $ do 447extractAux f m mode v = unsafePerformIO $ do
448 r <- createMatrix RowMajor (dim v) (cols m) 448 let nr | mode == 0 = fromIntegral $ max 0 (v@>1 - v@>0 + 1)
449 app3 (f (isT m)) vec v mat (tt m) mat r "extractAux" 449 | otherwise = dim v
450 r <- createMatrix RowMajor nr (cols m)
451 app3 (f mode (isT m)) vec v mat (tt m) mat r "extractAux"
450 return r 452 return r
451 453
452foreign import ccall unsafe "extractRD" c_extractRD 454foreign import ccall unsafe "extractRD" c_extractRD
453 :: CInt -> CIdxs (CM Double (CM Double (IO CInt))) 455 :: CInt -> CInt -> CIdxs (CM Double (CM Double (IO CInt)))
454 456
455foreign import ccall unsafe "extractRF" c_extractRF 457foreign import ccall unsafe "extractRF" c_extractRF
456 :: CInt -> CIdxs (CM Float (CM Float (IO CInt))) 458 :: CInt -> CInt -> CIdxs (CM Float (CM Float (IO CInt)))
457 459
458foreign import ccall unsafe "extractRC" c_extractRC 460foreign import ccall unsafe "extractRC" c_extractRC
459 :: CInt -> CIdxs (CM (Complex Double) (CM (Complex Double) (IO CInt))) 461 :: CInt -> CInt -> CIdxs (CM (Complex Double) (CM (Complex Double) (IO CInt)))
460 462
461foreign import ccall unsafe "extractRQ" c_extractRQ 463foreign import ccall unsafe "extractRQ" c_extractRQ
462 :: CInt -> CIdxs (CM (Complex Float) (CM (Complex Float) (IO CInt))) 464 :: CInt -> CInt -> CIdxs (CM (Complex Float) (CM (Complex Float) (IO CInt)))
463 465
464foreign import ccall unsafe "extractRI" c_extractRI 466foreign import ccall unsafe "extractRI" c_extractRI
465 :: CInt -> CIdxs (CM CInt (CM CInt (IO CInt))) 467 :: CInt -> CInt -> CIdxs (CM CInt (CM CInt (IO CInt)))
466 468
diff --git a/packages/base/src/Data/Packed/Internal/Numeric.hs b/packages/base/src/Data/Packed/Internal/Numeric.hs
index 9b6b55b..f1b4898 100644
--- a/packages/base/src/Data/Packed/Internal/Numeric.hs
+++ b/packages/base/src/Data/Packed/Internal/Numeric.hs
@@ -39,7 +39,7 @@ module Data.Packed.Internal.Numeric (
39 roundVector, 39 roundVector,
40 RealOf, ComplexOf, SingleOf, DoubleOf, 40 RealOf, ComplexOf, SingleOf, DoubleOf,
41 IndexOf, 41 IndexOf,
42 CInt, 42 CInt, Extractor(..), (??),(¿¿),
43 module Data.Complex 43 module Data.Complex
44) where 44) where
45 45
@@ -53,6 +53,7 @@ import Data.Complex
53import Numeric.LinearAlgebra.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ) 53import Numeric.LinearAlgebra.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ)
54import Data.Packed.Internal 54import Data.Packed.Internal
55import Foreign.C.Types(CInt) 55import Foreign.C.Types(CInt)
56import Text.Printf(printf)
56 57
57------------------------------------------------------------------- 58-------------------------------------------------------------------
58 59
@@ -66,8 +67,53 @@ type family ArgOf (c :: * -> *) a
66type instance ArgOf Vector a = a -> a 67type instance ArgOf Vector a = a -> a
67type instance ArgOf Matrix a = a -> a -> a 68type instance ArgOf Matrix a = a -> a -> a
68 69
70--------------------------------------------------------------------------
71
72data Extractor = All | Range Int Int | At [Int] | AtCyc [Int] | Take Int | Drop Int
73
74idxs js = fromList (map fromIntegral js) :: Idxs
75
76infixl 9 ??, ¿¿
77(??),(¿¿) :: Element t => Matrix t -> Extractor -> Matrix t
78
79m ?? All = m
80m ?? Take 0 = (0><cols m) []
81m ?? Take n | abs n >= rows m = m
82m ?? Drop 0 = m
83m ?? Drop n | abs n >= rows m = (0><cols m) []
84m ?? Range a b | a > b = m ?? Take 0
85m ?? Range a b | a < 0 || b >= cols m = error $
86 printf "can't extract rows %d to %d from matrix %dx%d" a b (rows m) (cols m)
87m ?? At ps | minimum ps < 0 || maximum ps >= rows m = error $
88 printf "can't extract rows %s from matrix %dx%d" (show ps) (rows m) (cols m)
89
90m ?? er = extractR m mode js
91 where
92 (mode,js) = mkExt (rows m) er
93 ran a b = (0, idxs [a,b])
94 pos ks = (1, idxs ks)
95 mkExt _ (At ks) = pos ks
96 mkExt n (AtCyc ks) = pos (map (`mod` n) ks)
97 mkExt n All = ran 0 (n-1)
98 mkExt _ (Range mn mx) = ran mn mx
99 mkExt n (Take k)
100 | k >= 0 = ran 0 (k-1)
101 | otherwise = mkExt n (Drop (n+k))
102 mkExt n (Drop k)
103 | k >= 0 = ran k (n-1)
104 | otherwise = mkExt n (Take (n+k))
105
106
107m ¿¿ Range a b | a < 0 || b > cols m -1 = error $
108 printf "can't extract columns %d to %d from matrix %dx%d" a b (rows m) (cols m)
109
110m ¿¿ At ps | minimum ps < 0 || maximum ps >= cols m = error $
111 printf "can't extract columns %s from matrix %dx%d" (show ps) (rows m) (cols m)
112m ¿¿ ec = trans (trans m ?? ec)
113
69------------------------------------------------------------------- 114-------------------------------------------------------------------
70 115
116
71-- | Basic element-by-element functions for numeric containers 117-- | Basic element-by-element functions for numeric containers
72class Element e => SContainer c e 118class Element e => SContainer c e
73 where 119 where
@@ -123,6 +169,7 @@ class (Complexable c, Fractional e, SContainer c e) => Container c e
123 -- element by element inverse tangent 169 -- element by element inverse tangent
124 arctan2' :: c e -> c e -> c e 170 arctan2' :: c e -> c e -> c e
125 171
172
126-------------------------------------------------------------------------- 173--------------------------------------------------------------------------
127 174
128instance SContainer Vector CInt 175instance SContainer Vector CInt
@@ -245,14 +292,14 @@ instance SContainer Vector (Complex Double)
245 accum' = accumV 292 accum' = accumV
246 cond' = undefined -- cannot match 293 cond' = undefined -- cannot match
247 294
248 295
249instance Container Vector (Complex Double) 296instance Container Vector (Complex Double)
250 where 297 where
251 scaleRecip = vectorMapValC Recip 298 scaleRecip = vectorMapValC Recip
252 divide = vectorZipC Div 299 divide = vectorZipC Div
253 arctan2' = vectorZipC ATan2 300 arctan2' = vectorZipC ATan2
254 conj' = conjugateC 301 conj' = conjugateC
255 302
256 303
257instance SContainer Vector (Complex Float) 304instance SContainer Vector (Complex Float)
258 where 305 where