{-# LANGUAGE BangPatterns #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} module Internal.Specialized where import Control.Monad import Data.Complex import Data.Int import Data.Typeable import Foreign.Marshal.Alloc(free,malloc) import Foreign.Marshal.Array(newArray,copyArray) import Foreign.Ptr import Foreign.Storable import Foreign.C.Types (CInt(..)) import Data.Vector.Storable (Vector) import System.IO.Unsafe import Internal.Vector (createVector) import Internal.Devel type Element t = (Storable t, Typeable t) data Specialized a = SpFloat !(a :~: Float) | SpDouble !(a :~: Double) | SpCFloat !(a :~: Complex Float) | SpCDouble !(a :~: Complex Double) | SpInt32 !(a :~: Int32) | SpInt64 !(a :~: Int64) specialize :: Typeable a => a -> Maybe (Specialized a) specialize x = foldr1 mplus [ SpDouble <$> cst x , SpInt64 <$> cst x , SpFloat <$> cst x , SpInt32 <$> cst x , SpCDouble <$> cst x , SpCFloat <$> cst x -- , fmap (\(CInt y) -> SpInt32 y) <$> cast x ] where cst :: (Typeable a, Typeable b) => a -> Maybe (a :~: b) cst _ = eqT -- | Supported matrix elements. constantD :: Typeable a => a -> Int -> Vector a constantD x = case specialize x of Nothing -> error "constantD" Just (SpDouble Refl) -> constantAux cconstantR x Just (SpInt64 Refl) -> constantAux cconstantL x Just (SpFloat Refl) -> constantAux cconstantF x Just (SpInt32 Refl) -> constantAux cconstantI x Just (SpCDouble Refl) -> constantAux cconstantC x Just (SpCFloat Refl) -> constantAux cconstantQ x {- extractR :: Typeable a => MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) setRect :: Typeable a => Int -> Int -> Matrix a -> Matrix a -> IO () sortI :: (Typeable a , Ord a ) => Vector a -> Vector CInt sortV :: (Typeable a , Ord a ) => Vector a -> Vector a compareV :: (Typeable a , Ord a ) => Vector a -> Vector a -> Vector CInt selectV :: Typeable a => Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a remapM :: Typeable a => Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a rowOp :: Typeable a => Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () gemm :: Typeable a => Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () reorderV :: Typeable a => Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation -} ( extractR , setRect , sortI , sortV , compareV , selectV , remapM , rowOp , gemm , reorderV ) = error "todo Element" constantAux :: (Storable a1, Storable a) => (Ptr a1 -> CInt -> Ptr a -> IO CInt) -> a1 -> Int -> Vector a constantAux fun x n = unsafePerformIO $ do v <- createVector n px <- newArray [x] (applyRaw v id) (fun px) #|"constantAux" free px return v type TConst t = Ptr t -> CInt -> Ptr t -> IO CInt foreign import ccall unsafe "constantF" cconstantF :: TConst Float foreign import ccall unsafe "constantR" cconstantR :: TConst Double foreign import ccall unsafe "constantQ" cconstantQ :: TConst (Complex Float) foreign import ccall unsafe "constantC" cconstantC :: TConst (Complex Double) foreign import ccall unsafe "constantI" cconstantI :: TConst Int32 foreign import ccall unsafe "constantL" cconstantL :: TConst Int64