From 1a68793247b8845cefad4d157e4f4d25b1731b42 Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Fri, 30 Mar 2018 12:48:20 +0100 Subject: Implement CI --- packages/base/src/Internal/Vectorized.hs | 36 ++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) (limited to 'packages/base/src/Internal/Vectorized.hs') diff --git a/packages/base/src/Internal/Vectorized.hs b/packages/base/src/Internal/Vectorized.hs index a410bb2..c00c324 100644 --- a/packages/base/src/Internal/Vectorized.hs +++ b/packages/base/src/Internal/Vectorized.hs @@ -28,12 +28,15 @@ import System.IO.Unsafe(unsafePerformIO) import Control.Monad(when) infixr 1 # +(#) :: TransArray c => c -> (b -> IO r) -> TransRaw c b -> IO r a # b = applyRaw a b {-# INLINE (#) #-} +(#!) :: (TransArray c, TransArray c1) => c1 -> c -> TransRaw c1 (TransRaw c (IO r)) -> IO r a #! b = a # b # id {-# INLINE (#!) #-} +fromei :: Enum a => a -> CInt fromei x = fromIntegral (fromEnum x) :: CInt data FunCodeV = Sin @@ -100,10 +103,20 @@ sumQ = sumg c_sumQ sumC :: Vector (Complex Double) -> Complex Double sumC = sumg c_sumC +sumI :: ( TransRaw c (CInt -> Ptr a -> IO CInt) ~ (CInt -> Ptr I -> I :> Ok) + , TransArray c + , Storable a + ) + => I -> c -> a sumI m = sumg (c_sumI m) +sumL :: ( TransRaw c (CInt -> Ptr a -> IO CInt) ~ (CInt -> Ptr Z -> Z :> Ok) + , TransArray c + , Storable a + ) => Z -> c -> a sumL m = sumg (c_sumL m) +sumg :: (TransArray c, Storable a) => TransRaw c (CInt -> Ptr a -> IO CInt) -> c -> a sumg f x = unsafePerformIO $ do r <- createVector 1 (x #! r) f #| "sum" @@ -140,6 +153,8 @@ prodI = prodg . c_prodI prodL :: Z-> Vector Z -> Z prodL = prodg . c_prodL +prodg :: (TransArray c, Storable a) + => TransRaw c (CInt -> Ptr a -> IO CInt) -> c -> a prodg f x = unsafePerformIO $ do r <- createVector 1 (x #! r) f #| "prod" @@ -155,16 +170,25 @@ foreign import ccall unsafe "prodL" c_prodL :: Z -> TVV Z ------------------------------------------------------------------ +toScalarAux :: (Enum a, TransArray c, Storable a1) + => (CInt -> TransRaw c (CInt -> Ptr a1 -> IO CInt)) -> a -> c -> a1 toScalarAux fun code v = unsafePerformIO $ do r <- createVector 1 (v #! r) (fun (fromei code)) #|"toScalarAux" return (r @> 0) + +vectorMapAux :: (Enum a, Storable t, Storable a1) + => (CInt -> CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt) + -> a -> Vector t -> Vector a1 vectorMapAux fun code v = unsafePerformIO $ do r <- createVector (dim v) (v #! r) (fun (fromei code)) #|"vectorMapAux" return r +vectorMapValAux :: (Enum a, Storable a2, Storable t, Storable a1) + => (CInt -> Ptr a2 -> CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt) + -> a -> a2 -> Vector t -> Vector a1 vectorMapValAux fun code val v = unsafePerformIO $ do r <- createVector (dim v) pval <- newArray [val] @@ -172,6 +196,9 @@ vectorMapValAux fun code val v = unsafePerformIO $ do free pval return r +vectorZipAux :: (Enum a, TransArray c, Storable t, Storable a1) + => (CInt -> CInt -> Ptr t -> TransRaw c (CInt -> Ptr a1 -> IO CInt)) + -> a -> Vector t -> c -> Vector a1 vectorZipAux fun code u v = unsafePerformIO $ do r <- createVector (dim u) (u # v #! r) (fun (fromei code)) #|"vectorZipAux" @@ -378,6 +405,7 @@ foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> D -------------------------------------------------------------------------------- +roundVector :: Vector Double -> Vector Double roundVector v = unsafePerformIO $ do r <- createVector (dim v) (v #! r) c_round_vector #|"roundVector" @@ -432,6 +460,8 @@ long2intV :: Vector Z -> Vector I long2intV = tog c_long2int +tog :: (Storable t, Storable a) + => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a tog f v = unsafePerformIO $ do r <- createVector (dim v) (v #! r) f #|"tog" @@ -451,6 +481,8 @@ foreign import ccall unsafe "long2int" c_long2int :: Z :> I :> Ok --------------------------------------------------------------- +stepg :: (Storable t, Storable a) + => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a stepg f v = unsafePerformIO $ do r <- createVector (dim v) (v #! r) f #|"step" @@ -476,6 +508,8 @@ foreign import ccall unsafe "stepL" c_stepL :: TVV Z -------------------------------------------------------------------------------- +conjugateAux :: (Storable t, Storable a) + => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a conjugateAux fun x = unsafePerformIO $ do v <- createVector (dim x) (x #! v) fun #|"conjugateAux" @@ -501,6 +535,8 @@ cloneVector v = do -------------------------------------------------------------------------------- +constantAux :: (Storable a1, Storable a) + => (Ptr a1 -> CInt -> Ptr a -> IO CInt) -> a1 -> Int -> Vector a constantAux fun x n = unsafePerformIO $ do v <- createVector n px <- newArray [x] -- cgit v1.2.3