diff options
author | Alberto Ruiz <aruiz@um.es> | 2015-05-24 12:45:23 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2015-05-24 12:45:23 +0200 |
commit | 8ede2ed162f3d00172ee3fa4835e3ee2184bcd99 (patch) | |
tree | cdb3025dc5e469603d32d4e200cc753d3502c6d8 /packages/base/src/Data | |
parent | b1b445697db31b1603a31747ca31151f97ee7263 (diff) |
joint extractor of rows and columns
Diffstat (limited to 'packages/base/src/Data')
-rw-r--r-- | packages/base/src/Data/Packed/Internal/Matrix.hs | 44 | ||||
-rw-r--r-- | packages/base/src/Data/Packed/Internal/Numeric.hs | 66 |
2 files changed, 63 insertions, 47 deletions
diff --git a/packages/base/src/Data/Packed/Internal/Matrix.hs b/packages/base/src/Data/Packed/Internal/Matrix.hs index 1aee7d3..76d2204 100644 --- a/packages/base/src/Data/Packed/Internal/Matrix.hs +++ b/packages/base/src/Data/Packed/Internal/Matrix.hs | |||
@@ -265,32 +265,32 @@ class (Storable a) => Element a where | |||
265 | transdata = transdataP -- transdata' | 265 | transdata = transdataP -- transdata' |
266 | constantD :: a -> Int -> Vector a | 266 | constantD :: a -> Int -> Vector a |
267 | constantD = constantP -- constant' | 267 | constantD = constantP -- constant' |
268 | extractR :: Matrix a -> CInt -> Idxs -> Matrix a | 268 | extractR :: Matrix a -> CInt -> Idxs -> CInt -> Idxs -> Matrix a |
269 | 269 | ||
270 | instance Element Float where | 270 | instance Element Float where |
271 | transdata = transdataAux ctransF | 271 | transdata = transdataAux ctransF |
272 | constantD = constantAux cconstantF | 272 | constantD = constantAux cconstantF |
273 | extractR = extractAux c_extractRF | 273 | extractR = extractAux c_extractF |
274 | 274 | ||
275 | instance Element Double where | 275 | instance Element Double where |
276 | transdata = transdataAux ctransR | 276 | transdata = transdataAux ctransR |
277 | constantD = constantAux cconstantR | 277 | constantD = constantAux cconstantR |
278 | extractR = extractAux c_extractRD | 278 | extractR = extractAux c_extractD |
279 | 279 | ||
280 | instance Element (Complex Float) where | 280 | instance Element (Complex Float) where |
281 | transdata = transdataAux ctransQ | 281 | transdata = transdataAux ctransQ |
282 | constantD = constantAux cconstantQ | 282 | constantD = constantAux cconstantQ |
283 | extractR = extractAux c_extractRQ | 283 | extractR = extractAux c_extractQ |
284 | 284 | ||
285 | instance Element (Complex Double) where | 285 | instance Element (Complex Double) where |
286 | transdata = transdataAux ctransC | 286 | transdata = transdataAux ctransC |
287 | constantD = constantAux cconstantC | 287 | constantD = constantAux cconstantC |
288 | extractR = extractAux c_extractRC | 288 | extractR = extractAux c_extractC |
289 | 289 | ||
290 | instance Element (CInt) where | 290 | instance Element (CInt) where |
291 | transdata = transdataAux ctransI | 291 | transdata = transdataAux ctransI |
292 | constantD = constantAux cconstantI | 292 | constantD = constantAux cconstantI |
293 | extractR = extractAux c_extractRI | 293 | extractR = extractAux c_extractI |
294 | 294 | ||
295 | 295 | ||
296 | ------------------------------------------------------------------- | 296 | ------------------------------------------------------------------- |
@@ -443,26 +443,26 @@ isT Matrix{order = RowMajor} = 0 | |||
443 | tt x@Matrix{order = ColumnMajor} = trans x | 443 | tt x@Matrix{order = ColumnMajor} = trans x |
444 | tt x@Matrix{order = RowMajor} = x | 444 | tt x@Matrix{order = RowMajor} = x |
445 | 445 | ||
446 | --extractAux :: Matrix Double -> Idxs -> Matrix Double | 446 | |
447 | extractAux f m mode v = unsafePerformIO $ do | 447 | extractAux f m moder vr modec vc = unsafePerformIO $ do |
448 | let nr | mode == 0 = fromIntegral $ max 0 (v@>1 - v@>0 + 1) | 448 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr |
449 | | otherwise = dim v | 449 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc |
450 | r <- createMatrix RowMajor nr (cols m) | 450 | r <- createMatrix RowMajor nr nc |
451 | app3 (f mode (isT m)) vec v mat (tt m) mat r "extractAux" | 451 | app4 (f moder modec (isT m)) vec vr vec vc mat (tt m) mat r "extractAux" |
452 | return r | 452 | return r |
453 | 453 | ||
454 | foreign import ccall unsafe "extractRD" c_extractRD | 454 | foreign import ccall unsafe "extractD" c_extractD |
455 | :: CInt -> CInt -> CIdxs (CM Double (CM Double (IO CInt))) | 455 | :: CInt -> CInt -> CInt -> CIdxs (CIdxs (CM Double (CM Double (IO CInt)))) |
456 | 456 | ||
457 | foreign import ccall unsafe "extractRF" c_extractRF | 457 | foreign import ccall unsafe "extractF" c_extractF |
458 | :: CInt -> CInt -> CIdxs (CM Float (CM Float (IO CInt))) | 458 | :: CInt -> CInt -> CInt -> CIdxs (CIdxs (CM Float (CM Float (IO CInt)))) |
459 | 459 | ||
460 | foreign import ccall unsafe "extractRC" c_extractRC | 460 | foreign import ccall unsafe "extractC" c_extractC |
461 | :: CInt -> CInt -> CIdxs (CM (Complex Double) (CM (Complex Double) (IO CInt))) | 461 | :: CInt -> CInt -> CInt -> CIdxs (CIdxs (CM (Complex Double) (CM (Complex Double) (IO CInt)))) |
462 | 462 | ||
463 | foreign import ccall unsafe "extractRQ" c_extractRQ | 463 | foreign import ccall unsafe "extractQ" c_extractQ |
464 | :: CInt -> CInt -> CIdxs (CM (Complex Float) (CM (Complex Float) (IO CInt))) | 464 | :: CInt -> CInt -> CInt -> CIdxs (CIdxs (CM (Complex Float) (CM (Complex Float) (IO CInt)))) |
465 | 465 | ||
466 | foreign import ccall unsafe "extractRI" c_extractRI | 466 | foreign import ccall unsafe "extractI" c_extractI |
467 | :: CInt -> CInt -> CIdxs (CM CInt (CM CInt (IO CInt))) | 467 | :: CInt -> CInt -> CInt -> CIdxs (CIdxs (CM CInt (CM CInt (IO CInt)))) |
468 | 468 | ||
diff --git a/packages/base/src/Data/Packed/Internal/Numeric.hs b/packages/base/src/Data/Packed/Internal/Numeric.hs index 00ec70c..353877a 100644 --- a/packages/base/src/Data/Packed/Internal/Numeric.hs +++ b/packages/base/src/Data/Packed/Internal/Numeric.hs | |||
@@ -39,7 +39,7 @@ module Data.Packed.Internal.Numeric ( | |||
39 | roundVector, | 39 | roundVector, |
40 | RealOf, ComplexOf, SingleOf, DoubleOf, | 40 | RealOf, ComplexOf, SingleOf, DoubleOf, |
41 | IndexOf, | 41 | IndexOf, |
42 | CInt, Extractor(..), (??),(¿¿), | 42 | CInt, Extractor(..), (??), |
43 | module Data.Complex | 43 | module Data.Complex |
44 | ) where | 44 | ) where |
45 | 45 | ||
@@ -69,27 +69,50 @@ type instance ArgOf Matrix a = a -> a -> a | |||
69 | 69 | ||
70 | -------------------------------------------------------------------------- | 70 | -------------------------------------------------------------------------- |
71 | 71 | ||
72 | data Extractor = All | Range Int Int | At [Int] | AtCyc [Int] | Take Int | Drop Int | 72 | data Extractor |
73 | = All | ||
74 | | Range Int Int | ||
75 | | At [Int] | ||
76 | | AtCyc [Int] | ||
77 | | Take Int | ||
78 | | Drop Int | ||
79 | deriving Show | ||
73 | 80 | ||
74 | idxs js = fromList (map fromIntegral js) :: Idxs | 81 | idxs js = fromList (map fromIntegral js) :: Idxs |
75 | 82 | ||
76 | infixl 9 ??, ¿¿ | 83 | infixl 9 ?? |
77 | (??),(¿¿) :: Element t => Matrix t -> Extractor -> Matrix t | 84 | (??) :: Element t => Matrix t -> (Extractor,Extractor) -> Matrix t |
78 | 85 | ||
79 | m ?? All = m | 86 | |
80 | m ?? Take 0 = (0><cols m) [] | 87 | extractError m e = error $ printf "can't extract %s from matrix %dx%d" (show e) (rows m) (cols m) |
81 | m ?? Take n | abs n >= rows m = m | 88 | |
82 | m ?? Drop 0 = m | 89 | m ?? e@(Range a b,_) | a < 0 || b >= rows m = extractError m e |
83 | m ?? Drop n | abs n >= rows m = (0><cols m) [] | 90 | m ?? e@(_,Range a b) | a < 0 || b >= cols m = extractError m e |
84 | m ?? Range a b | a > b = m ?? Take 0 | 91 | m ?? e@(At ps,_) | minimum ps < 0 || maximum ps >= rows m = extractError m e |
85 | m ?? Range a b | a < 0 || b >= cols m = error $ | 92 | m ?? e@(_,At ps) | minimum ps < 0 || maximum ps >= cols m = extractError m e |
86 | printf "can't extract rows %d to %d from matrix %dx%d" a b (rows m) (cols m) | 93 | |
87 | m ?? At ps | minimum ps < 0 || maximum ps >= rows m = error $ | 94 | m ?? (All,All) = m |
88 | printf "can't extract rows %s from matrix %dx%d" (show ps) (rows m) (cols m) | 95 | |
89 | 96 | m ?? (Range a b,e) | a > b = m ?? (Take 0,e) | |
90 | m ?? er = extractR m mode js | 97 | m ?? (e,Range a b) | a > b = m ?? (e,Take 0) |
98 | |||
99 | m ?? (Take 0,e) = (0><cols m) [] ?? (All,e) | ||
100 | m ?? (e,Take 0) = (rows m><0) [] ?? (e,All) | ||
101 | |||
102 | m ?? (Take n,e) | abs n > rows m = m ?? (All,e) | ||
103 | m ?? (e,Take n) | abs n > cols m = m ?? (e,All) | ||
104 | |||
105 | m ?? (Drop 0,e) = m ?? (All,e) | ||
106 | m ?? (e,Drop 0) = m ?? (e,All) | ||
107 | |||
108 | m ?? (Drop n,e) | abs n > rows m = m ?? (Take 0,e) | ||
109 | m ?? (e,Drop n) | abs n > cols m = m ?? (e,Take 0) | ||
110 | |||
111 | |||
112 | m ?? (er,ec) = extractR m moder rs modec cs | ||
91 | where | 113 | where |
92 | (mode,js) = mkExt (rows m) er | 114 | (moder,rs) = mkExt (rows m) er |
115 | (modec,cs) = mkExt (cols m) ec | ||
93 | ran a b = (0, idxs [a,b]) | 116 | ran a b = (0, idxs [a,b]) |
94 | pos ks = (1, idxs ks) | 117 | pos ks = (1, idxs ks) |
95 | mkExt _ (At ks) = pos ks | 118 | mkExt _ (At ks) = pos ks |
@@ -104,13 +127,6 @@ m ?? er = extractR m mode js | |||
104 | | otherwise = mkExt n (Take (n+k)) | 127 | | otherwise = mkExt n (Take (n+k)) |
105 | 128 | ||
106 | 129 | ||
107 | m ¿¿ Range a b | a < 0 || b > cols m -1 = error $ | ||
108 | printf "can't extract columns %d to %d from matrix %dx%d" a b (rows m) (cols m) | ||
109 | |||
110 | m ¿¿ At ps | minimum ps < 0 || maximum ps >= cols m = error $ | ||
111 | printf "can't extract columns %s from matrix %dx%d" (show ps) (rows m) (cols m) | ||
112 | m ¿¿ ec = trans (trans m ?? ec) | ||
113 | |||
114 | ------------------------------------------------------------------- | 130 | ------------------------------------------------------------------- |
115 | 131 | ||
116 | 132 | ||