diff options
Diffstat (limited to 'packages/base/src/Internal/Vectorized.hs')
-rw-r--r-- | packages/base/src/Internal/Vectorized.hs | 34 |
1 files changed, 18 insertions, 16 deletions
diff --git a/packages/base/src/Internal/Vectorized.hs b/packages/base/src/Internal/Vectorized.hs index 03bcf90..a410bb2 100644 --- a/packages/base/src/Internal/Vectorized.hs +++ b/packages/base/src/Internal/Vectorized.hs | |||
@@ -27,10 +27,13 @@ import Foreign.C.String | |||
27 | import System.IO.Unsafe(unsafePerformIO) | 27 | import System.IO.Unsafe(unsafePerformIO) |
28 | import Control.Monad(when) | 28 | import Control.Monad(when) |
29 | 29 | ||
30 | infixl 1 # | 30 | infixr 1 # |
31 | a # b = applyRaw a b | 31 | a # b = applyRaw a b |
32 | {-# INLINE (#) #-} | 32 | {-# INLINE (#) #-} |
33 | 33 | ||
34 | a #! b = a # b # id | ||
35 | {-# INLINE (#!) #-} | ||
36 | |||
34 | fromei x = fromIntegral (fromEnum x) :: CInt | 37 | fromei x = fromIntegral (fromEnum x) :: CInt |
35 | 38 | ||
36 | data FunCodeV = Sin | 39 | data FunCodeV = Sin |
@@ -103,7 +106,7 @@ sumL m = sumg (c_sumL m) | |||
103 | 106 | ||
104 | sumg f x = unsafePerformIO $ do | 107 | sumg f x = unsafePerformIO $ do |
105 | r <- createVector 1 | 108 | r <- createVector 1 |
106 | f # x # r #| "sum" | 109 | (x #! r) f #| "sum" |
107 | return $ r @> 0 | 110 | return $ r @> 0 |
108 | 111 | ||
109 | type TVV t = t :> t :> Ok | 112 | type TVV t = t :> t :> Ok |
@@ -139,7 +142,7 @@ prodL = prodg . c_prodL | |||
139 | 142 | ||
140 | prodg f x = unsafePerformIO $ do | 143 | prodg f x = unsafePerformIO $ do |
141 | r <- createVector 1 | 144 | r <- createVector 1 |
142 | f # x # r #| "prod" | 145 | (x #! r) f #| "prod" |
143 | return $ r @> 0 | 146 | return $ r @> 0 |
144 | 147 | ||
145 | 148 | ||
@@ -154,24 +157,24 @@ foreign import ccall unsafe "prodL" c_prodL :: Z -> TVV Z | |||
154 | 157 | ||
155 | toScalarAux fun code v = unsafePerformIO $ do | 158 | toScalarAux fun code v = unsafePerformIO $ do |
156 | r <- createVector 1 | 159 | r <- createVector 1 |
157 | fun (fromei code) # v # r #|"toScalarAux" | 160 | (v #! r) (fun (fromei code)) #|"toScalarAux" |
158 | return (r @> 0) | 161 | return (r @> 0) |
159 | 162 | ||
160 | vectorMapAux fun code v = unsafePerformIO $ do | 163 | vectorMapAux fun code v = unsafePerformIO $ do |
161 | r <- createVector (dim v) | 164 | r <- createVector (dim v) |
162 | fun (fromei code) # v # r #|"vectorMapAux" | 165 | (v #! r) (fun (fromei code)) #|"vectorMapAux" |
163 | return r | 166 | return r |
164 | 167 | ||
165 | vectorMapValAux fun code val v = unsafePerformIO $ do | 168 | vectorMapValAux fun code val v = unsafePerformIO $ do |
166 | r <- createVector (dim v) | 169 | r <- createVector (dim v) |
167 | pval <- newArray [val] | 170 | pval <- newArray [val] |
168 | fun (fromei code) pval # v # r #|"vectorMapValAux" | 171 | (v #! r) (fun (fromei code) pval) #|"vectorMapValAux" |
169 | free pval | 172 | free pval |
170 | return r | 173 | return r |
171 | 174 | ||
172 | vectorZipAux fun code u v = unsafePerformIO $ do | 175 | vectorZipAux fun code u v = unsafePerformIO $ do |
173 | r <- createVector (dim u) | 176 | r <- createVector (dim u) |
174 | fun (fromei code) # u # v # r #|"vectorZipAux" | 177 | (u # v #! r) (fun (fromei code)) #|"vectorZipAux" |
175 | return r | 178 | return r |
176 | 179 | ||
177 | --------------------------------------------------------------------- | 180 | --------------------------------------------------------------------- |
@@ -368,7 +371,7 @@ randomVector :: Seed | |||
368 | -> Vector Double | 371 | -> Vector Double |
369 | randomVector seed dist n = unsafePerformIO $ do | 372 | randomVector seed dist n = unsafePerformIO $ do |
370 | r <- createVector n | 373 | r <- createVector n |
371 | c_random_vector (fi seed) ((fi.fromEnum) dist) # r #|"randomVector" | 374 | (r # id) (c_random_vector (fi seed) ((fi.fromEnum) dist)) #|"randomVector" |
372 | return r | 375 | return r |
373 | 376 | ||
374 | foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> Double :> Ok | 377 | foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> Double :> Ok |
@@ -377,7 +380,7 @@ foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> D | |||
377 | 380 | ||
378 | roundVector v = unsafePerformIO $ do | 381 | roundVector v = unsafePerformIO $ do |
379 | r <- createVector (dim v) | 382 | r <- createVector (dim v) |
380 | c_round_vector # v # r #|"roundVector" | 383 | (v #! r) c_round_vector #|"roundVector" |
381 | return r | 384 | return r |
382 | 385 | ||
383 | foreign import ccall unsafe "round_vector" c_round_vector :: TVV Double | 386 | foreign import ccall unsafe "round_vector" c_round_vector :: TVV Double |
@@ -391,7 +394,7 @@ foreign import ccall unsafe "round_vector" c_round_vector :: TVV Double | |||
391 | range :: Int -> Vector I | 394 | range :: Int -> Vector I |
392 | range n = unsafePerformIO $ do | 395 | range n = unsafePerformIO $ do |
393 | r <- createVector n | 396 | r <- createVector n |
394 | c_range_vector # r #|"range" | 397 | (r # id) c_range_vector #|"range" |
395 | return r | 398 | return r |
396 | 399 | ||
397 | foreign import ccall unsafe "range_vector" c_range_vector :: CInt :> Ok | 400 | foreign import ccall unsafe "range_vector" c_range_vector :: CInt :> Ok |
@@ -431,7 +434,7 @@ long2intV = tog c_long2int | |||
431 | 434 | ||
432 | tog f v = unsafePerformIO $ do | 435 | tog f v = unsafePerformIO $ do |
433 | r <- createVector (dim v) | 436 | r <- createVector (dim v) |
434 | f # v # r #|"tog" | 437 | (v #! r) f #|"tog" |
435 | return r | 438 | return r |
436 | 439 | ||
437 | foreign import ccall unsafe "float2double" c_float2double :: Float :> Double :> Ok | 440 | foreign import ccall unsafe "float2double" c_float2double :: Float :> Double :> Ok |
@@ -450,7 +453,7 @@ foreign import ccall unsafe "long2int" c_long2int :: Z :> I :> Ok | |||
450 | 453 | ||
451 | stepg f v = unsafePerformIO $ do | 454 | stepg f v = unsafePerformIO $ do |
452 | r <- createVector (dim v) | 455 | r <- createVector (dim v) |
453 | f # v # r #|"step" | 456 | (v #! r) f #|"step" |
454 | return r | 457 | return r |
455 | 458 | ||
456 | stepD :: Vector Double -> Vector Double | 459 | stepD :: Vector Double -> Vector Double |
@@ -475,7 +478,7 @@ foreign import ccall unsafe "stepL" c_stepL :: TVV Z | |||
475 | 478 | ||
476 | conjugateAux fun x = unsafePerformIO $ do | 479 | conjugateAux fun x = unsafePerformIO $ do |
477 | v <- createVector (dim x) | 480 | v <- createVector (dim x) |
478 | fun # x # v #|"conjugateAux" | 481 | (x #! v) fun #|"conjugateAux" |
479 | return v | 482 | return v |
480 | 483 | ||
481 | conjugateQ :: Vector (Complex Float) -> Vector (Complex Float) | 484 | conjugateQ :: Vector (Complex Float) -> Vector (Complex Float) |
@@ -493,7 +496,7 @@ cloneVector v = do | |||
493 | let n = dim v | 496 | let n = dim v |
494 | r <- createVector n | 497 | r <- createVector n |
495 | let f _ s _ d = copyArray d s n >> return 0 | 498 | let f _ s _ d = copyArray d s n >> return 0 |
496 | f # v # r #|"cloneVector" | 499 | (v #! r) f #|"cloneVector" |
497 | return r | 500 | return r |
498 | 501 | ||
499 | -------------------------------------------------------------------------------- | 502 | -------------------------------------------------------------------------------- |
@@ -501,7 +504,7 @@ cloneVector v = do | |||
501 | constantAux fun x n = unsafePerformIO $ do | 504 | constantAux fun x n = unsafePerformIO $ do |
502 | v <- createVector n | 505 | v <- createVector n |
503 | px <- newArray [x] | 506 | px <- newArray [x] |
504 | fun px # v #|"constantAux" | 507 | (v # id) (fun px) #|"constantAux" |
505 | free px | 508 | free px |
506 | return v | 509 | return v |
507 | 510 | ||
@@ -515,4 +518,3 @@ foreign import ccall unsafe "constantI" cconstantI :: TConst CInt | |||
515 | foreign import ccall unsafe "constantL" cconstantL :: TConst Z | 518 | foreign import ccall unsafe "constantL" cconstantL :: TConst Z |
516 | 519 | ||
517 | ---------------------------------------------------------------------- | 520 | ---------------------------------------------------------------------- |
518 | |||