diff options
Diffstat (limited to 'packages/base/src/Internal/Numeric.hs')
-rw-r--r-- | packages/base/src/Internal/Numeric.hs | 80 |
1 files changed, 64 insertions, 16 deletions
diff --git a/packages/base/src/Internal/Numeric.hs b/packages/base/src/Internal/Numeric.hs index fd0a217..4f7bb82 100644 --- a/packages/base/src/Internal/Numeric.hs +++ b/packages/base/src/Internal/Numeric.hs | |||
@@ -4,6 +4,7 @@ | |||
4 | {-# LANGUAGE MultiParamTypeClasses #-} | 4 | {-# LANGUAGE MultiParamTypeClasses #-} |
5 | {-# LANGUAGE FunctionalDependencies #-} | 5 | {-# LANGUAGE FunctionalDependencies #-} |
6 | {-# LANGUAGE UndecidableInstances #-} | 6 | {-# LANGUAGE UndecidableInstances #-} |
7 | {-# LANGUAGE PatternSynonyms #-} | ||
7 | 8 | ||
8 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} | 9 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} |
9 | 10 | ||
@@ -22,12 +23,18 @@ module Internal.Numeric where | |||
22 | import Internal.Vector | 23 | import Internal.Vector |
23 | import Internal.Matrix | 24 | import Internal.Matrix |
24 | import Internal.Element | 25 | import Internal.Element |
26 | import Internal.Extract (requires,pattern BAD_SIZE) | ||
25 | import Internal.ST as ST | 27 | import Internal.ST as ST |
26 | import Internal.Conversion | 28 | import Internal.Conversion |
27 | import Internal.Vectorized | 29 | import Internal.Vectorized |
28 | import Internal.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ,multiplyI,multiplyL) | 30 | import Internal.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ,multiplyI,multiplyL) |
31 | import Control.Monad | ||
32 | import Data.Function | ||
33 | import Data.Int | ||
29 | import Data.List.Split(chunksOf) | 34 | import Data.List.Split(chunksOf) |
30 | import qualified Data.Vector.Storable as V | 35 | import qualified Data.Vector.Storable as V |
36 | import Foreign.Ptr | ||
37 | import Foreign.Storable | ||
31 | 38 | ||
32 | -------------------------------------------------------------------------------- | 39 | -------------------------------------------------------------------------------- |
33 | 40 | ||
@@ -44,7 +51,7 @@ type instance ArgOf Matrix a = a -> a -> a | |||
44 | -------------------------------------------------------------------------------- | 51 | -------------------------------------------------------------------------------- |
45 | 52 | ||
46 | -- | Basic element-by-element functions for numeric containers | 53 | -- | Basic element-by-element functions for numeric containers |
47 | class Element e => Container c e | 54 | class Storable e => Container c e |
48 | where | 55 | where |
49 | conj' :: c e -> c e | 56 | conj' :: c e -> c e |
50 | size' :: c e -> IndexOf c | 57 | size' :: c e -> IndexOf c |
@@ -56,7 +63,7 @@ class Element e => Container c e | |||
56 | -- | element by element multiplication | 63 | -- | element by element multiplication |
57 | mul :: c e -> c e -> c e | 64 | mul :: c e -> c e -> c e |
58 | equal :: c e -> c e -> Bool | 65 | equal :: c e -> c e -> Bool |
59 | cmap' :: (Element b) => (e -> b) -> c e -> c b | 66 | cmap' :: (Storable b) => (e -> b) -> c e -> c b |
60 | konst' :: e -> IndexOf c -> c e | 67 | konst' :: e -> IndexOf c -> c e |
61 | build' :: IndexOf c -> (ArgOf c e) -> c e | 68 | build' :: IndexOf c -> (ArgOf c e) -> c e |
62 | atIndex' :: c e -> IndexOf c -> e | 69 | atIndex' :: c e -> IndexOf c -> e |
@@ -107,7 +114,7 @@ instance Container Vector I | |||
107 | mul = vectorZipI Mul | 114 | mul = vectorZipI Mul |
108 | equal = (==) | 115 | equal = (==) |
109 | scalar' = V.singleton | 116 | scalar' = V.singleton |
110 | konst' = constantD | 117 | konst' = constantAux |
111 | build' = buildV | 118 | build' = buildV |
112 | cmap' = mapVector | 119 | cmap' = mapVector |
113 | atIndex' = (@>) | 120 | atIndex' = (@>) |
@@ -146,7 +153,7 @@ instance Container Vector Z | |||
146 | mul = vectorZipL Mul | 153 | mul = vectorZipL Mul |
147 | equal = (==) | 154 | equal = (==) |
148 | scalar' = V.singleton | 155 | scalar' = V.singleton |
149 | konst' = constantD | 156 | konst' = constantAux |
150 | build' = buildV | 157 | build' = buildV |
151 | cmap' = mapVector | 158 | cmap' = mapVector |
152 | atIndex' = (@>) | 159 | atIndex' = (@>) |
@@ -186,7 +193,7 @@ instance Container Vector Float | |||
186 | mul = vectorZipF Mul | 193 | mul = vectorZipF Mul |
187 | equal = (==) | 194 | equal = (==) |
188 | scalar' = V.singleton | 195 | scalar' = V.singleton |
189 | konst' = constantD | 196 | konst' = constantAux |
190 | build' = buildV | 197 | build' = buildV |
191 | cmap' = mapVector | 198 | cmap' = mapVector |
192 | atIndex' = (@>) | 199 | atIndex' = (@>) |
@@ -223,7 +230,7 @@ instance Container Vector Double | |||
223 | mul = vectorZipR Mul | 230 | mul = vectorZipR Mul |
224 | equal = (==) | 231 | equal = (==) |
225 | scalar' = V.singleton | 232 | scalar' = V.singleton |
226 | konst' = constantD | 233 | konst' = constantAux |
227 | build' = buildV | 234 | build' = buildV |
228 | cmap' = mapVector | 235 | cmap' = mapVector |
229 | atIndex' = (@>) | 236 | atIndex' = (@>) |
@@ -260,7 +267,7 @@ instance Container Vector (Complex Double) | |||
260 | mul = vectorZipC Mul | 267 | mul = vectorZipC Mul |
261 | equal = (==) | 268 | equal = (==) |
262 | scalar' = V.singleton | 269 | scalar' = V.singleton |
263 | konst' = constantD | 270 | konst' = constantAux |
264 | build' = buildV | 271 | build' = buildV |
265 | cmap' = mapVector | 272 | cmap' = mapVector |
266 | atIndex' = (@>) | 273 | atIndex' = (@>) |
@@ -296,7 +303,7 @@ instance Container Vector (Complex Float) | |||
296 | mul = vectorZipQ Mul | 303 | mul = vectorZipQ Mul |
297 | equal = (==) | 304 | equal = (==) |
298 | scalar' = V.singleton | 305 | scalar' = V.singleton |
299 | konst' = constantD | 306 | konst' = constantAux |
300 | build' = buildV | 307 | build' = buildV |
301 | cmap' = mapVector | 308 | cmap' = mapVector |
302 | atIndex' = (@>) | 309 | atIndex' = (@>) |
@@ -323,7 +330,7 @@ instance Container Vector (Complex Float) | |||
323 | 330 | ||
324 | --------------------------------------------------------------- | 331 | --------------------------------------------------------------- |
325 | 332 | ||
326 | instance (Num a, Element a, Container Vector a) => Container Matrix a | 333 | instance (Num a, Storable a, Container Vector a) => Container Matrix a |
327 | where | 334 | where |
328 | conj' = liftMatrix conj' | 335 | conj' = liftMatrix conj' |
329 | size' = size | 336 | size' = size |
@@ -418,8 +425,8 @@ fromZ = fromZ' | |||
418 | toZ :: (Container c e) => c e -> c Z | 425 | toZ :: (Container c e) => c e -> c Z |
419 | toZ = toZ' | 426 | toZ = toZ' |
420 | 427 | ||
421 | -- | like 'fmap' (cannot implement instance Functor because of Element class constraint) | 428 | -- | like 'fmap' (cannot implement instance Functor because of Storable class constraint) |
422 | cmap :: (Element b, Container c e) => (e -> b) -> c e -> c b | 429 | cmap :: (Storable b, Container c e) => (e -> b) -> c e -> c b |
423 | cmap = cmap' | 430 | cmap = cmap' |
424 | 431 | ||
425 | -- | generic indexing function | 432 | -- | generic indexing function |
@@ -470,7 +477,7 @@ step | |||
470 | step = step' | 477 | step = step' |
471 | 478 | ||
472 | 479 | ||
473 | -- | Element by element version of @case compare a b of {LT -> l; EQ -> e; GT -> g}@. | 480 | -- | Storable by element version of @case compare a b of {LT -> l; EQ -> e; GT -> g}@. |
474 | -- | 481 | -- |
475 | -- Arguments with any dimension = 1 are automatically expanded: | 482 | -- Arguments with any dimension = 1 are automatically expanded: |
476 | -- | 483 | -- |
@@ -598,7 +605,7 @@ instance Numeric Z | |||
598 | -------------------------------------------------------------------------------- | 605 | -------------------------------------------------------------------------------- |
599 | 606 | ||
600 | -- | Matrix product and related functions | 607 | -- | Matrix product and related functions |
601 | class (Num e, Element e) => Product e where | 608 | class (Num e, Storable e) => Product e where |
602 | -- | matrix product | 609 | -- | matrix product |
603 | multiply :: Matrix e -> Matrix e -> Matrix e | 610 | multiply :: Matrix e -> Matrix e -> Matrix e |
604 | -- | sum of absolute value of elements (differs in complex case from @norm1@) | 611 | -- | sum of absolute value of elements (differs in complex case from @norm1@) |
@@ -823,12 +830,12 @@ buildV n f = fromList [f k | k <- ks] | |||
823 | -------------------------------------------------------- | 830 | -------------------------------------------------------- |
824 | 831 | ||
825 | -- | Creates a square matrix with a given diagonal. | 832 | -- | Creates a square matrix with a given diagonal. |
826 | diag :: (Num a, Element a) => Vector a -> Matrix a | 833 | diag :: (Num a, Storable a) => Vector a -> Matrix a |
827 | diag v = diagRect 0 v n n where n = dim v | 834 | diag v = diagRect 0 v n n where n = dim v |
828 | 835 | ||
829 | -- | creates the identity matrix of given dimension | 836 | -- | creates the identity matrix of given dimension |
830 | ident :: (Num a, Element a) => Int -> Matrix a | 837 | ident :: (Num a, Storable a) => Int -> Matrix a |
831 | ident n = diag (constantD 1 n) | 838 | ident n = diag (constantAux 1 n) |
832 | 839 | ||
833 | -------------------------------------------------------- | 840 | -------------------------------------------------------- |
834 | 841 | ||
@@ -943,3 +950,44 @@ class Testable t | |||
943 | 950 | ||
944 | -------------------------------------------------------------------------------- | 951 | -------------------------------------------------------------------------------- |
945 | 952 | ||
953 | compareV :: (Storable a, Ord a) => Vector a -> Vector a -> Vector Int32 | ||
954 | compareV = compareG compareStorable | ||
955 | |||
956 | compareStorable :: (Storable a, Ord a) => | ||
957 | Int32 -> Ptr a | ||
958 | -> Int32 -> Ptr a | ||
959 | -> Int32 -> Ptr Int32 | ||
960 | -> IO Int32 | ||
961 | compareStorable xn xp yn yp rn rp = do | ||
962 | requires (xn==yn && xn==rn) BAD_SIZE $ do | ||
963 | ($ 0) $ fix $ \kloop k -> when (k<xn) $ do | ||
964 | xk <- peekElemOff xp (fromIntegral k) | ||
965 | yk <- peekElemOff yp (fromIntegral k) | ||
966 | pokeElemOff rp (fromIntegral k) $ case compare xk yk of | ||
967 | LT -> -1 | ||
968 | GT -> 1 | ||
969 | EQ -> 0 | ||
970 | kloop (succ k) | ||
971 | return 0 | ||
972 | |||
973 | selectV :: Storable a => Vector Int32 -> Vector a -> Vector a -> Vector a -> Vector a | ||
974 | selectV = selectG selectStorable | ||
975 | |||
976 | selectStorable :: Storable a => | ||
977 | Int32 -> Ptr Int32 | ||
978 | -> Int32 -> Ptr a | ||
979 | -> Int32 -> Ptr a | ||
980 | -> Int32 -> Ptr a | ||
981 | -> Int32 -> Ptr a | ||
982 | -> IO Int32 | ||
983 | selectStorable condn condp ltn ltp eqn eqp gtn gtp rn rp = do | ||
984 | requires (condn==ltn && ltn==eqn && ltn==gtn && ltn==rn) BAD_SIZE $ do | ||
985 | ($ 0) $ fix $ \kloop k -> when (k<condn) $ do | ||
986 | condpk <- peekElemOff condp (fromIntegral k) | ||
987 | pokeElemOff rp (fromIntegral k) =<< case compare condpk 0 of | ||
988 | LT -> peekElemOff ltp (fromIntegral k) | ||
989 | GT -> peekElemOff gtp (fromIntegral k) | ||
990 | EQ -> peekElemOff eqp (fromIntegral k) | ||
991 | kloop (succ k) | ||
992 | return 0 | ||
993 | |||