diff options
Diffstat (limited to 'packages/base/src/Data/Packed')
-rw-r--r-- | packages/base/src/Data/Packed/Internal/Matrix.hs | 53 | ||||
-rw-r--r-- | packages/base/src/Data/Packed/Internal/Numeric.hs | 44 | ||||
-rw-r--r-- | packages/base/src/Data/Packed/Numeric.hs | 11 |
3 files changed, 104 insertions, 4 deletions
diff --git a/packages/base/src/Data/Packed/Internal/Matrix.hs b/packages/base/src/Data/Packed/Internal/Matrix.hs index 1679ea6..82a9d8f 100644 --- a/packages/base/src/Data/Packed/Internal/Matrix.hs +++ b/packages/base/src/Data/Packed/Internal/Matrix.hs | |||
@@ -268,6 +268,9 @@ class (Storable a) => Element a where | |||
268 | extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> Matrix a | 268 | extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> Matrix a |
269 | sortI :: Ord a => Vector a -> Vector CInt | 269 | sortI :: Ord a => Vector a -> Vector CInt |
270 | sortV :: Ord a => Vector a -> Vector a | 270 | sortV :: Ord a => Vector a -> Vector a |
271 | compareV :: Ord a => Vector a -> Vector a -> Vector CInt | ||
272 | selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a | ||
273 | |||
271 | 274 | ||
272 | instance Element Float where | 275 | instance Element Float where |
273 | transdata = transdataAux ctransF | 276 | transdata = transdataAux ctransF |
@@ -275,6 +278,9 @@ instance Element Float where | |||
275 | extractR = extractAux c_extractF | 278 | extractR = extractAux c_extractF |
276 | sortI = sortIdxF | 279 | sortI = sortIdxF |
277 | sortV = sortValF | 280 | sortV = sortValF |
281 | compareV = compareF | ||
282 | selectV = selectF | ||
283 | |||
278 | 284 | ||
279 | instance Element Double where | 285 | instance Element Double where |
280 | transdata = transdataAux ctransR | 286 | transdata = transdataAux ctransR |
@@ -282,6 +288,9 @@ instance Element Double where | |||
282 | extractR = extractAux c_extractD | 288 | extractR = extractAux c_extractD |
283 | sortI = sortIdxD | 289 | sortI = sortIdxD |
284 | sortV = sortValD | 290 | sortV = sortValD |
291 | compareV = compareD | ||
292 | selectV = selectD | ||
293 | |||
285 | 294 | ||
286 | instance Element (Complex Float) where | 295 | instance Element (Complex Float) where |
287 | transdata = transdataAux ctransQ | 296 | transdata = transdataAux ctransQ |
@@ -289,6 +298,9 @@ instance Element (Complex Float) where | |||
289 | extractR = extractAux c_extractQ | 298 | extractR = extractAux c_extractQ |
290 | sortI = undefined | 299 | sortI = undefined |
291 | sortV = undefined | 300 | sortV = undefined |
301 | compareV = undefined | ||
302 | selectV = selectQ | ||
303 | |||
292 | 304 | ||
293 | instance Element (Complex Double) where | 305 | instance Element (Complex Double) where |
294 | transdata = transdataAux ctransC | 306 | transdata = transdataAux ctransC |
@@ -296,6 +308,9 @@ instance Element (Complex Double) where | |||
296 | extractR = extractAux c_extractC | 308 | extractR = extractAux c_extractC |
297 | sortI = undefined | 309 | sortI = undefined |
298 | sortV = undefined | 310 | sortV = undefined |
311 | compareV = undefined | ||
312 | selectV = selectC | ||
313 | |||
299 | 314 | ||
300 | instance Element (CInt) where | 315 | instance Element (CInt) where |
301 | transdata = transdataAux ctransI | 316 | transdata = transdataAux ctransI |
@@ -303,6 +318,9 @@ instance Element (CInt) where | |||
303 | extractR = extractAux c_extractI | 318 | extractR = extractAux c_extractI |
304 | sortI = sortIdxI | 319 | sortI = sortIdxI |
305 | sortV = sortValI | 320 | sortV = sortValI |
321 | compareV = compareI | ||
322 | selectV = selectI | ||
323 | |||
306 | 324 | ||
307 | ------------------------------------------------------------------- | 325 | ------------------------------------------------------------------- |
308 | 326 | ||
@@ -502,3 +520,38 @@ foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV CInt (CV CInt (IO | |||
502 | 520 | ||
503 | -------------------------------------------------------------------------------- | 521 | -------------------------------------------------------------------------------- |
504 | 522 | ||
523 | compareG f u v = unsafePerformIO $ do | ||
524 | r <- createVector (dim v) | ||
525 | app3 f vec u vec v vec r "compareG" | ||
526 | return r | ||
527 | |||
528 | compareD = compareG c_compareD | ||
529 | compareF = compareG c_compareF | ||
530 | compareI = compareG c_compareI | ||
531 | |||
532 | foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) | ||
533 | foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV CInt (IO CInt))) | ||
534 | foreign import ccall unsafe "compareI" c_compareI :: CV CInt (CV CInt (CV CInt (IO CInt))) | ||
535 | |||
536 | -------------------------------------------------------------------------------- | ||
537 | |||
538 | selectG f c u v w = unsafePerformIO $ do | ||
539 | r <- createVector (dim v) | ||
540 | app5 f vec c vec u vec v vec w vec r "selectG" | ||
541 | return r | ||
542 | |||
543 | selectD = selectG c_selectD | ||
544 | selectF = selectG c_selectF | ||
545 | selectI = selectG c_selectI | ||
546 | selectC = selectG c_selectC | ||
547 | selectQ = selectG c_selectQ | ||
548 | |||
549 | type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) | ||
550 | |||
551 | foreign import ccall unsafe "chooseD" c_selectD :: Sel Double | ||
552 | foreign import ccall unsafe "chooseF" c_selectF :: Sel Float | ||
553 | foreign import ccall unsafe "chooseI" c_selectI :: Sel CInt | ||
554 | foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double) | ||
555 | foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float) | ||
556 | |||
557 | |||
diff --git a/packages/base/src/Data/Packed/Internal/Numeric.hs b/packages/base/src/Data/Packed/Internal/Numeric.hs index 51bee5c..a241c48 100644 --- a/packages/base/src/Data/Packed/Internal/Numeric.hs +++ b/packages/base/src/Data/Packed/Internal/Numeric.hs | |||
@@ -36,7 +36,7 @@ module Data.Packed.Internal.Numeric ( | |||
36 | Convert(..), | 36 | Convert(..), |
37 | Complexable(), | 37 | Complexable(), |
38 | RealElement(), | 38 | RealElement(), |
39 | roundVector, fromInt, | 39 | roundVector, fromInt, toInt, |
40 | RealOf, ComplexOf, SingleOf, DoubleOf, | 40 | RealOf, ComplexOf, SingleOf, DoubleOf, |
41 | IndexOf, | 41 | IndexOf, |
42 | I, Extractor(..), (??), range, idxs, | 42 | I, Extractor(..), (??), range, idxs, |
@@ -171,6 +171,8 @@ class Element e => Container c e | |||
171 | -> c e -- ^ e | 171 | -> c e -- ^ e |
172 | -> c e -- ^ g | 172 | -> c e -- ^ g |
173 | -> c e -- ^ result | 173 | -> c e -- ^ result |
174 | ccompare' :: Ord e => c e -> c e -> c I | ||
175 | cselect' :: c I -> c e -> c e -> c e -> c e | ||
174 | find' :: (e -> Bool) -> c e -> [IndexOf c] | 176 | find' :: (e -> Bool) -> c e -> [IndexOf c] |
175 | assoc' :: IndexOf c -- ^ size | 177 | assoc' :: IndexOf c -- ^ size |
176 | -> e -- ^ default value | 178 | -> e -- ^ default value |
@@ -192,6 +194,7 @@ class Element e => Container c e | |||
192 | arctan2' :: Fractional e => c e -> c e -> c e | 194 | arctan2' :: Fractional e => c e -> c e -> c e |
193 | cmod' :: Integral e => e -> c e -> c e | 195 | cmod' :: Integral e => e -> c e -> c e |
194 | fromInt' :: c I -> c e | 196 | fromInt' :: c I -> c e |
197 | toInt' :: c e -> c I | ||
195 | 198 | ||
196 | 199 | ||
197 | -------------------------------------------------------------------------- | 200 | -------------------------------------------------------------------------- |
@@ -222,6 +225,8 @@ instance Container Vector I | |||
222 | assoc' = assocV | 225 | assoc' = assocV |
223 | accum' = accumV | 226 | accum' = accumV |
224 | cond' = condV condI | 227 | cond' = condV condI |
228 | ccompare' = compareCV compareV | ||
229 | cselect' = selectCV selectV | ||
225 | scaleRecip = undefined -- cannot match | 230 | scaleRecip = undefined -- cannot match |
226 | divide = undefined | 231 | divide = undefined |
227 | arctan2' = undefined | 232 | arctan2' = undefined |
@@ -229,6 +234,7 @@ instance Container Vector I | |||
229 | | m /= 0 = vectorMapValI ModVS m x | 234 | | m /= 0 = vectorMapValI ModVS m x |
230 | | otherwise = error $ "cmod 0 on vector of size "++(show $ dim x) | 235 | | otherwise = error $ "cmod 0 on vector of size "++(show $ dim x) |
231 | fromInt' = id | 236 | fromInt' = id |
237 | toInt' = id | ||
232 | 238 | ||
233 | instance Container Vector Float | 239 | instance Container Vector Float |
234 | where | 240 | where |
@@ -256,11 +262,14 @@ instance Container Vector Float | |||
256 | assoc' = assocV | 262 | assoc' = assocV |
257 | accum' = accumV | 263 | accum' = accumV |
258 | cond' = condV condF | 264 | cond' = condV condF |
265 | ccompare' = compareCV compareV | ||
266 | cselect' = selectCV selectV | ||
259 | scaleRecip = vectorMapValF Recip | 267 | scaleRecip = vectorMapValF Recip |
260 | divide = vectorZipF Div | 268 | divide = vectorZipF Div |
261 | arctan2' = vectorZipF ATan2 | 269 | arctan2' = vectorZipF ATan2 |
262 | cmod' = undefined | 270 | cmod' = undefined |
263 | fromInt' = int2floatV | 271 | fromInt' = int2floatV |
272 | toInt' = float2IntV | ||
264 | 273 | ||
265 | 274 | ||
266 | 275 | ||
@@ -290,11 +299,14 @@ instance Container Vector Double | |||
290 | assoc' = assocV | 299 | assoc' = assocV |
291 | accum' = accumV | 300 | accum' = accumV |
292 | cond' = condV condD | 301 | cond' = condV condD |
302 | ccompare' = compareCV compareV | ||
303 | cselect' = selectCV selectV | ||
293 | scaleRecip = vectorMapValR Recip | 304 | scaleRecip = vectorMapValR Recip |
294 | divide = vectorZipR Div | 305 | divide = vectorZipR Div |
295 | arctan2' = vectorZipR ATan2 | 306 | arctan2' = vectorZipR ATan2 |
296 | cmod' = undefined | 307 | cmod' = undefined |
297 | fromInt' = int2DoubleV | 308 | fromInt' = int2DoubleV |
309 | toInt' = double2IntV | ||
298 | 310 | ||
299 | 311 | ||
300 | instance Container Vector (Complex Double) | 312 | instance Container Vector (Complex Double) |
@@ -323,11 +335,14 @@ instance Container Vector (Complex Double) | |||
323 | assoc' = assocV | 335 | assoc' = assocV |
324 | accum' = accumV | 336 | accum' = accumV |
325 | cond' = undefined -- cannot match | 337 | cond' = undefined -- cannot match |
338 | ccompare' = undefined | ||
339 | cselect' = selectCV selectV | ||
326 | scaleRecip = vectorMapValC Recip | 340 | scaleRecip = vectorMapValC Recip |
327 | divide = vectorZipC Div | 341 | divide = vectorZipC Div |
328 | arctan2' = vectorZipC ATan2 | 342 | arctan2' = vectorZipC ATan2 |
329 | cmod' = undefined | 343 | cmod' = undefined |
330 | fromInt' = complex . int2DoubleV | 344 | fromInt' = complex . int2DoubleV |
345 | toInt' = toInt' . fst . fromComplex | ||
331 | 346 | ||
332 | instance Container Vector (Complex Float) | 347 | instance Container Vector (Complex Float) |
333 | where | 348 | where |
@@ -355,11 +370,14 @@ instance Container Vector (Complex Float) | |||
355 | assoc' = assocV | 370 | assoc' = assocV |
356 | accum' = accumV | 371 | accum' = accumV |
357 | cond' = undefined -- cannot match | 372 | cond' = undefined -- cannot match |
373 | ccompare' = undefined | ||
374 | cselect' = selectCV selectV | ||
358 | scaleRecip = vectorMapValQ Recip | 375 | scaleRecip = vectorMapValQ Recip |
359 | divide = vectorZipQ Div | 376 | divide = vectorZipQ Div |
360 | arctan2' = vectorZipQ ATan2 | 377 | arctan2' = vectorZipQ ATan2 |
361 | cmod' = undefined | 378 | cmod' = undefined |
362 | fromInt' = complex . int2floatV | 379 | fromInt' = complex . int2floatV |
380 | toInt' = toInt' . fst . fromComplex | ||
363 | 381 | ||
364 | --------------------------------------------------------------- | 382 | --------------------------------------------------------------- |
365 | 383 | ||
@@ -391,6 +409,8 @@ instance (Num a, Element a, Container Vector a) => Container Matrix a | |||
391 | assoc' = assocM | 409 | assoc' = assocM |
392 | accum' = accumM | 410 | accum' = accumM |
393 | cond' = condM | 411 | cond' = condM |
412 | ccompare' = compareM | ||
413 | cselect' = selectM | ||
394 | scaleRecip x = liftMatrix (scaleRecip x) | 414 | scaleRecip x = liftMatrix (scaleRecip x) |
395 | divide = liftMatrix2 divide | 415 | divide = liftMatrix2 divide |
396 | arctan2' = liftMatrix2 arctan2' | 416 | arctan2' = liftMatrix2 arctan2' |
@@ -398,6 +418,7 @@ instance (Num a, Element a, Container Vector a) => Container Matrix a | |||
398 | | m /= 0 = liftMatrix (cmod' m) x | 418 | | m /= 0 = liftMatrix (cmod' m) x |
399 | | otherwise = error $ "cmod 0 on matrix "++shSize x | 419 | | otherwise = error $ "cmod 0 on matrix "++shSize x |
400 | fromInt' = liftMatrix fromInt' | 420 | fromInt' = liftMatrix fromInt' |
421 | toInt' = liftMatrix toInt' | ||
401 | 422 | ||
402 | 423 | ||
403 | emptyErrorV msg f v = | 424 | emptyErrorV msg f v = |
@@ -448,6 +469,9 @@ cmod m = cmod' (fromIntegral m) | |||
448 | fromInt :: (Container c e) => c I -> c e | 469 | fromInt :: (Container c e) => c I -> c e |
449 | fromInt = fromInt' | 470 | fromInt = fromInt' |
450 | 471 | ||
472 | toInt :: (Container c e) => c e -> c I | ||
473 | toInt = toInt' | ||
474 | |||
451 | 475 | ||
452 | -- | like 'fmap' (cannot implement instance Functor because of Element class constraint) | 476 | -- | like 'fmap' (cannot implement instance Functor because of Element class constraint) |
453 | cmap :: (Element b, Container c e) => (e -> b) -> c e -> c b | 477 | cmap :: (Element b, Container c e) => (e -> b) -> c e -> c b |
@@ -852,6 +876,24 @@ condV f a b l e t = f a' b' l' e' t' | |||
852 | where | 876 | where |
853 | [a', b', l', e', t'] = conformVs [a,b,l,e,t] | 877 | [a', b', l', e', t'] = conformVs [a,b,l,e,t] |
854 | 878 | ||
879 | compareM a b = matrixFromVector RowMajor (rows a'') (cols a'') $ ccompare' a' b' | ||
880 | where | ||
881 | args@(a'':_) = conformMs [a,b] | ||
882 | [a', b'] = map flatten args | ||
883 | |||
884 | compareCV f a b = f a' b' | ||
885 | where | ||
886 | [a', b'] = conformVs [a,b] | ||
887 | |||
888 | selectM c l e t = matrixFromVector RowMajor (rows a'') (cols a'') $ cselect' (toInt c') l' e' t' | ||
889 | where | ||
890 | args@(a'':_) = conformMs [fromInt c,l,e,t] | ||
891 | [c', l', e', t'] = map flatten args | ||
892 | |||
893 | selectCV f c l e t = f (toInt c') l' e' t' | ||
894 | where | ||
895 | [c', l', e', t'] = conformVs [fromInt c,l,e,t] | ||
896 | |||
855 | -------------------------------------------------------------------------------- | 897 | -------------------------------------------------------------------------------- |
856 | 898 | ||
857 | class Transposable m mt | m -> mt, mt -> m | 899 | class Transposable m mt | m -> mt, mt -> m |
diff --git a/packages/base/src/Data/Packed/Numeric.hs b/packages/base/src/Data/Packed/Numeric.hs index cb449a9..906bc83 100644 --- a/packages/base/src/Data/Packed/Numeric.hs +++ b/packages/base/src/Data/Packed/Numeric.hs | |||
@@ -31,12 +31,12 @@ module Data.Packed.Numeric ( | |||
31 | diag, ident, | 31 | diag, ident, |
32 | ctrans, | 32 | ctrans, |
33 | -- * Generic operations | 33 | -- * Generic operations |
34 | Container(..), Numeric, | 34 | Container(..), Numeric, Extractor(..), (??), range, idxs, I, |
35 | -- add, mul, sub, divide, equal, scaleRecip, addConstant, | 35 | -- add, mul, sub, divide, equal, scaleRecip, addConstant, |
36 | scalar, conj, scale, arctan2, cmap, cmod, | 36 | scalar, conj, scale, arctan2, cmap, cmod, |
37 | atIndex, minIndex, maxIndex, minElement, maxElement, | 37 | atIndex, minIndex, maxIndex, minElement, maxElement, |
38 | sumElements, prodElements, | 38 | sumElements, prodElements, |
39 | step, cond, find, assoc, accum, | 39 | step, cond, find, assoc, accum, ccompare, cselect, |
40 | Transposable(..), Linear(..), | 40 | Transposable(..), Linear(..), |
41 | -- * Matrix product | 41 | -- * Matrix product |
42 | Product(..), udot, dot, (<·>), (#>), (<#), app, | 42 | Product(..), udot, dot, (<·>), (#>), (<#), app, |
@@ -58,7 +58,7 @@ module Data.Packed.Numeric ( | |||
58 | Complexable(), | 58 | Complexable(), |
59 | RealElement(), | 59 | RealElement(), |
60 | RealOf, ComplexOf, SingleOf, DoubleOf, | 60 | RealOf, ComplexOf, SingleOf, DoubleOf, |
61 | roundVector, | 61 | roundVector,fromInt,toInt, |
62 | IndexOf, | 62 | IndexOf, |
63 | module Data.Complex, | 63 | module Data.Complex, |
64 | -- * IO | 64 | -- * IO |
@@ -309,4 +309,9 @@ sortVector = sortV | |||
309 | sortIndex :: (Ord t, Element t) => Vector t -> Vector I | 309 | sortIndex :: (Ord t, Element t) => Vector t -> Vector I |
310 | sortIndex = sortI | 310 | sortIndex = sortI |
311 | 311 | ||
312 | ccompare :: (Ord t, Container c t) => c t -> c t -> c I | ||
313 | ccompare = ccompare' | ||
314 | |||
315 | cselect :: (Container c t) => c I -> c t -> c t -> c t -> c t | ||
316 | cselect = cselect' | ||
312 | 317 | ||