summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Vectorized.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal/Vectorized.hs')
-rw-r--r--packages/base/src/Internal/Vectorized.hs36
1 files changed, 36 insertions, 0 deletions
diff --git a/packages/base/src/Internal/Vectorized.hs b/packages/base/src/Internal/Vectorized.hs
index 2990173..32430c6 100644
--- a/packages/base/src/Internal/Vectorized.hs
+++ b/packages/base/src/Internal/Vectorized.hs
@@ -28,12 +28,15 @@ import System.IO.Unsafe(unsafePerformIO)
28import Control.Monad(when) 28import Control.Monad(when)
29 29
30infixr 1 # 30infixr 1 #
31(#) :: TransArray c => c -> (b -> IO r) -> TransRaw c b -> IO r
31a # b = applyRaw a b 32a # b = applyRaw a b
32{-# INLINE (#) #-} 33{-# INLINE (#) #-}
33 34
35(#!) :: (TransArray c, TransArray c1) => c1 -> c -> TransRaw c1 (TransRaw c (IO r)) -> IO r
34a #! b = a # b # id 36a #! b = a # b # id
35{-# INLINE (#!) #-} 37{-# INLINE (#!) #-}
36 38
39fromei :: Enum a => a -> CInt
37fromei x = fromIntegral (fromEnum x) :: CInt 40fromei x = fromIntegral (fromEnum x) :: CInt
38 41
39data FunCodeV = Sin 42data FunCodeV = Sin
@@ -100,10 +103,20 @@ sumQ = sumg c_sumQ
100sumC :: Vector (Complex Double) -> Complex Double 103sumC :: Vector (Complex Double) -> Complex Double
101sumC = sumg c_sumC 104sumC = sumg c_sumC
102 105
106sumI :: ( TransRaw c (CInt -> Ptr a -> IO CInt) ~ (CInt -> Ptr I -> I :> Ok)
107 , TransArray c
108 , Storable a
109 )
110 => I -> c -> a
103sumI m = sumg (c_sumI m) 111sumI m = sumg (c_sumI m)
104 112
113sumL :: ( TransRaw c (CInt -> Ptr a -> IO CInt) ~ (CInt -> Ptr Z -> Z :> Ok)
114 , TransArray c
115 , Storable a
116 ) => Z -> c -> a
105sumL m = sumg (c_sumL m) 117sumL m = sumg (c_sumL m)
106 118
119sumg :: (TransArray c, Storable a) => TransRaw c (CInt -> Ptr a -> IO CInt) -> c -> a
107sumg f x = unsafePerformIO $ do 120sumg f x = unsafePerformIO $ do
108 r <- createVector 1 121 r <- createVector 1
109 (x #! r) f #| "sum" 122 (x #! r) f #| "sum"
@@ -140,6 +153,8 @@ prodI = prodg . c_prodI
140prodL :: Z-> Vector Z -> Z 153prodL :: Z-> Vector Z -> Z
141prodL = prodg . c_prodL 154prodL = prodg . c_prodL
142 155
156prodg :: (TransArray c, Storable a)
157 => TransRaw c (CInt -> Ptr a -> IO CInt) -> c -> a
143prodg f x = unsafePerformIO $ do 158prodg f x = unsafePerformIO $ do
144 r <- createVector 1 159 r <- createVector 1
145 (x #! r) f #| "prod" 160 (x #! r) f #| "prod"
@@ -155,16 +170,25 @@ foreign import ccall unsafe "prodL" c_prodL :: Z -> TVV Z
155 170
156------------------------------------------------------------------ 171------------------------------------------------------------------
157 172
173toScalarAux :: (Enum a, TransArray c, Storable a1)
174 => (CInt -> TransRaw c (CInt -> Ptr a1 -> IO CInt)) -> a -> c -> a1
158toScalarAux fun code v = unsafePerformIO $ do 175toScalarAux fun code v = unsafePerformIO $ do
159 r <- createVector 1 176 r <- createVector 1
160 (v #! r) (fun (fromei code)) #|"toScalarAux" 177 (v #! r) (fun (fromei code)) #|"toScalarAux"
161 return (r @> 0) 178 return (r @> 0)
162 179
180
181vectorMapAux :: (Enum a, Storable t, Storable a1)
182 => (CInt -> CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt)
183 -> a -> Vector t -> Vector a1
163vectorMapAux fun code v = unsafePerformIO $ do 184vectorMapAux fun code v = unsafePerformIO $ do
164 r <- createVector (dim v) 185 r <- createVector (dim v)
165 (v #! r) (fun (fromei code)) #|"vectorMapAux" 186 (v #! r) (fun (fromei code)) #|"vectorMapAux"
166 return r 187 return r
167 188
189vectorMapValAux :: (Enum a, Storable a2, Storable t, Storable a1)
190 => (CInt -> Ptr a2 -> CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt)
191 -> a -> a2 -> Vector t -> Vector a1
168vectorMapValAux fun code val v = unsafePerformIO $ do 192vectorMapValAux fun code val v = unsafePerformIO $ do
169 r <- createVector (dim v) 193 r <- createVector (dim v)
170 pval <- newArray [val] 194 pval <- newArray [val]
@@ -172,6 +196,9 @@ vectorMapValAux fun code val v = unsafePerformIO $ do
172 free pval 196 free pval
173 return r 197 return r
174 198
199vectorZipAux :: (Enum a, TransArray c, Storable t, Storable a1)
200 => (CInt -> CInt -> Ptr t -> TransRaw c (CInt -> Ptr a1 -> IO CInt))
201 -> a -> Vector t -> c -> Vector a1
175vectorZipAux fun code u v = unsafePerformIO $ do 202vectorZipAux fun code u v = unsafePerformIO $ do
176 r <- createVector (dim u) 203 r <- createVector (dim u)
177 (u # v #! r) (fun (fromei code)) #|"vectorZipAux" 204 (u # v #! r) (fun (fromei code)) #|"vectorZipAux"
@@ -378,6 +405,7 @@ foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> D
378 405
379-------------------------------------------------------------------------------- 406--------------------------------------------------------------------------------
380 407
408roundVector :: Vector Double -> Vector Double
381roundVector v = unsafePerformIO $ do 409roundVector v = unsafePerformIO $ do
382 r <- createVector (dim v) 410 r <- createVector (dim v)
383 (v #! r) c_round_vector #|"roundVector" 411 (v #! r) c_round_vector #|"roundVector"
@@ -433,6 +461,8 @@ long2intV :: Vector Z -> Vector I
433long2intV = tog c_long2int 461long2intV = tog c_long2int
434 462
435 463
464tog :: (Storable t, Storable a)
465 => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a
436tog f v = unsafePerformIO $ do 466tog f v = unsafePerformIO $ do
437 r <- createVector (dim v) 467 r <- createVector (dim v)
438 (v #! r) f #|"tog" 468 (v #! r) f #|"tog"
@@ -452,6 +482,8 @@ foreign import ccall unsafe "long2int" c_long2int :: Z :> I :> Ok
452 482
453--------------------------------------------------------------- 483---------------------------------------------------------------
454 484
485stepg :: (Storable t, Storable a)
486 => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a
455stepg f v = unsafePerformIO $ do 487stepg f v = unsafePerformIO $ do
456 r <- createVector (dim v) 488 r <- createVector (dim v)
457 (v #! r) f #|"step" 489 (v #! r) f #|"step"
@@ -477,6 +509,8 @@ foreign import ccall unsafe "stepL" c_stepL :: TVV Z
477 509
478-------------------------------------------------------------------------------- 510--------------------------------------------------------------------------------
479 511
512conjugateAux :: (Storable t, Storable a)
513 => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a
480conjugateAux fun x = unsafePerformIO $ do 514conjugateAux fun x = unsafePerformIO $ do
481 v <- createVector (dim x) 515 v <- createVector (dim x)
482 (x #! v) fun #|"conjugateAux" 516 (x #! v) fun #|"conjugateAux"
@@ -502,6 +536,8 @@ cloneVector v = do
502 536
503-------------------------------------------------------------------------------- 537--------------------------------------------------------------------------------
504 538
539constantAux :: (Storable a1, Storable a)
540 => (Ptr a1 -> CInt -> Ptr a -> IO CInt) -> a1 -> Int -> Vector a
505constantAux fun x n = unsafePerformIO $ do 541constantAux fun x n = unsafePerformIO $ do
506 v <- createVector n 542 v <- createVector n
507 px <- newArray [x] 543 px <- newArray [x]