diff options
Diffstat (limited to 'packages/base/src/Data')
-rw-r--r-- | packages/base/src/Data/Packed/Internal/Matrix.hs | 20 | ||||
-rw-r--r-- | packages/base/src/Data/Packed/Internal/Numeric.hs | 53 |
2 files changed, 61 insertions, 12 deletions
diff --git a/packages/base/src/Data/Packed/Internal/Matrix.hs b/packages/base/src/Data/Packed/Internal/Matrix.hs index be5fb03..1aee7d3 100644 --- a/packages/base/src/Data/Packed/Internal/Matrix.hs +++ b/packages/base/src/Data/Packed/Internal/Matrix.hs | |||
@@ -265,7 +265,7 @@ class (Storable a) => Element a where | |||
265 | transdata = transdataP -- transdata' | 265 | transdata = transdataP -- transdata' |
266 | constantD :: a -> Int -> Vector a | 266 | constantD :: a -> Int -> Vector a |
267 | constantD = constantP -- constant' | 267 | constantD = constantP -- constant' |
268 | extractR :: Matrix a -> Idxs -> Matrix a | 268 | extractR :: Matrix a -> CInt -> Idxs -> Matrix a |
269 | 269 | ||
270 | instance Element Float where | 270 | instance Element Float where |
271 | transdata = transdataAux ctransF | 271 | transdata = transdataAux ctransF |
@@ -444,23 +444,25 @@ tt x@Matrix{order = ColumnMajor} = trans x | |||
444 | tt x@Matrix{order = RowMajor} = x | 444 | tt x@Matrix{order = RowMajor} = x |
445 | 445 | ||
446 | --extractAux :: Matrix Double -> Idxs -> Matrix Double | 446 | --extractAux :: Matrix Double -> Idxs -> Matrix Double |
447 | extractAux f m v = unsafePerformIO $ do | 447 | extractAux f m mode v = unsafePerformIO $ do |
448 | r <- createMatrix RowMajor (dim v) (cols m) | 448 | let nr | mode == 0 = fromIntegral $ max 0 (v@>1 - v@>0 + 1) |
449 | app3 (f (isT m)) vec v mat (tt m) mat r "extractAux" | 449 | | otherwise = dim v |
450 | r <- createMatrix RowMajor nr (cols m) | ||
451 | app3 (f mode (isT m)) vec v mat (tt m) mat r "extractAux" | ||
450 | return r | 452 | return r |
451 | 453 | ||
452 | foreign import ccall unsafe "extractRD" c_extractRD | 454 | foreign import ccall unsafe "extractRD" c_extractRD |
453 | :: CInt -> CIdxs (CM Double (CM Double (IO CInt))) | 455 | :: CInt -> CInt -> CIdxs (CM Double (CM Double (IO CInt))) |
454 | 456 | ||
455 | foreign import ccall unsafe "extractRF" c_extractRF | 457 | foreign import ccall unsafe "extractRF" c_extractRF |
456 | :: CInt -> CIdxs (CM Float (CM Float (IO CInt))) | 458 | :: CInt -> CInt -> CIdxs (CM Float (CM Float (IO CInt))) |
457 | 459 | ||
458 | foreign import ccall unsafe "extractRC" c_extractRC | 460 | foreign import ccall unsafe "extractRC" c_extractRC |
459 | :: CInt -> CIdxs (CM (Complex Double) (CM (Complex Double) (IO CInt))) | 461 | :: CInt -> CInt -> CIdxs (CM (Complex Double) (CM (Complex Double) (IO CInt))) |
460 | 462 | ||
461 | foreign import ccall unsafe "extractRQ" c_extractRQ | 463 | foreign import ccall unsafe "extractRQ" c_extractRQ |
462 | :: CInt -> CIdxs (CM (Complex Float) (CM (Complex Float) (IO CInt))) | 464 | :: CInt -> CInt -> CIdxs (CM (Complex Float) (CM (Complex Float) (IO CInt))) |
463 | 465 | ||
464 | foreign import ccall unsafe "extractRI" c_extractRI | 466 | foreign import ccall unsafe "extractRI" c_extractRI |
465 | :: CInt -> CIdxs (CM CInt (CM CInt (IO CInt))) | 467 | :: CInt -> CInt -> CIdxs (CM CInt (CM CInt (IO CInt))) |
466 | 468 | ||
diff --git a/packages/base/src/Data/Packed/Internal/Numeric.hs b/packages/base/src/Data/Packed/Internal/Numeric.hs index 9b6b55b..f1b4898 100644 --- a/packages/base/src/Data/Packed/Internal/Numeric.hs +++ b/packages/base/src/Data/Packed/Internal/Numeric.hs | |||
@@ -39,7 +39,7 @@ module Data.Packed.Internal.Numeric ( | |||
39 | roundVector, | 39 | roundVector, |
40 | RealOf, ComplexOf, SingleOf, DoubleOf, | 40 | RealOf, ComplexOf, SingleOf, DoubleOf, |
41 | IndexOf, | 41 | IndexOf, |
42 | CInt, | 42 | CInt, Extractor(..), (??),(¿¿), |
43 | module Data.Complex | 43 | module Data.Complex |
44 | ) where | 44 | ) where |
45 | 45 | ||
@@ -53,6 +53,7 @@ import Data.Complex | |||
53 | import Numeric.LinearAlgebra.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ) | 53 | import Numeric.LinearAlgebra.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ) |
54 | import Data.Packed.Internal | 54 | import Data.Packed.Internal |
55 | import Foreign.C.Types(CInt) | 55 | import Foreign.C.Types(CInt) |
56 | import Text.Printf(printf) | ||
56 | 57 | ||
57 | ------------------------------------------------------------------- | 58 | ------------------------------------------------------------------- |
58 | 59 | ||
@@ -66,8 +67,53 @@ type family ArgOf (c :: * -> *) a | |||
66 | type instance ArgOf Vector a = a -> a | 67 | type instance ArgOf Vector a = a -> a |
67 | type instance ArgOf Matrix a = a -> a -> a | 68 | type instance ArgOf Matrix a = a -> a -> a |
68 | 69 | ||
70 | -------------------------------------------------------------------------- | ||
71 | |||
72 | data Extractor = All | Range Int Int | At [Int] | AtCyc [Int] | Take Int | Drop Int | ||
73 | |||
74 | idxs js = fromList (map fromIntegral js) :: Idxs | ||
75 | |||
76 | infixl 9 ??, ¿¿ | ||
77 | (??),(¿¿) :: Element t => Matrix t -> Extractor -> Matrix t | ||
78 | |||
79 | m ?? All = m | ||
80 | m ?? Take 0 = (0><cols m) [] | ||
81 | m ?? Take n | abs n >= rows m = m | ||
82 | m ?? Drop 0 = m | ||
83 | m ?? Drop n | abs n >= rows m = (0><cols m) [] | ||
84 | m ?? Range a b | a > b = m ?? Take 0 | ||
85 | m ?? Range a b | a < 0 || b >= cols m = error $ | ||
86 | printf "can't extract rows %d to %d from matrix %dx%d" a b (rows m) (cols m) | ||
87 | m ?? At ps | minimum ps < 0 || maximum ps >= rows m = error $ | ||
88 | printf "can't extract rows %s from matrix %dx%d" (show ps) (rows m) (cols m) | ||
89 | |||
90 | m ?? er = extractR m mode js | ||
91 | where | ||
92 | (mode,js) = mkExt (rows m) er | ||
93 | ran a b = (0, idxs [a,b]) | ||
94 | pos ks = (1, idxs ks) | ||
95 | mkExt _ (At ks) = pos ks | ||
96 | mkExt n (AtCyc ks) = pos (map (`mod` n) ks) | ||
97 | mkExt n All = ran 0 (n-1) | ||
98 | mkExt _ (Range mn mx) = ran mn mx | ||
99 | mkExt n (Take k) | ||
100 | | k >= 0 = ran 0 (k-1) | ||
101 | | otherwise = mkExt n (Drop (n+k)) | ||
102 | mkExt n (Drop k) | ||
103 | | k >= 0 = ran k (n-1) | ||
104 | | otherwise = mkExt n (Take (n+k)) | ||
105 | |||
106 | |||
107 | m ¿¿ Range a b | a < 0 || b > cols m -1 = error $ | ||
108 | printf "can't extract columns %d to %d from matrix %dx%d" a b (rows m) (cols m) | ||
109 | |||
110 | m ¿¿ At ps | minimum ps < 0 || maximum ps >= cols m = error $ | ||
111 | printf "can't extract columns %s from matrix %dx%d" (show ps) (rows m) (cols m) | ||
112 | m ¿¿ ec = trans (trans m ?? ec) | ||
113 | |||
69 | ------------------------------------------------------------------- | 114 | ------------------------------------------------------------------- |
70 | 115 | ||
116 | |||
71 | -- | Basic element-by-element functions for numeric containers | 117 | -- | Basic element-by-element functions for numeric containers |
72 | class Element e => SContainer c e | 118 | class Element e => SContainer c e |
73 | where | 119 | where |
@@ -123,6 +169,7 @@ class (Complexable c, Fractional e, SContainer c e) => Container c e | |||
123 | -- element by element inverse tangent | 169 | -- element by element inverse tangent |
124 | arctan2' :: c e -> c e -> c e | 170 | arctan2' :: c e -> c e -> c e |
125 | 171 | ||
172 | |||
126 | -------------------------------------------------------------------------- | 173 | -------------------------------------------------------------------------- |
127 | 174 | ||
128 | instance SContainer Vector CInt | 175 | instance SContainer Vector CInt |
@@ -245,14 +292,14 @@ instance SContainer Vector (Complex Double) | |||
245 | accum' = accumV | 292 | accum' = accumV |
246 | cond' = undefined -- cannot match | 293 | cond' = undefined -- cannot match |
247 | 294 | ||
248 | 295 | ||
249 | instance Container Vector (Complex Double) | 296 | instance Container Vector (Complex Double) |
250 | where | 297 | where |
251 | scaleRecip = vectorMapValC Recip | 298 | scaleRecip = vectorMapValC Recip |
252 | divide = vectorZipC Div | 299 | divide = vectorZipC Div |
253 | arctan2' = vectorZipC ATan2 | 300 | arctan2' = vectorZipC ATan2 |
254 | conj' = conjugateC | 301 | conj' = conjugateC |
255 | 302 | ||
256 | 303 | ||
257 | instance SContainer Vector (Complex Float) | 304 | instance SContainer Vector (Complex Float) |
258 | where | 305 | where |