summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Numeric.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal/Numeric.hs')
-rw-r--r--packages/base/src/Internal/Numeric.hs80
1 files changed, 64 insertions, 16 deletions
diff --git a/packages/base/src/Internal/Numeric.hs b/packages/base/src/Internal/Numeric.hs
index fd0a217..4f7bb82 100644
--- a/packages/base/src/Internal/Numeric.hs
+++ b/packages/base/src/Internal/Numeric.hs
@@ -4,6 +4,7 @@
4{-# LANGUAGE MultiParamTypeClasses #-} 4{-# LANGUAGE MultiParamTypeClasses #-}
5{-# LANGUAGE FunctionalDependencies #-} 5{-# LANGUAGE FunctionalDependencies #-}
6{-# LANGUAGE UndecidableInstances #-} 6{-# LANGUAGE UndecidableInstances #-}
7{-# LANGUAGE PatternSynonyms #-}
7 8
8{-# OPTIONS_GHC -fno-warn-missing-signatures #-} 9{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
9 10
@@ -22,12 +23,18 @@ module Internal.Numeric where
22import Internal.Vector 23import Internal.Vector
23import Internal.Matrix 24import Internal.Matrix
24import Internal.Element 25import Internal.Element
26import Internal.Extract (requires,pattern BAD_SIZE)
25import Internal.ST as ST 27import Internal.ST as ST
26import Internal.Conversion 28import Internal.Conversion
27import Internal.Vectorized 29import Internal.Vectorized
28import Internal.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ,multiplyI,multiplyL) 30import Internal.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ,multiplyI,multiplyL)
31import Control.Monad
32import Data.Function
33import Data.Int
29import Data.List.Split(chunksOf) 34import Data.List.Split(chunksOf)
30import qualified Data.Vector.Storable as V 35import qualified Data.Vector.Storable as V
36import Foreign.Ptr
37import Foreign.Storable
31 38
32-------------------------------------------------------------------------------- 39--------------------------------------------------------------------------------
33 40
@@ -44,7 +51,7 @@ type instance ArgOf Matrix a = a -> a -> a
44-------------------------------------------------------------------------------- 51--------------------------------------------------------------------------------
45 52
46-- | Basic element-by-element functions for numeric containers 53-- | Basic element-by-element functions for numeric containers
47class Element e => Container c e 54class Storable e => Container c e
48 where 55 where
49 conj' :: c e -> c e 56 conj' :: c e -> c e
50 size' :: c e -> IndexOf c 57 size' :: c e -> IndexOf c
@@ -56,7 +63,7 @@ class Element e => Container c e
56 -- | element by element multiplication 63 -- | element by element multiplication
57 mul :: c e -> c e -> c e 64 mul :: c e -> c e -> c e
58 equal :: c e -> c e -> Bool 65 equal :: c e -> c e -> Bool
59 cmap' :: (Element b) => (e -> b) -> c e -> c b 66 cmap' :: (Storable b) => (e -> b) -> c e -> c b
60 konst' :: e -> IndexOf c -> c e 67 konst' :: e -> IndexOf c -> c e
61 build' :: IndexOf c -> (ArgOf c e) -> c e 68 build' :: IndexOf c -> (ArgOf c e) -> c e
62 atIndex' :: c e -> IndexOf c -> e 69 atIndex' :: c e -> IndexOf c -> e
@@ -107,7 +114,7 @@ instance Container Vector I
107 mul = vectorZipI Mul 114 mul = vectorZipI Mul
108 equal = (==) 115 equal = (==)
109 scalar' = V.singleton 116 scalar' = V.singleton
110 konst' = constantD 117 konst' = constantAux
111 build' = buildV 118 build' = buildV
112 cmap' = mapVector 119 cmap' = mapVector
113 atIndex' = (@>) 120 atIndex' = (@>)
@@ -146,7 +153,7 @@ instance Container Vector Z
146 mul = vectorZipL Mul 153 mul = vectorZipL Mul
147 equal = (==) 154 equal = (==)
148 scalar' = V.singleton 155 scalar' = V.singleton
149 konst' = constantD 156 konst' = constantAux
150 build' = buildV 157 build' = buildV
151 cmap' = mapVector 158 cmap' = mapVector
152 atIndex' = (@>) 159 atIndex' = (@>)
@@ -186,7 +193,7 @@ instance Container Vector Float
186 mul = vectorZipF Mul 193 mul = vectorZipF Mul
187 equal = (==) 194 equal = (==)
188 scalar' = V.singleton 195 scalar' = V.singleton
189 konst' = constantD 196 konst' = constantAux
190 build' = buildV 197 build' = buildV
191 cmap' = mapVector 198 cmap' = mapVector
192 atIndex' = (@>) 199 atIndex' = (@>)
@@ -223,7 +230,7 @@ instance Container Vector Double
223 mul = vectorZipR Mul 230 mul = vectorZipR Mul
224 equal = (==) 231 equal = (==)
225 scalar' = V.singleton 232 scalar' = V.singleton
226 konst' = constantD 233 konst' = constantAux
227 build' = buildV 234 build' = buildV
228 cmap' = mapVector 235 cmap' = mapVector
229 atIndex' = (@>) 236 atIndex' = (@>)
@@ -260,7 +267,7 @@ instance Container Vector (Complex Double)
260 mul = vectorZipC Mul 267 mul = vectorZipC Mul
261 equal = (==) 268 equal = (==)
262 scalar' = V.singleton 269 scalar' = V.singleton
263 konst' = constantD 270 konst' = constantAux
264 build' = buildV 271 build' = buildV
265 cmap' = mapVector 272 cmap' = mapVector
266 atIndex' = (@>) 273 atIndex' = (@>)
@@ -296,7 +303,7 @@ instance Container Vector (Complex Float)
296 mul = vectorZipQ Mul 303 mul = vectorZipQ Mul
297 equal = (==) 304 equal = (==)
298 scalar' = V.singleton 305 scalar' = V.singleton
299 konst' = constantD 306 konst' = constantAux
300 build' = buildV 307 build' = buildV
301 cmap' = mapVector 308 cmap' = mapVector
302 atIndex' = (@>) 309 atIndex' = (@>)
@@ -323,7 +330,7 @@ instance Container Vector (Complex Float)
323 330
324--------------------------------------------------------------- 331---------------------------------------------------------------
325 332
326instance (Num a, Element a, Container Vector a) => Container Matrix a 333instance (Num a, Storable a, Container Vector a) => Container Matrix a
327 where 334 where
328 conj' = liftMatrix conj' 335 conj' = liftMatrix conj'
329 size' = size 336 size' = size
@@ -418,8 +425,8 @@ fromZ = fromZ'
418toZ :: (Container c e) => c e -> c Z 425toZ :: (Container c e) => c e -> c Z
419toZ = toZ' 426toZ = toZ'
420 427
421-- | like 'fmap' (cannot implement instance Functor because of Element class constraint) 428-- | like 'fmap' (cannot implement instance Functor because of Storable class constraint)
422cmap :: (Element b, Container c e) => (e -> b) -> c e -> c b 429cmap :: (Storable b, Container c e) => (e -> b) -> c e -> c b
423cmap = cmap' 430cmap = cmap'
424 431
425-- | generic indexing function 432-- | generic indexing function
@@ -470,7 +477,7 @@ step
470step = step' 477step = step'
471 478
472 479
473-- | Element by element version of @case compare a b of {LT -> l; EQ -> e; GT -> g}@. 480-- | Storable by element version of @case compare a b of {LT -> l; EQ -> e; GT -> g}@.
474-- 481--
475-- Arguments with any dimension = 1 are automatically expanded: 482-- Arguments with any dimension = 1 are automatically expanded:
476-- 483--
@@ -598,7 +605,7 @@ instance Numeric Z
598-------------------------------------------------------------------------------- 605--------------------------------------------------------------------------------
599 606
600-- | Matrix product and related functions 607-- | Matrix product and related functions
601class (Num e, Element e) => Product e where 608class (Num e, Storable e) => Product e where
602 -- | matrix product 609 -- | matrix product
603 multiply :: Matrix e -> Matrix e -> Matrix e 610 multiply :: Matrix e -> Matrix e -> Matrix e
604 -- | sum of absolute value of elements (differs in complex case from @norm1@) 611 -- | sum of absolute value of elements (differs in complex case from @norm1@)
@@ -823,12 +830,12 @@ buildV n f = fromList [f k | k <- ks]
823-------------------------------------------------------- 830--------------------------------------------------------
824 831
825-- | Creates a square matrix with a given diagonal. 832-- | Creates a square matrix with a given diagonal.
826diag :: (Num a, Element a) => Vector a -> Matrix a 833diag :: (Num a, Storable a) => Vector a -> Matrix a
827diag v = diagRect 0 v n n where n = dim v 834diag v = diagRect 0 v n n where n = dim v
828 835
829-- | creates the identity matrix of given dimension 836-- | creates the identity matrix of given dimension
830ident :: (Num a, Element a) => Int -> Matrix a 837ident :: (Num a, Storable a) => Int -> Matrix a
831ident n = diag (constantD 1 n) 838ident n = diag (constantAux 1 n)
832 839
833-------------------------------------------------------- 840--------------------------------------------------------
834 841
@@ -943,3 +950,44 @@ class Testable t
943 950
944-------------------------------------------------------------------------------- 951--------------------------------------------------------------------------------
945 952
953compareV :: (Storable a, Ord a) => Vector a -> Vector a -> Vector Int32
954compareV = compareG compareStorable
955
956compareStorable :: (Storable a, Ord a) =>
957 Int32 -> Ptr a
958 -> Int32 -> Ptr a
959 -> Int32 -> Ptr Int32
960 -> IO Int32
961compareStorable xn xp yn yp rn rp = do
962 requires (xn==yn && xn==rn) BAD_SIZE $ do
963 ($ 0) $ fix $ \kloop k -> when (k<xn) $ do
964 xk <- peekElemOff xp (fromIntegral k)
965 yk <- peekElemOff yp (fromIntegral k)
966 pokeElemOff rp (fromIntegral k) $ case compare xk yk of
967 LT -> -1
968 GT -> 1
969 EQ -> 0
970 kloop (succ k)
971 return 0
972
973selectV :: Storable a => Vector Int32 -> Vector a -> Vector a -> Vector a -> Vector a
974selectV = selectG selectStorable
975
976selectStorable :: Storable a =>
977 Int32 -> Ptr Int32
978 -> Int32 -> Ptr a
979 -> Int32 -> Ptr a
980 -> Int32 -> Ptr a
981 -> Int32 -> Ptr a
982 -> IO Int32
983selectStorable condn condp ltn ltp eqn eqp gtn gtp rn rp = do
984 requires (condn==ltn && ltn==eqn && ltn==gtn && ltn==rn) BAD_SIZE $ do
985 ($ 0) $ fix $ \kloop k -> when (k<condn) $ do
986 condpk <- peekElemOff condp (fromIntegral k)
987 pokeElemOff rp (fromIntegral k) =<< case compare condpk 0 of
988 LT -> peekElemOff ltp (fromIntegral k)
989 GT -> peekElemOff gtp (fromIntegral k)
990 EQ -> peekElemOff eqp (fromIntegral k)
991 kloop (succ k)
992 return 0
993