summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Vectorized.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal/Vectorized.hs')
-rw-r--r--packages/base/src/Internal/Vectorized.hs34
1 files changed, 18 insertions, 16 deletions
diff --git a/packages/base/src/Internal/Vectorized.hs b/packages/base/src/Internal/Vectorized.hs
index 03bcf90..a410bb2 100644
--- a/packages/base/src/Internal/Vectorized.hs
+++ b/packages/base/src/Internal/Vectorized.hs
@@ -27,10 +27,13 @@ import Foreign.C.String
27import System.IO.Unsafe(unsafePerformIO) 27import System.IO.Unsafe(unsafePerformIO)
28import Control.Monad(when) 28import Control.Monad(when)
29 29
30infixl 1 # 30infixr 1 #
31a # b = applyRaw a b 31a # b = applyRaw a b
32{-# INLINE (#) #-} 32{-# INLINE (#) #-}
33 33
34a #! b = a # b # id
35{-# INLINE (#!) #-}
36
34fromei x = fromIntegral (fromEnum x) :: CInt 37fromei x = fromIntegral (fromEnum x) :: CInt
35 38
36data FunCodeV = Sin 39data FunCodeV = Sin
@@ -103,7 +106,7 @@ sumL m = sumg (c_sumL m)
103 106
104sumg f x = unsafePerformIO $ do 107sumg f x = unsafePerformIO $ do
105 r <- createVector 1 108 r <- createVector 1
106 f # x # r #| "sum" 109 (x #! r) f #| "sum"
107 return $ r @> 0 110 return $ r @> 0
108 111
109type TVV t = t :> t :> Ok 112type TVV t = t :> t :> Ok
@@ -139,7 +142,7 @@ prodL = prodg . c_prodL
139 142
140prodg f x = unsafePerformIO $ do 143prodg f x = unsafePerformIO $ do
141 r <- createVector 1 144 r <- createVector 1
142 f # x # r #| "prod" 145 (x #! r) f #| "prod"
143 return $ r @> 0 146 return $ r @> 0
144 147
145 148
@@ -154,24 +157,24 @@ foreign import ccall unsafe "prodL" c_prodL :: Z -> TVV Z
154 157
155toScalarAux fun code v = unsafePerformIO $ do 158toScalarAux fun code v = unsafePerformIO $ do
156 r <- createVector 1 159 r <- createVector 1
157 fun (fromei code) # v # r #|"toScalarAux" 160 (v #! r) (fun (fromei code)) #|"toScalarAux"
158 return (r @> 0) 161 return (r @> 0)
159 162
160vectorMapAux fun code v = unsafePerformIO $ do 163vectorMapAux fun code v = unsafePerformIO $ do
161 r <- createVector (dim v) 164 r <- createVector (dim v)
162 fun (fromei code) # v # r #|"vectorMapAux" 165 (v #! r) (fun (fromei code)) #|"vectorMapAux"
163 return r 166 return r
164 167
165vectorMapValAux fun code val v = unsafePerformIO $ do 168vectorMapValAux fun code val v = unsafePerformIO $ do
166 r <- createVector (dim v) 169 r <- createVector (dim v)
167 pval <- newArray [val] 170 pval <- newArray [val]
168 fun (fromei code) pval # v # r #|"vectorMapValAux" 171 (v #! r) (fun (fromei code) pval) #|"vectorMapValAux"
169 free pval 172 free pval
170 return r 173 return r
171 174
172vectorZipAux fun code u v = unsafePerformIO $ do 175vectorZipAux fun code u v = unsafePerformIO $ do
173 r <- createVector (dim u) 176 r <- createVector (dim u)
174 fun (fromei code) # u # v # r #|"vectorZipAux" 177 (u # v #! r) (fun (fromei code)) #|"vectorZipAux"
175 return r 178 return r
176 179
177--------------------------------------------------------------------- 180---------------------------------------------------------------------
@@ -368,7 +371,7 @@ randomVector :: Seed
368 -> Vector Double 371 -> Vector Double
369randomVector seed dist n = unsafePerformIO $ do 372randomVector seed dist n = unsafePerformIO $ do
370 r <- createVector n 373 r <- createVector n
371 c_random_vector (fi seed) ((fi.fromEnum) dist) # r #|"randomVector" 374 (r # id) (c_random_vector (fi seed) ((fi.fromEnum) dist)) #|"randomVector"
372 return r 375 return r
373 376
374foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> Double :> Ok 377foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> Double :> Ok
@@ -377,7 +380,7 @@ foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> D
377 380
378roundVector v = unsafePerformIO $ do 381roundVector v = unsafePerformIO $ do
379 r <- createVector (dim v) 382 r <- createVector (dim v)
380 c_round_vector # v # r #|"roundVector" 383 (v #! r) c_round_vector #|"roundVector"
381 return r 384 return r
382 385
383foreign import ccall unsafe "round_vector" c_round_vector :: TVV Double 386foreign import ccall unsafe "round_vector" c_round_vector :: TVV Double
@@ -391,7 +394,7 @@ foreign import ccall unsafe "round_vector" c_round_vector :: TVV Double
391range :: Int -> Vector I 394range :: Int -> Vector I
392range n = unsafePerformIO $ do 395range n = unsafePerformIO $ do
393 r <- createVector n 396 r <- createVector n
394 c_range_vector # r #|"range" 397 (r # id) c_range_vector #|"range"
395 return r 398 return r
396 399
397foreign import ccall unsafe "range_vector" c_range_vector :: CInt :> Ok 400foreign import ccall unsafe "range_vector" c_range_vector :: CInt :> Ok
@@ -431,7 +434,7 @@ long2intV = tog c_long2int
431 434
432tog f v = unsafePerformIO $ do 435tog f v = unsafePerformIO $ do
433 r <- createVector (dim v) 436 r <- createVector (dim v)
434 f # v # r #|"tog" 437 (v #! r) f #|"tog"
435 return r 438 return r
436 439
437foreign import ccall unsafe "float2double" c_float2double :: Float :> Double :> Ok 440foreign import ccall unsafe "float2double" c_float2double :: Float :> Double :> Ok
@@ -450,7 +453,7 @@ foreign import ccall unsafe "long2int" c_long2int :: Z :> I :> Ok
450 453
451stepg f v = unsafePerformIO $ do 454stepg f v = unsafePerformIO $ do
452 r <- createVector (dim v) 455 r <- createVector (dim v)
453 f # v # r #|"step" 456 (v #! r) f #|"step"
454 return r 457 return r
455 458
456stepD :: Vector Double -> Vector Double 459stepD :: Vector Double -> Vector Double
@@ -475,7 +478,7 @@ foreign import ccall unsafe "stepL" c_stepL :: TVV Z
475 478
476conjugateAux fun x = unsafePerformIO $ do 479conjugateAux fun x = unsafePerformIO $ do
477 v <- createVector (dim x) 480 v <- createVector (dim x)
478 fun # x # v #|"conjugateAux" 481 (x #! v) fun #|"conjugateAux"
479 return v 482 return v
480 483
481conjugateQ :: Vector (Complex Float) -> Vector (Complex Float) 484conjugateQ :: Vector (Complex Float) -> Vector (Complex Float)
@@ -493,7 +496,7 @@ cloneVector v = do
493 let n = dim v 496 let n = dim v
494 r <- createVector n 497 r <- createVector n
495 let f _ s _ d = copyArray d s n >> return 0 498 let f _ s _ d = copyArray d s n >> return 0
496 f # v # r #|"cloneVector" 499 (v #! r) f #|"cloneVector"
497 return r 500 return r
498 501
499-------------------------------------------------------------------------------- 502--------------------------------------------------------------------------------
@@ -501,7 +504,7 @@ cloneVector v = do
501constantAux fun x n = unsafePerformIO $ do 504constantAux fun x n = unsafePerformIO $ do
502 v <- createVector n 505 v <- createVector n
503 px <- newArray [x] 506 px <- newArray [x]
504 fun px # v #|"constantAux" 507 (v # id) (fun px) #|"constantAux"
505 free px 508 free px
506 return v 509 return v
507 510
@@ -515,4 +518,3 @@ foreign import ccall unsafe "constantI" cconstantI :: TConst CInt
515foreign import ccall unsafe "constantL" cconstantL :: TConst Z 518foreign import ccall unsafe "constantL" cconstantL :: TConst Z
516 519
517---------------------------------------------------------------------- 520----------------------------------------------------------------------
518