summaryrefslogtreecommitdiff
path: root/packages/base/src/Data/Packed
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2015-05-24 12:45:23 +0200
committerAlberto Ruiz <aruiz@um.es>2015-05-24 12:45:23 +0200
commit8ede2ed162f3d00172ee3fa4835e3ee2184bcd99 (patch)
treecdb3025dc5e469603d32d4e200cc753d3502c6d8 /packages/base/src/Data/Packed
parentb1b445697db31b1603a31747ca31151f97ee7263 (diff)
joint extractor of rows and columns
Diffstat (limited to 'packages/base/src/Data/Packed')
-rw-r--r--packages/base/src/Data/Packed/Internal/Matrix.hs44
-rw-r--r--packages/base/src/Data/Packed/Internal/Numeric.hs66
2 files changed, 63 insertions, 47 deletions
diff --git a/packages/base/src/Data/Packed/Internal/Matrix.hs b/packages/base/src/Data/Packed/Internal/Matrix.hs
index 1aee7d3..76d2204 100644
--- a/packages/base/src/Data/Packed/Internal/Matrix.hs
+++ b/packages/base/src/Data/Packed/Internal/Matrix.hs
@@ -265,32 +265,32 @@ class (Storable a) => Element a where
265 transdata = transdataP -- transdata' 265 transdata = transdataP -- transdata'
266 constantD :: a -> Int -> Vector a 266 constantD :: a -> Int -> Vector a
267 constantD = constantP -- constant' 267 constantD = constantP -- constant'
268 extractR :: Matrix a -> CInt -> Idxs -> Matrix a 268 extractR :: Matrix a -> CInt -> Idxs -> CInt -> Idxs -> Matrix a
269 269
270instance Element Float where 270instance Element Float where
271 transdata = transdataAux ctransF 271 transdata = transdataAux ctransF
272 constantD = constantAux cconstantF 272 constantD = constantAux cconstantF
273 extractR = extractAux c_extractRF 273 extractR = extractAux c_extractF
274 274
275instance Element Double where 275instance Element Double where
276 transdata = transdataAux ctransR 276 transdata = transdataAux ctransR
277 constantD = constantAux cconstantR 277 constantD = constantAux cconstantR
278 extractR = extractAux c_extractRD 278 extractR = extractAux c_extractD
279 279
280instance Element (Complex Float) where 280instance Element (Complex Float) where
281 transdata = transdataAux ctransQ 281 transdata = transdataAux ctransQ
282 constantD = constantAux cconstantQ 282 constantD = constantAux cconstantQ
283 extractR = extractAux c_extractRQ 283 extractR = extractAux c_extractQ
284 284
285instance Element (Complex Double) where 285instance Element (Complex Double) where
286 transdata = transdataAux ctransC 286 transdata = transdataAux ctransC
287 constantD = constantAux cconstantC 287 constantD = constantAux cconstantC
288 extractR = extractAux c_extractRC 288 extractR = extractAux c_extractC
289 289
290instance Element (CInt) where 290instance Element (CInt) where
291 transdata = transdataAux ctransI 291 transdata = transdataAux ctransI
292 constantD = constantAux cconstantI 292 constantD = constantAux cconstantI
293 extractR = extractAux c_extractRI 293 extractR = extractAux c_extractI
294 294
295 295
296------------------------------------------------------------------- 296-------------------------------------------------------------------
@@ -443,26 +443,26 @@ isT Matrix{order = RowMajor} = 0
443tt x@Matrix{order = ColumnMajor} = trans x 443tt x@Matrix{order = ColumnMajor} = trans x
444tt x@Matrix{order = RowMajor} = x 444tt x@Matrix{order = RowMajor} = x
445 445
446--extractAux :: Matrix Double -> Idxs -> Matrix Double 446
447extractAux f m mode v = unsafePerformIO $ do 447extractAux f m moder vr modec vc = unsafePerformIO $ do
448 let nr | mode == 0 = fromIntegral $ max 0 (v@>1 - v@>0 + 1) 448 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr
449 | otherwise = dim v 449 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc
450 r <- createMatrix RowMajor nr (cols m) 450 r <- createMatrix RowMajor nr nc
451 app3 (f mode (isT m)) vec v mat (tt m) mat r "extractAux" 451 app4 (f moder modec (isT m)) vec vr vec vc mat (tt m) mat r "extractAux"
452 return r 452 return r
453 453
454foreign import ccall unsafe "extractRD" c_extractRD 454foreign import ccall unsafe "extractD" c_extractD
455 :: CInt -> CInt -> CIdxs (CM Double (CM Double (IO CInt))) 455 :: CInt -> CInt -> CInt -> CIdxs (CIdxs (CM Double (CM Double (IO CInt))))
456 456
457foreign import ccall unsafe "extractRF" c_extractRF 457foreign import ccall unsafe "extractF" c_extractF
458 :: CInt -> CInt -> CIdxs (CM Float (CM Float (IO CInt))) 458 :: CInt -> CInt -> CInt -> CIdxs (CIdxs (CM Float (CM Float (IO CInt))))
459 459
460foreign import ccall unsafe "extractRC" c_extractRC 460foreign import ccall unsafe "extractC" c_extractC
461 :: CInt -> CInt -> CIdxs (CM (Complex Double) (CM (Complex Double) (IO CInt))) 461 :: CInt -> CInt -> CInt -> CIdxs (CIdxs (CM (Complex Double) (CM (Complex Double) (IO CInt))))
462 462
463foreign import ccall unsafe "extractRQ" c_extractRQ 463foreign import ccall unsafe "extractQ" c_extractQ
464 :: CInt -> CInt -> CIdxs (CM (Complex Float) (CM (Complex Float) (IO CInt))) 464 :: CInt -> CInt -> CInt -> CIdxs (CIdxs (CM (Complex Float) (CM (Complex Float) (IO CInt))))
465 465
466foreign import ccall unsafe "extractRI" c_extractRI 466foreign import ccall unsafe "extractI" c_extractI
467 :: CInt -> CInt -> CIdxs (CM CInt (CM CInt (IO CInt))) 467 :: CInt -> CInt -> CInt -> CIdxs (CIdxs (CM CInt (CM CInt (IO CInt))))
468 468
diff --git a/packages/base/src/Data/Packed/Internal/Numeric.hs b/packages/base/src/Data/Packed/Internal/Numeric.hs
index 00ec70c..353877a 100644
--- a/packages/base/src/Data/Packed/Internal/Numeric.hs
+++ b/packages/base/src/Data/Packed/Internal/Numeric.hs
@@ -39,7 +39,7 @@ module Data.Packed.Internal.Numeric (
39 roundVector, 39 roundVector,
40 RealOf, ComplexOf, SingleOf, DoubleOf, 40 RealOf, ComplexOf, SingleOf, DoubleOf,
41 IndexOf, 41 IndexOf,
42 CInt, Extractor(..), (??),(¿¿), 42 CInt, Extractor(..), (??),
43 module Data.Complex 43 module Data.Complex
44) where 44) where
45 45
@@ -69,27 +69,50 @@ type instance ArgOf Matrix a = a -> a -> a
69 69
70-------------------------------------------------------------------------- 70--------------------------------------------------------------------------
71 71
72data Extractor = All | Range Int Int | At [Int] | AtCyc [Int] | Take Int | Drop Int 72data Extractor
73 = All
74 | Range Int Int
75 | At [Int]
76 | AtCyc [Int]
77 | Take Int
78 | Drop Int
79 deriving Show
73 80
74idxs js = fromList (map fromIntegral js) :: Idxs 81idxs js = fromList (map fromIntegral js) :: Idxs
75 82
76infixl 9 ??, ¿¿ 83infixl 9 ??
77(??),(¿¿) :: Element t => Matrix t -> Extractor -> Matrix t 84(??) :: Element t => Matrix t -> (Extractor,Extractor) -> Matrix t
78 85
79m ?? All = m 86
80m ?? Take 0 = (0><cols m) [] 87extractError m e = error $ printf "can't extract %s from matrix %dx%d" (show e) (rows m) (cols m)
81m ?? Take n | abs n >= rows m = m 88
82m ?? Drop 0 = m 89m ?? e@(Range a b,_) | a < 0 || b >= rows m = extractError m e
83m ?? Drop n | abs n >= rows m = (0><cols m) [] 90m ?? e@(_,Range a b) | a < 0 || b >= cols m = extractError m e
84m ?? Range a b | a > b = m ?? Take 0 91m ?? e@(At ps,_) | minimum ps < 0 || maximum ps >= rows m = extractError m e
85m ?? Range a b | a < 0 || b >= cols m = error $ 92m ?? e@(_,At ps) | minimum ps < 0 || maximum ps >= cols m = extractError m e
86 printf "can't extract rows %d to %d from matrix %dx%d" a b (rows m) (cols m) 93
87m ?? At ps | minimum ps < 0 || maximum ps >= rows m = error $ 94m ?? (All,All) = m
88 printf "can't extract rows %s from matrix %dx%d" (show ps) (rows m) (cols m) 95
89 96m ?? (Range a b,e) | a > b = m ?? (Take 0,e)
90m ?? er = extractR m mode js 97m ?? (e,Range a b) | a > b = m ?? (e,Take 0)
98
99m ?? (Take 0,e) = (0><cols m) [] ?? (All,e)
100m ?? (e,Take 0) = (rows m><0) [] ?? (e,All)
101
102m ?? (Take n,e) | abs n > rows m = m ?? (All,e)
103m ?? (e,Take n) | abs n > cols m = m ?? (e,All)
104
105m ?? (Drop 0,e) = m ?? (All,e)
106m ?? (e,Drop 0) = m ?? (e,All)
107
108m ?? (Drop n,e) | abs n > rows m = m ?? (Take 0,e)
109m ?? (e,Drop n) | abs n > cols m = m ?? (e,Take 0)
110
111
112m ?? (er,ec) = extractR m moder rs modec cs
91 where 113 where
92 (mode,js) = mkExt (rows m) er 114 (moder,rs) = mkExt (rows m) er
115 (modec,cs) = mkExt (cols m) ec
93 ran a b = (0, idxs [a,b]) 116 ran a b = (0, idxs [a,b])
94 pos ks = (1, idxs ks) 117 pos ks = (1, idxs ks)
95 mkExt _ (At ks) = pos ks 118 mkExt _ (At ks) = pos ks
@@ -104,13 +127,6 @@ m ?? er = extractR m mode js
104 | otherwise = mkExt n (Take (n+k)) 127 | otherwise = mkExt n (Take (n+k))
105 128
106 129
107m ¿¿ Range a b | a < 0 || b > cols m -1 = error $
108 printf "can't extract columns %d to %d from matrix %dx%d" a b (rows m) (cols m)
109
110m ¿¿ At ps | minimum ps < 0 || maximum ps >= cols m = error $
111 printf "can't extract columns %s from matrix %dx%d" (show ps) (rows m) (cols m)
112m ¿¿ ec = trans (trans m ?? ec)
113
114------------------------------------------------------------------- 130-------------------------------------------------------------------
115 131
116 132