diff options
author | idontgetoutmuch <dominic@steinitz.org> | 2018-04-01 10:25:10 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-04-01 10:25:10 -0700 |
commit | 7ff1249249cb1b238e3e5ffafdaa9180c86242b9 (patch) | |
tree | 903fa549ea33ae9eff9fb527375f87548948d41d /packages/base/src/Internal/Vectorized.hs | |
parent | 1bdd94324c6e7cd9d9361c4136998d4c14ea3857 (diff) | |
parent | 63ceed4a69563c08e54269a4a9d3e5d27cb7c489 (diff) |
Merge branch 'master' into master
Diffstat (limited to 'packages/base/src/Internal/Vectorized.hs')
-rw-r--r-- | packages/base/src/Internal/Vectorized.hs | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/packages/base/src/Internal/Vectorized.hs b/packages/base/src/Internal/Vectorized.hs index a410bb2..c00c324 100644 --- a/packages/base/src/Internal/Vectorized.hs +++ b/packages/base/src/Internal/Vectorized.hs | |||
@@ -28,12 +28,15 @@ import System.IO.Unsafe(unsafePerformIO) | |||
28 | import Control.Monad(when) | 28 | import Control.Monad(when) |
29 | 29 | ||
30 | infixr 1 # | 30 | infixr 1 # |
31 | (#) :: TransArray c => c -> (b -> IO r) -> TransRaw c b -> IO r | ||
31 | a # b = applyRaw a b | 32 | a # b = applyRaw a b |
32 | {-# INLINE (#) #-} | 33 | {-# INLINE (#) #-} |
33 | 34 | ||
35 | (#!) :: (TransArray c, TransArray c1) => c1 -> c -> TransRaw c1 (TransRaw c (IO r)) -> IO r | ||
34 | a #! b = a # b # id | 36 | a #! b = a # b # id |
35 | {-# INLINE (#!) #-} | 37 | {-# INLINE (#!) #-} |
36 | 38 | ||
39 | fromei :: Enum a => a -> CInt | ||
37 | fromei x = fromIntegral (fromEnum x) :: CInt | 40 | fromei x = fromIntegral (fromEnum x) :: CInt |
38 | 41 | ||
39 | data FunCodeV = Sin | 42 | data FunCodeV = Sin |
@@ -100,10 +103,20 @@ sumQ = sumg c_sumQ | |||
100 | sumC :: Vector (Complex Double) -> Complex Double | 103 | sumC :: Vector (Complex Double) -> Complex Double |
101 | sumC = sumg c_sumC | 104 | sumC = sumg c_sumC |
102 | 105 | ||
106 | sumI :: ( TransRaw c (CInt -> Ptr a -> IO CInt) ~ (CInt -> Ptr I -> I :> Ok) | ||
107 | , TransArray c | ||
108 | , Storable a | ||
109 | ) | ||
110 | => I -> c -> a | ||
103 | sumI m = sumg (c_sumI m) | 111 | sumI m = sumg (c_sumI m) |
104 | 112 | ||
113 | sumL :: ( TransRaw c (CInt -> Ptr a -> IO CInt) ~ (CInt -> Ptr Z -> Z :> Ok) | ||
114 | , TransArray c | ||
115 | , Storable a | ||
116 | ) => Z -> c -> a | ||
105 | sumL m = sumg (c_sumL m) | 117 | sumL m = sumg (c_sumL m) |
106 | 118 | ||
119 | sumg :: (TransArray c, Storable a) => TransRaw c (CInt -> Ptr a -> IO CInt) -> c -> a | ||
107 | sumg f x = unsafePerformIO $ do | 120 | sumg f x = unsafePerformIO $ do |
108 | r <- createVector 1 | 121 | r <- createVector 1 |
109 | (x #! r) f #| "sum" | 122 | (x #! r) f #| "sum" |
@@ -140,6 +153,8 @@ prodI = prodg . c_prodI | |||
140 | prodL :: Z-> Vector Z -> Z | 153 | prodL :: Z-> Vector Z -> Z |
141 | prodL = prodg . c_prodL | 154 | prodL = prodg . c_prodL |
142 | 155 | ||
156 | prodg :: (TransArray c, Storable a) | ||
157 | => TransRaw c (CInt -> Ptr a -> IO CInt) -> c -> a | ||
143 | prodg f x = unsafePerformIO $ do | 158 | prodg f x = unsafePerformIO $ do |
144 | r <- createVector 1 | 159 | r <- createVector 1 |
145 | (x #! r) f #| "prod" | 160 | (x #! r) f #| "prod" |
@@ -155,16 +170,25 @@ foreign import ccall unsafe "prodL" c_prodL :: Z -> TVV Z | |||
155 | 170 | ||
156 | ------------------------------------------------------------------ | 171 | ------------------------------------------------------------------ |
157 | 172 | ||
173 | toScalarAux :: (Enum a, TransArray c, Storable a1) | ||
174 | => (CInt -> TransRaw c (CInt -> Ptr a1 -> IO CInt)) -> a -> c -> a1 | ||
158 | toScalarAux fun code v = unsafePerformIO $ do | 175 | toScalarAux fun code v = unsafePerformIO $ do |
159 | r <- createVector 1 | 176 | r <- createVector 1 |
160 | (v #! r) (fun (fromei code)) #|"toScalarAux" | 177 | (v #! r) (fun (fromei code)) #|"toScalarAux" |
161 | return (r @> 0) | 178 | return (r @> 0) |
162 | 179 | ||
180 | |||
181 | vectorMapAux :: (Enum a, Storable t, Storable a1) | ||
182 | => (CInt -> CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt) | ||
183 | -> a -> Vector t -> Vector a1 | ||
163 | vectorMapAux fun code v = unsafePerformIO $ do | 184 | vectorMapAux fun code v = unsafePerformIO $ do |
164 | r <- createVector (dim v) | 185 | r <- createVector (dim v) |
165 | (v #! r) (fun (fromei code)) #|"vectorMapAux" | 186 | (v #! r) (fun (fromei code)) #|"vectorMapAux" |
166 | return r | 187 | return r |
167 | 188 | ||
189 | vectorMapValAux :: (Enum a, Storable a2, Storable t, Storable a1) | ||
190 | => (CInt -> Ptr a2 -> CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt) | ||
191 | -> a -> a2 -> Vector t -> Vector a1 | ||
168 | vectorMapValAux fun code val v = unsafePerformIO $ do | 192 | vectorMapValAux fun code val v = unsafePerformIO $ do |
169 | r <- createVector (dim v) | 193 | r <- createVector (dim v) |
170 | pval <- newArray [val] | 194 | pval <- newArray [val] |
@@ -172,6 +196,9 @@ vectorMapValAux fun code val v = unsafePerformIO $ do | |||
172 | free pval | 196 | free pval |
173 | return r | 197 | return r |
174 | 198 | ||
199 | vectorZipAux :: (Enum a, TransArray c, Storable t, Storable a1) | ||
200 | => (CInt -> CInt -> Ptr t -> TransRaw c (CInt -> Ptr a1 -> IO CInt)) | ||
201 | -> a -> Vector t -> c -> Vector a1 | ||
175 | vectorZipAux fun code u v = unsafePerformIO $ do | 202 | vectorZipAux fun code u v = unsafePerformIO $ do |
176 | r <- createVector (dim u) | 203 | r <- createVector (dim u) |
177 | (u # v #! r) (fun (fromei code)) #|"vectorZipAux" | 204 | (u # v #! r) (fun (fromei code)) #|"vectorZipAux" |
@@ -378,6 +405,7 @@ foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> D | |||
378 | 405 | ||
379 | -------------------------------------------------------------------------------- | 406 | -------------------------------------------------------------------------------- |
380 | 407 | ||
408 | roundVector :: Vector Double -> Vector Double | ||
381 | roundVector v = unsafePerformIO $ do | 409 | roundVector v = unsafePerformIO $ do |
382 | r <- createVector (dim v) | 410 | r <- createVector (dim v) |
383 | (v #! r) c_round_vector #|"roundVector" | 411 | (v #! r) c_round_vector #|"roundVector" |
@@ -432,6 +460,8 @@ long2intV :: Vector Z -> Vector I | |||
432 | long2intV = tog c_long2int | 460 | long2intV = tog c_long2int |
433 | 461 | ||
434 | 462 | ||
463 | tog :: (Storable t, Storable a) | ||
464 | => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a | ||
435 | tog f v = unsafePerformIO $ do | 465 | tog f v = unsafePerformIO $ do |
436 | r <- createVector (dim v) | 466 | r <- createVector (dim v) |
437 | (v #! r) f #|"tog" | 467 | (v #! r) f #|"tog" |
@@ -451,6 +481,8 @@ foreign import ccall unsafe "long2int" c_long2int :: Z :> I :> Ok | |||
451 | 481 | ||
452 | --------------------------------------------------------------- | 482 | --------------------------------------------------------------- |
453 | 483 | ||
484 | stepg :: (Storable t, Storable a) | ||
485 | => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a | ||
454 | stepg f v = unsafePerformIO $ do | 486 | stepg f v = unsafePerformIO $ do |
455 | r <- createVector (dim v) | 487 | r <- createVector (dim v) |
456 | (v #! r) f #|"step" | 488 | (v #! r) f #|"step" |
@@ -476,6 +508,8 @@ foreign import ccall unsafe "stepL" c_stepL :: TVV Z | |||
476 | 508 | ||
477 | -------------------------------------------------------------------------------- | 509 | -------------------------------------------------------------------------------- |
478 | 510 | ||
511 | conjugateAux :: (Storable t, Storable a) | ||
512 | => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a | ||
479 | conjugateAux fun x = unsafePerformIO $ do | 513 | conjugateAux fun x = unsafePerformIO $ do |
480 | v <- createVector (dim x) | 514 | v <- createVector (dim x) |
481 | (x #! v) fun #|"conjugateAux" | 515 | (x #! v) fun #|"conjugateAux" |
@@ -501,6 +535,8 @@ cloneVector v = do | |||
501 | 535 | ||
502 | -------------------------------------------------------------------------------- | 536 | -------------------------------------------------------------------------------- |
503 | 537 | ||
538 | constantAux :: (Storable a1, Storable a) | ||
539 | => (Ptr a1 -> CInt -> Ptr a -> IO CInt) -> a1 -> Int -> Vector a | ||
504 | constantAux fun x n = unsafePerformIO $ do | 540 | constantAux fun x n = unsafePerformIO $ do |
505 | v <- createVector n | 541 | v <- createVector n |
506 | px <- newArray [x] | 542 | px <- newArray [x] |