From 45c2c85c3fbf3173df685bf3143af50e8569e40d Mon Sep 17 00:00:00 2001 From: Joe Crayne Date: Fri, 9 Aug 2019 19:27:40 -0400 Subject: Internal.Specialized --- packages/base/src/Internal/Specialized.hs | 89 +++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 packages/base/src/Internal/Specialized.hs diff --git a/packages/base/src/Internal/Specialized.hs b/packages/base/src/Internal/Specialized.hs new file mode 100644 index 0000000..ff89024 --- /dev/null +++ b/packages/base/src/Internal/Specialized.hs @@ -0,0 +1,89 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeFamilies #-} +module Internal.Specialized where + +import Control.Monad +import Data.Complex +import Data.Int +import Data.Typeable +import Foreign.Marshal.Alloc(free,malloc) +import Foreign.Marshal.Array(newArray,copyArray) +import Foreign.Ptr +import Foreign.Storable +import Foreign.C.Types (CInt(..)) +import Data.Vector.Storable (Vector) +import System.IO.Unsafe + +import Internal.Vector (createVector) +import Internal.Devel + +type Element t = (Storable t, Typeable t) + +data Specialized a + = SpFloat !(a :~: Float) + | SpDouble !(a :~: Double) + | SpCFloat !(a :~: Complex Float) + | SpCDouble !(a :~: Complex Double) + | SpInt32 !(a :~: Int32) + | SpInt64 !(a :~: Int64) + +specialize :: Typeable a => a -> Maybe (Specialized a) +specialize x = foldr1 mplus + [ SpDouble <$> cst x + , SpInt64 <$> cst x + , SpFloat <$> cst x + , SpInt32 <$> cst x + , SpCDouble <$> cst x + , SpCFloat <$> cst x + -- , fmap (\(CInt y) -> SpInt32 y) <$> cast x + ] + where + cst :: (Typeable a, Typeable b) => a -> Maybe (a :~: b) + cst _ = eqT + +-- | Supported matrix elements. +constantD :: Typeable a => a -> Int -> Vector a +constantD x = case specialize x of + Nothing -> error "constantD" + Just (SpDouble Refl) -> constantAux cconstantR x + Just (SpInt64 Refl) -> constantAux cconstantL x + Just (SpFloat Refl) -> constantAux cconstantF x + Just (SpInt32 Refl) -> constantAux cconstantI x + Just (SpCDouble Refl) -> constantAux cconstantC x + Just (SpCFloat Refl) -> constantAux cconstantQ x + +{- +extractR :: Typeable a => MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) +setRect :: Typeable a => Int -> Int -> Matrix a -> Matrix a -> IO () +sortI :: (Typeable a , Ord a ) => Vector a -> Vector CInt +sortV :: (Typeable a , Ord a ) => Vector a -> Vector a +compareV :: (Typeable a , Ord a ) => Vector a -> Vector a -> Vector CInt +selectV :: Typeable a => Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a +remapM :: Typeable a => Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a +rowOp :: Typeable a => Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () +gemm :: Typeable a => Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () +reorderV :: Typeable a => Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation +-} + + +( extractR , setRect , sortI , sortV , compareV , selectV , remapM , rowOp , gemm , reorderV ) + = error "todo Element" + +constantAux :: (Storable a1, Storable a) + => (Ptr a1 -> CInt -> Ptr a -> IO CInt) -> a1 -> Int -> Vector a +constantAux fun x n = unsafePerformIO $ do + v <- createVector n + px <- newArray [x] + (applyRaw v id) (fun px) #|"constantAux" + free px + return v + +type TConst t = Ptr t -> CInt -> Ptr t -> IO CInt +foreign import ccall unsafe "constantF" cconstantF :: TConst Float +foreign import ccall unsafe "constantR" cconstantR :: TConst Double +foreign import ccall unsafe "constantQ" cconstantQ :: TConst (Complex Float) +foreign import ccall unsafe "constantC" cconstantC :: TConst (Complex Double) +foreign import ccall unsafe "constantI" cconstantI :: TConst Int32 +foreign import ccall unsafe "constantL" cconstantL :: TConst Int64 -- cgit v1.2.3