summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Vectorized.hs
diff options
context:
space:
mode:
authoridontgetoutmuch <dominic@steinitz.org>2018-04-01 10:25:10 -0700
committerGitHub <noreply@github.com>2018-04-01 10:25:10 -0700
commit7ff1249249cb1b238e3e5ffafdaa9180c86242b9 (patch)
tree903fa549ea33ae9eff9fb527375f87548948d41d /packages/base/src/Internal/Vectorized.hs
parent1bdd94324c6e7cd9d9361c4136998d4c14ea3857 (diff)
parent63ceed4a69563c08e54269a4a9d3e5d27cb7c489 (diff)
Merge branch 'master' into master
Diffstat (limited to 'packages/base/src/Internal/Vectorized.hs')
-rw-r--r--packages/base/src/Internal/Vectorized.hs36
1 files changed, 36 insertions, 0 deletions
diff --git a/packages/base/src/Internal/Vectorized.hs b/packages/base/src/Internal/Vectorized.hs
index a410bb2..c00c324 100644
--- a/packages/base/src/Internal/Vectorized.hs
+++ b/packages/base/src/Internal/Vectorized.hs
@@ -28,12 +28,15 @@ import System.IO.Unsafe(unsafePerformIO)
28import Control.Monad(when) 28import Control.Monad(when)
29 29
30infixr 1 # 30infixr 1 #
31(#) :: TransArray c => c -> (b -> IO r) -> TransRaw c b -> IO r
31a # b = applyRaw a b 32a # b = applyRaw a b
32{-# INLINE (#) #-} 33{-# INLINE (#) #-}
33 34
35(#!) :: (TransArray c, TransArray c1) => c1 -> c -> TransRaw c1 (TransRaw c (IO r)) -> IO r
34a #! b = a # b # id 36a #! b = a # b # id
35{-# INLINE (#!) #-} 37{-# INLINE (#!) #-}
36 38
39fromei :: Enum a => a -> CInt
37fromei x = fromIntegral (fromEnum x) :: CInt 40fromei x = fromIntegral (fromEnum x) :: CInt
38 41
39data FunCodeV = Sin 42data FunCodeV = Sin
@@ -100,10 +103,20 @@ sumQ = sumg c_sumQ
100sumC :: Vector (Complex Double) -> Complex Double 103sumC :: Vector (Complex Double) -> Complex Double
101sumC = sumg c_sumC 104sumC = sumg c_sumC
102 105
106sumI :: ( TransRaw c (CInt -> Ptr a -> IO CInt) ~ (CInt -> Ptr I -> I :> Ok)
107 , TransArray c
108 , Storable a
109 )
110 => I -> c -> a
103sumI m = sumg (c_sumI m) 111sumI m = sumg (c_sumI m)
104 112
113sumL :: ( TransRaw c (CInt -> Ptr a -> IO CInt) ~ (CInt -> Ptr Z -> Z :> Ok)
114 , TransArray c
115 , Storable a
116 ) => Z -> c -> a
105sumL m = sumg (c_sumL m) 117sumL m = sumg (c_sumL m)
106 118
119sumg :: (TransArray c, Storable a) => TransRaw c (CInt -> Ptr a -> IO CInt) -> c -> a
107sumg f x = unsafePerformIO $ do 120sumg f x = unsafePerformIO $ do
108 r <- createVector 1 121 r <- createVector 1
109 (x #! r) f #| "sum" 122 (x #! r) f #| "sum"
@@ -140,6 +153,8 @@ prodI = prodg . c_prodI
140prodL :: Z-> Vector Z -> Z 153prodL :: Z-> Vector Z -> Z
141prodL = prodg . c_prodL 154prodL = prodg . c_prodL
142 155
156prodg :: (TransArray c, Storable a)
157 => TransRaw c (CInt -> Ptr a -> IO CInt) -> c -> a
143prodg f x = unsafePerformIO $ do 158prodg f x = unsafePerformIO $ do
144 r <- createVector 1 159 r <- createVector 1
145 (x #! r) f #| "prod" 160 (x #! r) f #| "prod"
@@ -155,16 +170,25 @@ foreign import ccall unsafe "prodL" c_prodL :: Z -> TVV Z
155 170
156------------------------------------------------------------------ 171------------------------------------------------------------------
157 172
173toScalarAux :: (Enum a, TransArray c, Storable a1)
174 => (CInt -> TransRaw c (CInt -> Ptr a1 -> IO CInt)) -> a -> c -> a1
158toScalarAux fun code v = unsafePerformIO $ do 175toScalarAux fun code v = unsafePerformIO $ do
159 r <- createVector 1 176 r <- createVector 1
160 (v #! r) (fun (fromei code)) #|"toScalarAux" 177 (v #! r) (fun (fromei code)) #|"toScalarAux"
161 return (r @> 0) 178 return (r @> 0)
162 179
180
181vectorMapAux :: (Enum a, Storable t, Storable a1)
182 => (CInt -> CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt)
183 -> a -> Vector t -> Vector a1
163vectorMapAux fun code v = unsafePerformIO $ do 184vectorMapAux fun code v = unsafePerformIO $ do
164 r <- createVector (dim v) 185 r <- createVector (dim v)
165 (v #! r) (fun (fromei code)) #|"vectorMapAux" 186 (v #! r) (fun (fromei code)) #|"vectorMapAux"
166 return r 187 return r
167 188
189vectorMapValAux :: (Enum a, Storable a2, Storable t, Storable a1)
190 => (CInt -> Ptr a2 -> CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt)
191 -> a -> a2 -> Vector t -> Vector a1
168vectorMapValAux fun code val v = unsafePerformIO $ do 192vectorMapValAux fun code val v = unsafePerformIO $ do
169 r <- createVector (dim v) 193 r <- createVector (dim v)
170 pval <- newArray [val] 194 pval <- newArray [val]
@@ -172,6 +196,9 @@ vectorMapValAux fun code val v = unsafePerformIO $ do
172 free pval 196 free pval
173 return r 197 return r
174 198
199vectorZipAux :: (Enum a, TransArray c, Storable t, Storable a1)
200 => (CInt -> CInt -> Ptr t -> TransRaw c (CInt -> Ptr a1 -> IO CInt))
201 -> a -> Vector t -> c -> Vector a1
175vectorZipAux fun code u v = unsafePerformIO $ do 202vectorZipAux fun code u v = unsafePerformIO $ do
176 r <- createVector (dim u) 203 r <- createVector (dim u)
177 (u # v #! r) (fun (fromei code)) #|"vectorZipAux" 204 (u # v #! r) (fun (fromei code)) #|"vectorZipAux"
@@ -378,6 +405,7 @@ foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> D
378 405
379-------------------------------------------------------------------------------- 406--------------------------------------------------------------------------------
380 407
408roundVector :: Vector Double -> Vector Double
381roundVector v = unsafePerformIO $ do 409roundVector v = unsafePerformIO $ do
382 r <- createVector (dim v) 410 r <- createVector (dim v)
383 (v #! r) c_round_vector #|"roundVector" 411 (v #! r) c_round_vector #|"roundVector"
@@ -432,6 +460,8 @@ long2intV :: Vector Z -> Vector I
432long2intV = tog c_long2int 460long2intV = tog c_long2int
433 461
434 462
463tog :: (Storable t, Storable a)
464 => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a
435tog f v = unsafePerformIO $ do 465tog f v = unsafePerformIO $ do
436 r <- createVector (dim v) 466 r <- createVector (dim v)
437 (v #! r) f #|"tog" 467 (v #! r) f #|"tog"
@@ -451,6 +481,8 @@ foreign import ccall unsafe "long2int" c_long2int :: Z :> I :> Ok
451 481
452--------------------------------------------------------------- 482---------------------------------------------------------------
453 483
484stepg :: (Storable t, Storable a)
485 => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a
454stepg f v = unsafePerformIO $ do 486stepg f v = unsafePerformIO $ do
455 r <- createVector (dim v) 487 r <- createVector (dim v)
456 (v #! r) f #|"step" 488 (v #! r) f #|"step"
@@ -476,6 +508,8 @@ foreign import ccall unsafe "stepL" c_stepL :: TVV Z
476 508
477-------------------------------------------------------------------------------- 509--------------------------------------------------------------------------------
478 510
511conjugateAux :: (Storable t, Storable a)
512 => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a
479conjugateAux fun x = unsafePerformIO $ do 513conjugateAux fun x = unsafePerformIO $ do
480 v <- createVector (dim x) 514 v <- createVector (dim x)
481 (x #! v) fun #|"conjugateAux" 515 (x #! v) fun #|"conjugateAux"
@@ -501,6 +535,8 @@ cloneVector v = do
501 535
502-------------------------------------------------------------------------------- 536--------------------------------------------------------------------------------
503 537
538constantAux :: (Storable a1, Storable a)
539 => (Ptr a1 -> CInt -> Ptr a -> IO CInt) -> a1 -> Int -> Vector a
504constantAux fun x n = unsafePerformIO $ do 540constantAux fun x n = unsafePerformIO $ do
505 v <- createVector n 541 v <- createVector n
506 px <- newArray [x] 542 px <- newArray [x]