summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2015-05-27 09:10:22 +0200
committerAlberto Ruiz <aruiz@um.es>2015-05-27 09:10:22 +0200
commitc5795a191ded450987a30302c1d1fa4a265350ff (patch)
treee3f9f754de966189dab7e4cbdddf96a4750cace8
parentf3a044a6219bd098fe5d55ef427b9ae6fe360cb9 (diff)
ccompare, cselect, toInt
-rw-r--r--packages/base/src/C/lapack-aux.c7
-rw-r--r--packages/base/src/Data/Packed/Internal/Matrix.hs53
-rw-r--r--packages/base/src/Data/Packed/Internal/Numeric.hs44
-rw-r--r--packages/base/src/Data/Packed/Numeric.hs11
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Data.hs7
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Util.hs1
6 files changed, 115 insertions, 8 deletions
diff --git a/packages/base/src/C/lapack-aux.c b/packages/base/src/C/lapack-aux.c
index af515ca..77381cc 100644
--- a/packages/base/src/C/lapack-aux.c
+++ b/packages/base/src/C/lapack-aux.c
@@ -1619,6 +1619,13 @@ int chooseI(KIVEC(cond),KIVEC(lt),KIVEC(eq),KIVEC(gt),IVEC(r)) {
1619 CHOOSE_IMP 1619 CHOOSE_IMP
1620} 1620}
1621 1621
1622int chooseC(KIVEC(cond),KCVEC(lt),KCVEC(eq),KCVEC(gt),CVEC(r)) {
1623 CHOOSE_IMP
1624}
1625
1626int chooseQ(KIVEC(cond),KQVEC(lt),KQVEC(eq),KQVEC(gt),QVEC(r)) {
1627 CHOOSE_IMP
1628}
1622 1629
1623//////////////////////// extract ///////////////////////////////// 1630//////////////////////// extract /////////////////////////////////
1624 1631
diff --git a/packages/base/src/Data/Packed/Internal/Matrix.hs b/packages/base/src/Data/Packed/Internal/Matrix.hs
index 1679ea6..82a9d8f 100644
--- a/packages/base/src/Data/Packed/Internal/Matrix.hs
+++ b/packages/base/src/Data/Packed/Internal/Matrix.hs
@@ -268,6 +268,9 @@ class (Storable a) => Element a where
268 extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> Matrix a 268 extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> Matrix a
269 sortI :: Ord a => Vector a -> Vector CInt 269 sortI :: Ord a => Vector a -> Vector CInt
270 sortV :: Ord a => Vector a -> Vector a 270 sortV :: Ord a => Vector a -> Vector a
271 compareV :: Ord a => Vector a -> Vector a -> Vector CInt
272 selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a
273
271 274
272instance Element Float where 275instance Element Float where
273 transdata = transdataAux ctransF 276 transdata = transdataAux ctransF
@@ -275,6 +278,9 @@ instance Element Float where
275 extractR = extractAux c_extractF 278 extractR = extractAux c_extractF
276 sortI = sortIdxF 279 sortI = sortIdxF
277 sortV = sortValF 280 sortV = sortValF
281 compareV = compareF
282 selectV = selectF
283
278 284
279instance Element Double where 285instance Element Double where
280 transdata = transdataAux ctransR 286 transdata = transdataAux ctransR
@@ -282,6 +288,9 @@ instance Element Double where
282 extractR = extractAux c_extractD 288 extractR = extractAux c_extractD
283 sortI = sortIdxD 289 sortI = sortIdxD
284 sortV = sortValD 290 sortV = sortValD
291 compareV = compareD
292 selectV = selectD
293
285 294
286instance Element (Complex Float) where 295instance Element (Complex Float) where
287 transdata = transdataAux ctransQ 296 transdata = transdataAux ctransQ
@@ -289,6 +298,9 @@ instance Element (Complex Float) where
289 extractR = extractAux c_extractQ 298 extractR = extractAux c_extractQ
290 sortI = undefined 299 sortI = undefined
291 sortV = undefined 300 sortV = undefined
301 compareV = undefined
302 selectV = selectQ
303
292 304
293instance Element (Complex Double) where 305instance Element (Complex Double) where
294 transdata = transdataAux ctransC 306 transdata = transdataAux ctransC
@@ -296,6 +308,9 @@ instance Element (Complex Double) where
296 extractR = extractAux c_extractC 308 extractR = extractAux c_extractC
297 sortI = undefined 309 sortI = undefined
298 sortV = undefined 310 sortV = undefined
311 compareV = undefined
312 selectV = selectC
313
299 314
300instance Element (CInt) where 315instance Element (CInt) where
301 transdata = transdataAux ctransI 316 transdata = transdataAux ctransI
@@ -303,6 +318,9 @@ instance Element (CInt) where
303 extractR = extractAux c_extractI 318 extractR = extractAux c_extractI
304 sortI = sortIdxI 319 sortI = sortIdxI
305 sortV = sortValI 320 sortV = sortValI
321 compareV = compareI
322 selectV = selectI
323
306 324
307------------------------------------------------------------------- 325-------------------------------------------------------------------
308 326
@@ -502,3 +520,38 @@ foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV CInt (CV CInt (IO
502 520
503-------------------------------------------------------------------------------- 521--------------------------------------------------------------------------------
504 522
523compareG f u v = unsafePerformIO $ do
524 r <- createVector (dim v)
525 app3 f vec u vec v vec r "compareG"
526 return r
527
528compareD = compareG c_compareD
529compareF = compareG c_compareF
530compareI = compareG c_compareI
531
532foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt)))
533foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV CInt (IO CInt)))
534foreign import ccall unsafe "compareI" c_compareI :: CV CInt (CV CInt (CV CInt (IO CInt)))
535
536--------------------------------------------------------------------------------
537
538selectG f c u v w = unsafePerformIO $ do
539 r <- createVector (dim v)
540 app5 f vec c vec u vec v vec w vec r "selectG"
541 return r
542
543selectD = selectG c_selectD
544selectF = selectG c_selectF
545selectI = selectG c_selectI
546selectC = selectG c_selectC
547selectQ = selectG c_selectQ
548
549type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt)))))
550
551foreign import ccall unsafe "chooseD" c_selectD :: Sel Double
552foreign import ccall unsafe "chooseF" c_selectF :: Sel Float
553foreign import ccall unsafe "chooseI" c_selectI :: Sel CInt
554foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double)
555foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float)
556
557
diff --git a/packages/base/src/Data/Packed/Internal/Numeric.hs b/packages/base/src/Data/Packed/Internal/Numeric.hs
index 51bee5c..a241c48 100644
--- a/packages/base/src/Data/Packed/Internal/Numeric.hs
+++ b/packages/base/src/Data/Packed/Internal/Numeric.hs
@@ -36,7 +36,7 @@ module Data.Packed.Internal.Numeric (
36 Convert(..), 36 Convert(..),
37 Complexable(), 37 Complexable(),
38 RealElement(), 38 RealElement(),
39 roundVector, fromInt, 39 roundVector, fromInt, toInt,
40 RealOf, ComplexOf, SingleOf, DoubleOf, 40 RealOf, ComplexOf, SingleOf, DoubleOf,
41 IndexOf, 41 IndexOf,
42 I, Extractor(..), (??), range, idxs, 42 I, Extractor(..), (??), range, idxs,
@@ -171,6 +171,8 @@ class Element e => Container c e
171 -> c e -- ^ e 171 -> c e -- ^ e
172 -> c e -- ^ g 172 -> c e -- ^ g
173 -> c e -- ^ result 173 -> c e -- ^ result
174 ccompare' :: Ord e => c e -> c e -> c I
175 cselect' :: c I -> c e -> c e -> c e -> c e
174 find' :: (e -> Bool) -> c e -> [IndexOf c] 176 find' :: (e -> Bool) -> c e -> [IndexOf c]
175 assoc' :: IndexOf c -- ^ size 177 assoc' :: IndexOf c -- ^ size
176 -> e -- ^ default value 178 -> e -- ^ default value
@@ -192,6 +194,7 @@ class Element e => Container c e
192 arctan2' :: Fractional e => c e -> c e -> c e 194 arctan2' :: Fractional e => c e -> c e -> c e
193 cmod' :: Integral e => e -> c e -> c e 195 cmod' :: Integral e => e -> c e -> c e
194 fromInt' :: c I -> c e 196 fromInt' :: c I -> c e
197 toInt' :: c e -> c I
195 198
196 199
197-------------------------------------------------------------------------- 200--------------------------------------------------------------------------
@@ -222,6 +225,8 @@ instance Container Vector I
222 assoc' = assocV 225 assoc' = assocV
223 accum' = accumV 226 accum' = accumV
224 cond' = condV condI 227 cond' = condV condI
228 ccompare' = compareCV compareV
229 cselect' = selectCV selectV
225 scaleRecip = undefined -- cannot match 230 scaleRecip = undefined -- cannot match
226 divide = undefined 231 divide = undefined
227 arctan2' = undefined 232 arctan2' = undefined
@@ -229,6 +234,7 @@ instance Container Vector I
229 | m /= 0 = vectorMapValI ModVS m x 234 | m /= 0 = vectorMapValI ModVS m x
230 | otherwise = error $ "cmod 0 on vector of size "++(show $ dim x) 235 | otherwise = error $ "cmod 0 on vector of size "++(show $ dim x)
231 fromInt' = id 236 fromInt' = id
237 toInt' = id
232 238
233instance Container Vector Float 239instance Container Vector Float
234 where 240 where
@@ -256,11 +262,14 @@ instance Container Vector Float
256 assoc' = assocV 262 assoc' = assocV
257 accum' = accumV 263 accum' = accumV
258 cond' = condV condF 264 cond' = condV condF
265 ccompare' = compareCV compareV
266 cselect' = selectCV selectV
259 scaleRecip = vectorMapValF Recip 267 scaleRecip = vectorMapValF Recip
260 divide = vectorZipF Div 268 divide = vectorZipF Div
261 arctan2' = vectorZipF ATan2 269 arctan2' = vectorZipF ATan2
262 cmod' = undefined 270 cmod' = undefined
263 fromInt' = int2floatV 271 fromInt' = int2floatV
272 toInt' = float2IntV
264 273
265 274
266 275
@@ -290,11 +299,14 @@ instance Container Vector Double
290 assoc' = assocV 299 assoc' = assocV
291 accum' = accumV 300 accum' = accumV
292 cond' = condV condD 301 cond' = condV condD
302 ccompare' = compareCV compareV
303 cselect' = selectCV selectV
293 scaleRecip = vectorMapValR Recip 304 scaleRecip = vectorMapValR Recip
294 divide = vectorZipR Div 305 divide = vectorZipR Div
295 arctan2' = vectorZipR ATan2 306 arctan2' = vectorZipR ATan2
296 cmod' = undefined 307 cmod' = undefined
297 fromInt' = int2DoubleV 308 fromInt' = int2DoubleV
309 toInt' = double2IntV
298 310
299 311
300instance Container Vector (Complex Double) 312instance Container Vector (Complex Double)
@@ -323,11 +335,14 @@ instance Container Vector (Complex Double)
323 assoc' = assocV 335 assoc' = assocV
324 accum' = accumV 336 accum' = accumV
325 cond' = undefined -- cannot match 337 cond' = undefined -- cannot match
338 ccompare' = undefined
339 cselect' = selectCV selectV
326 scaleRecip = vectorMapValC Recip 340 scaleRecip = vectorMapValC Recip
327 divide = vectorZipC Div 341 divide = vectorZipC Div
328 arctan2' = vectorZipC ATan2 342 arctan2' = vectorZipC ATan2
329 cmod' = undefined 343 cmod' = undefined
330 fromInt' = complex . int2DoubleV 344 fromInt' = complex . int2DoubleV
345 toInt' = toInt' . fst . fromComplex
331 346
332instance Container Vector (Complex Float) 347instance Container Vector (Complex Float)
333 where 348 where
@@ -355,11 +370,14 @@ instance Container Vector (Complex Float)
355 assoc' = assocV 370 assoc' = assocV
356 accum' = accumV 371 accum' = accumV
357 cond' = undefined -- cannot match 372 cond' = undefined -- cannot match
373 ccompare' = undefined
374 cselect' = selectCV selectV
358 scaleRecip = vectorMapValQ Recip 375 scaleRecip = vectorMapValQ Recip
359 divide = vectorZipQ Div 376 divide = vectorZipQ Div
360 arctan2' = vectorZipQ ATan2 377 arctan2' = vectorZipQ ATan2
361 cmod' = undefined 378 cmod' = undefined
362 fromInt' = complex . int2floatV 379 fromInt' = complex . int2floatV
380 toInt' = toInt' . fst . fromComplex
363 381
364--------------------------------------------------------------- 382---------------------------------------------------------------
365 383
@@ -391,6 +409,8 @@ instance (Num a, Element a, Container Vector a) => Container Matrix a
391 assoc' = assocM 409 assoc' = assocM
392 accum' = accumM 410 accum' = accumM
393 cond' = condM 411 cond' = condM
412 ccompare' = compareM
413 cselect' = selectM
394 scaleRecip x = liftMatrix (scaleRecip x) 414 scaleRecip x = liftMatrix (scaleRecip x)
395 divide = liftMatrix2 divide 415 divide = liftMatrix2 divide
396 arctan2' = liftMatrix2 arctan2' 416 arctan2' = liftMatrix2 arctan2'
@@ -398,6 +418,7 @@ instance (Num a, Element a, Container Vector a) => Container Matrix a
398 | m /= 0 = liftMatrix (cmod' m) x 418 | m /= 0 = liftMatrix (cmod' m) x
399 | otherwise = error $ "cmod 0 on matrix "++shSize x 419 | otherwise = error $ "cmod 0 on matrix "++shSize x
400 fromInt' = liftMatrix fromInt' 420 fromInt' = liftMatrix fromInt'
421 toInt' = liftMatrix toInt'
401 422
402 423
403emptyErrorV msg f v = 424emptyErrorV msg f v =
@@ -448,6 +469,9 @@ cmod m = cmod' (fromIntegral m)
448fromInt :: (Container c e) => c I -> c e 469fromInt :: (Container c e) => c I -> c e
449fromInt = fromInt' 470fromInt = fromInt'
450 471
472toInt :: (Container c e) => c e -> c I
473toInt = toInt'
474
451 475
452-- | like 'fmap' (cannot implement instance Functor because of Element class constraint) 476-- | like 'fmap' (cannot implement instance Functor because of Element class constraint)
453cmap :: (Element b, Container c e) => (e -> b) -> c e -> c b 477cmap :: (Element b, Container c e) => (e -> b) -> c e -> c b
@@ -852,6 +876,24 @@ condV f a b l e t = f a' b' l' e' t'
852 where 876 where
853 [a', b', l', e', t'] = conformVs [a,b,l,e,t] 877 [a', b', l', e', t'] = conformVs [a,b,l,e,t]
854 878
879compareM a b = matrixFromVector RowMajor (rows a'') (cols a'') $ ccompare' a' b'
880 where
881 args@(a'':_) = conformMs [a,b]
882 [a', b'] = map flatten args
883
884compareCV f a b = f a' b'
885 where
886 [a', b'] = conformVs [a,b]
887
888selectM c l e t = matrixFromVector RowMajor (rows a'') (cols a'') $ cselect' (toInt c') l' e' t'
889 where
890 args@(a'':_) = conformMs [fromInt c,l,e,t]
891 [c', l', e', t'] = map flatten args
892
893selectCV f c l e t = f (toInt c') l' e' t'
894 where
895 [c', l', e', t'] = conformVs [fromInt c,l,e,t]
896
855-------------------------------------------------------------------------------- 897--------------------------------------------------------------------------------
856 898
857class Transposable m mt | m -> mt, mt -> m 899class Transposable m mt | m -> mt, mt -> m
diff --git a/packages/base/src/Data/Packed/Numeric.hs b/packages/base/src/Data/Packed/Numeric.hs
index cb449a9..906bc83 100644
--- a/packages/base/src/Data/Packed/Numeric.hs
+++ b/packages/base/src/Data/Packed/Numeric.hs
@@ -31,12 +31,12 @@ module Data.Packed.Numeric (
31 diag, ident, 31 diag, ident,
32 ctrans, 32 ctrans,
33 -- * Generic operations 33 -- * Generic operations
34 Container(..), Numeric, 34 Container(..), Numeric, Extractor(..), (??), range, idxs, I,
35 -- add, mul, sub, divide, equal, scaleRecip, addConstant, 35 -- add, mul, sub, divide, equal, scaleRecip, addConstant,
36 scalar, conj, scale, arctan2, cmap, cmod, 36 scalar, conj, scale, arctan2, cmap, cmod,
37 atIndex, minIndex, maxIndex, minElement, maxElement, 37 atIndex, minIndex, maxIndex, minElement, maxElement,
38 sumElements, prodElements, 38 sumElements, prodElements,
39 step, cond, find, assoc, accum, 39 step, cond, find, assoc, accum, ccompare, cselect,
40 Transposable(..), Linear(..), 40 Transposable(..), Linear(..),
41 -- * Matrix product 41 -- * Matrix product
42 Product(..), udot, dot, (<·>), (#>), (<#), app, 42 Product(..), udot, dot, (<·>), (#>), (<#), app,
@@ -58,7 +58,7 @@ module Data.Packed.Numeric (
58 Complexable(), 58 Complexable(),
59 RealElement(), 59 RealElement(),
60 RealOf, ComplexOf, SingleOf, DoubleOf, 60 RealOf, ComplexOf, SingleOf, DoubleOf,
61 roundVector, 61 roundVector,fromInt,toInt,
62 IndexOf, 62 IndexOf,
63 module Data.Complex, 63 module Data.Complex,
64 -- * IO 64 -- * IO
@@ -309,4 +309,9 @@ sortVector = sortV
309sortIndex :: (Ord t, Element t) => Vector t -> Vector I 309sortIndex :: (Ord t, Element t) => Vector t -> Vector I
310sortIndex = sortI 310sortIndex = sortI
311 311
312ccompare :: (Ord t, Container c t) => c t -> c t -> c I
313ccompare = ccompare'
314
315cselect :: (Container c t) => c I -> c t -> c t -> c t -> c t
316cselect = cselect'
312 317
diff --git a/packages/base/src/Numeric/LinearAlgebra/Data.hs b/packages/base/src/Numeric/LinearAlgebra/Data.hs
index 2aac2e4..79dd06b 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Data.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Data.hs
@@ -59,7 +59,9 @@ module Numeric.LinearAlgebra.Data(
59 fromBlocks, (|||), (===), diagBlock, repmat, toBlocks, toBlocksEvery, 59 fromBlocks, (|||), (===), diagBlock, repmat, toBlocks, toBlocksEvery,
60 60
61 -- * Mapping functions 61 -- * Mapping functions
62 conj, cmap, cmod, step, cond, 62 conj, cmap, cmod,
63
64 step, cond, ccompare, cselect,
63 65
64 -- * Find elements 66 -- * Find elements
65 find, maxIndex, minIndex, maxElement, minElement, 67 find, maxIndex, minIndex, maxElement, minElement,
@@ -78,7 +80,7 @@ module Numeric.LinearAlgebra.Data(
78-- * Element conversion 80-- * Element conversion
79 Convert(..), 81 Convert(..),
80 roundVector, 82 roundVector,
81 fromInt, 83 fromInt,toInt,
82 -- * Misc 84 -- * Misc
83 arctan2, 85 arctan2,
84 separable, 86 separable,
@@ -95,6 +97,5 @@ import Data.Packed.Numeric
95import Numeric.LinearAlgebra.Util hiding ((&),(#)) 97import Numeric.LinearAlgebra.Util hiding ((&),(#))
96import Data.Complex 98import Data.Complex
97import Numeric.Sparse 99import Numeric.Sparse
98import Data.Packed.Internal.Numeric(I,Extractor(..),(??),fromInt,range,idxs)
99 100
100 101
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util.hs b/packages/base/src/Numeric/LinearAlgebra/Util.hs
index eadd2a2..779630f 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Util.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Util.hs
@@ -66,7 +66,6 @@ import Control.Monad(when)
66import Text.Printf 66import Text.Printf
67import Data.List.Split(splitOn) 67import Data.List.Split(splitOn)
68import Data.List(intercalate) 68import Data.List(intercalate)
69import Data.Packed.Internal.Numeric(I)
70 69
71type ℝ = Double 70type ℝ = Double
72type ℕ = Int 71type ℕ = Int