{-# LANGUAGE BangPatterns #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE LambdaCase #-} module Internal.Specialized ( Mod(..) , f2i , i2f , f2iM , MatrixOrder(..) , Matrix(..) , createMatrix , matrixFromVector , cols , rows , size , shSize , shDim , constantD , extractR , setRect , sortI , sortV , compareV , selectV , remapM , rowOp , gemm , reorderV , specialize ) where import Control.Monad import Control.DeepSeq ( NFData(..) ) import Data.Coerce import Data.Complex import Data.Functor import Data.Int import Data.Maybe import Data.Typeable (eqT,Proxy) import Type.Reflection import Foreign.Marshal.Alloc(free,malloc) import Foreign.Marshal.Array(newArray,copyArray) import Foreign.ForeignPtr(castForeignPtr) import Foreign.Ptr import Foreign.Storable import Foreign.C.Types (CInt(..)) import System.IO.Unsafe #if MIN_VERSION_base(4,11,0) import GHC.TypeLits hiding (Mod) #else import GHC.TypeLits #endif import Internal.Vector -- (Vector,createVector,unsafeFromForeignPtr,unsafeToForeignPtr,(@>)) import Internal.Devel eqp :: (Typeable a, Typeable b) => proxy a -> Maybe (a :~: b) eqp _ = eqT ep32 :: (Typeable a) => proxy a -> Maybe (a :~: Int32) ep32 _ = eqT ep64 :: (Typeable a) => proxy a -> Maybe (a :~: Int64) ep64 _ = eqT epint :: (Typeable a) => proxy a -> Maybe (a :~: CInt) epint _ = eqT type Element t = (Storable t, Typeable t) -- | Wrapper with a phantom integer for statically checked modular arithmetic. newtype Mod (n :: Nat) t = Mod {unMod:: t} deriving (Storable) instance (NFData t) => NFData (Mod n t) where rnf (Mod x) = rnf x i2fM :: Storable t => Matrix t -> Matrix (Mod n t) i2fM m = m { xdat = i2f (xdat m) } i2f :: Storable t => Vector t -> Vector (Mod n t) i2f v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) where (fp,i,n) = unsafeToForeignPtr v f2i :: Storable t => Vector (Mod n t) -> Vector t f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) where (fp,i,n) = unsafeToForeignPtr v f2iM :: Storable t => Matrix (Mod n t) -> Matrix t f2iM m = m { xdat = f2i (xdat m) } data IntegralRep t a = IntegralRep { i2rep :: Vector t -> Vector a , i2repM :: Matrix t -> Matrix a , rep2i :: Vector a -> Vector t , rep2iM :: Matrix a -> Matrix t , rep2one :: a -> t , modulo :: Maybe t } instance Show (IntegralRep t a) where show _ = "IntegralRep" idint :: Storable t => IntegralRep t t idint = IntegralRep id id id id id Nothing coerceint :: Coercible t a => IntegralRep t a coerceint = IntegralRep coerce coerce coerce coerce coerce Nothing modint :: forall t n. (Read t, Storable t) => TypeRep n -> IntegralRep t (Mod n t) modint r = IntegralRep i2f i2fM f2i f2iM unMod (Just n) where n = read . show $ r -- XXX: Hack to get nat value from Type.Reflection -- n = fromIntegral . natVal $ (undefined :: Proxy n) coercerep :: Coercible s t => IntegralRep s a -> IntegralRep t a coercerep = coerce typeRepOf :: Typeable a => proxy a -> TypeRep a typeRepOf proxy = typeRep data Specialized a = SpFloat !(a :~: Float) | SpDouble !(a :~: Double) | SpCFloat !(a :~: Complex Float) | SpCDouble !(a :~: Complex Double) | SpInt32 !(IntegralRep Int32 a) | SpInt64 !(IntegralRep Int64 a) deriving Show specialize :: forall m a. Typeable a => m a -> Maybe (Specialized a) specialize x = foldr1 mplus [ SpDouble <$> eqp x , ep64 x <&> \Refl -> SpInt64 idint , SpFloat <$> eqp x , ep32 x <&> \Refl -> SpInt32 idint , SpCDouble <$> eqp x , SpCFloat <$> eqp x , epint x <&> \Refl -> SpInt32 coerceint , case typeRepOf x of App (App modtyp n) inttyp -> do HRefl <- eqTypeRep (typeRep :: TypeRep (Mod :: Nat -> * -> *)) modtyp mplus (mplus (eqTypeRep (typeRep :: TypeRep Int32) inttyp <&> \HRefl -> SpInt32 $ modint n) (eqTypeRep (typeRep :: TypeRep CInt) inttyp <&> \HRefl -> SpInt32 $ coercerep $ modint n)) (eqTypeRep (typeRep :: TypeRep Int64) inttyp <&> \HRefl -> SpInt64 $ modint n) _ -> Nothing ] -- | Supported matrix elements. constantD :: Typeable a => a -> Int -> Vector a constantD x = fromMaybe (error "constantD") $ specialize (const x) <&> \case SpDouble Refl -> constantAux cconstantR x SpInt64 r -> i2rep r . constantAux cconstantL (rep2one r x) SpFloat Refl -> constantAux cconstantF x SpInt32 r -> i2rep r . constantAux cconstantI (rep2one r x) SpCDouble Refl -> constantAux cconstantC x SpCFloat Refl -> constantAux cconstantQ x data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) -- | Matrix representation suitable for BLAS\/LAPACK computations. data Matrix t = Matrix { irows :: {-# UNPACK #-} !Int , icols :: {-# UNPACK #-} !Int , xRow :: {-# UNPACK #-} !Int , xCol :: {-# UNPACK #-} !Int , xdat :: {-# UNPACK #-} !(Vector t) } -- allocates memory for a new matrix createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a) createMatrix ord r c = do p <- createVector (r*c) return (matrixFromVector ord r c p) matrixFromVector :: Storable t => MatrixOrder -> Int -> Int -> Vector t -> Matrix t matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 } matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d } matrixFromVector o r c v | r * c == dim v = m | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m where m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = c, xCol = 1 } | otherwise = Matrix { irows = r, icols = c, xdat = v, xRow = 1, xCol = r } shSize :: Matrix t -> [Char] shSize = shDim . size shDim :: (Show a, Show a1) => (a1, a) -> [Char] shDim (r,c) = "(" ++ show r ++"x"++ show c ++")" size :: Matrix t -> (Int, Int) size m = (irows m, icols m) {-# INLINE size #-} extractR :: Typeable a => MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) extractR ord m = fromMaybe (\mi is mj js -> error "extractR") $ specialize m <&> \case SpDouble Refl -> extractAux c_extractD ord m SpInt64 r -> \mi is mj js -> i2repM r <$> extractAux c_extractL ord (rep2iM r m) mi is mj js SpFloat Refl -> extractAux c_extractF ord m SpInt32 r -> \mi is mj js -> i2repM r <$> extractAux (coerce c_extractI) ord (rep2iM r m) mi is mj js SpCDouble Refl -> extractAux c_extractC ord m SpCFloat Refl -> extractAux c_extractQ ord m setRect :: Typeable a => Int -> Int -> Matrix a -> Matrix a -> IO () setRect i j m x = fromMaybe (error "setRect") $ specialize m <&> \case SpDouble Refl -> setRectAux c_setRectD i j m x SpInt64 r -> setRectAux c_setRectL i j (rep2iM r m) (rep2iM r x) SpFloat Refl -> setRectAux c_setRectF i j m x SpInt32 r -> setRectAux (coerce c_setRectI) i j (rep2iM r m) (rep2iM r x) SpCDouble Refl -> setRectAux c_setRectC i j m x SpCFloat Refl -> setRectAux c_setRectQ i j m x sortI :: (Typeable a , Ord a) => Vector a -> Vector CInt sortI v = maybe (error "sortI") ($ v) $ specialize v <&> \case SpDouble Refl -> sortIdxD SpInt64 r -> sortIdxL . rep2i r SpFloat Refl -> sortIdxF SpInt32 r -> coerce sortIdxI . rep2i r SpCDouble Refl -> undefined -- Unreachable: Ord not implemented for Complex SpCFloat Refl -> undefined -- Unreachable: Ord not implemented for Complex sortV :: (Typeable a , Ord a ) => Vector a -> Vector a sortV v = maybe (error "sortV") ($ v) $ specialize v <&> \case SpDouble Refl -> sortValD SpInt64 r -> i2rep r . sortValL . rep2i r SpFloat Refl -> sortValF SpInt32 r -> i2rep r . coerce sortValI . rep2i r SpCDouble Refl -> undefined -- Unreachable: Ord not implemented for Complex SpCFloat Refl -> undefined -- Unreachable: Ord not implemented for Complex compareV :: (Typeable a , Ord a ) => Vector a -> Vector a -> Vector CInt compareV u v = fromMaybe (error "compareV" u v) $ specialize u <&> \case SpDouble Refl -> compareD u v SpInt64 r -> compareL (rep2i r u) (rep2i r v) SpFloat Refl -> compareF u v SpInt32 r -> coerce compareI (rep2i r u) (rep2i r v) SpCDouble Refl -> undefined -- Unreachable: Ord not implemented for Complex SpCFloat Refl -> undefined -- Unreachable: Ord not implemented for Complex selectV :: Typeable a => Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a selectV c l e g = fromMaybe (error "selectV" c l e g) $ specialize l <&> \case SpDouble Refl -> selectD c l e g SpInt64 r -> i2rep r (selectL c (rep2i r l) (rep2i r e) (rep2i r g)) SpFloat Refl -> selectF c l e g SpInt32 r -> i2rep r (coerce selectI c (rep2i r l) (rep2i r e) (rep2i r g)) SpCDouble Refl -> selectC c l e g SpCFloat Refl -> selectQ c l e g remapM :: Typeable a => Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a remapM i j m = fromMaybe (error "remapM" i j m) $ specialize m <&> \case SpDouble Refl -> remapD i j m SpInt64 r -> i2repM r (remapL i j (rep2iM r m)) SpFloat Refl -> remapF i j m SpInt32 r -> i2repM r (coerce remapI i j (rep2iM r m)) SpCDouble Refl -> remapC i j m SpCFloat Refl -> remapQ i j m rowOp :: Typeable a => Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () rowOp c a i1 i2 j1 j2 x = fromMaybe (error "rowOp") $ specialize x <&> \case SpDouble Refl -> rowOpAux c_rowOpD c a i1 i2 j1 j2 x SpInt64 r -> case modulo r of Just m' -> rowOpAux (c_rowOpML m') c (rep2one r a) i1 i2 j1 j2 (rep2iM r x) Nothing -> rowOpAux c_rowOpL c (rep2one r a) i1 i2 j1 j2 (rep2iM r x) SpFloat Refl -> rowOpAux c_rowOpF c a i1 i2 j1 j2 x SpInt32 r -> case modulo r of Just m' -> rowOpAux (coerce c_rowOpMI m') c (rep2one r a) i1 i2 j1 j2 (rep2iM r x) Nothing -> rowOpAux (coerce c_rowOpI) c (rep2one r a) i1 i2 j1 j2 (rep2iM r x) SpCDouble Refl -> rowOpAux c_rowOpC c a i1 i2 j1 j2 x SpCFloat Refl -> rowOpAux c_rowOpQ c a i1 i2 j1 j2 x gemm :: Typeable a => Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () gemm u a b c = fromMaybe (error "gemm") $ specialize u <&> \case SpDouble Refl -> gemmg c_gemmD u a b c SpInt64 r -> case modulo r of Just m' -> gemmg (c_gemmML m') (rep2i r u) (rep2iM r a) (rep2iM r b) (rep2iM r c) Nothing -> gemmg c_gemmL (rep2i r u) (rep2iM r a) (rep2iM r b) (rep2iM r c) SpFloat Refl -> gemmg c_gemmF u a b c SpInt32 r -> case modulo r of Just m' -> gemmg (coerce c_gemmMI m') (rep2i r u) (rep2iM r a) (rep2iM r b) (rep2iM r c) Nothing -> gemmg (coerce c_gemmI) (rep2i r u) (rep2iM r a) (rep2iM r b) (rep2iM r c) SpCDouble Refl -> gemmg c_gemmC u a b c SpCFloat Refl -> gemmg c_gemmQ u a b c reorderV :: Typeable a => Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation reorderV strides dims v = fromMaybe (error "reorderV") $ specialize v <&> \case SpDouble Refl -> reorderAux c_reorderD strides dims v SpInt64 r -> i2rep r $ reorderAux c_reorderL strides dims (rep2i r v) SpFloat Refl -> reorderAux c_reorderF strides dims v SpInt32 r -> i2rep r $ reorderAux (coerce c_reorderI) strides dims (rep2i r v) SpCDouble Refl -> reorderAux c_reorderC strides dims v SpCFloat Refl -> reorderAux c_reorderQ strides dims v instance Storable t => TransArray (Matrix t) where type TransRaw (Matrix t) b = CInt -> CInt -> Ptr t -> b type Trans (Matrix t) b = CInt -> CInt -> CInt -> CInt -> Ptr t -> b apply = amat {-# INLINE apply #-} applyRaw = amatr {-# INLINE applyRaw #-} -- C-Haskell matrix adapters {-# INLINE amatr #-} amatr :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> Ptr a -> f) -> IO r amatr x f g = unsafeWith (xdat x) (f . g r c) where r = fi (rows x) c = fi (cols x) {-# INLINE amat #-} amat :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> CInt -> CInt -> Ptr a -> f) -> IO r amat x f g = unsafeWith (xdat x) (f . g r c sr sc) where r = fi (rows x) c = fi (cols x) sr = fi (xRow x) sc = fi (xCol x) rows :: Matrix t -> Int rows = irows {-# INLINE rows #-} cols :: Matrix t -> Int cols = icols {-# INLINE cols #-} infixr 1 # (#) :: TransArray c => c -> (b -> IO r) -> Trans c b -> IO r a # b = apply a b {-# INLINE (#) #-} (#!) :: (TransArray c, TransArray c1) => c1 -> c -> Trans c1 (Trans c (IO r)) -> IO r a #! b = a # b # id {-# INLINE (#!) #-} constantAux :: (Storable a1, Storable a) => (Ptr a1 -> CInt -> Ptr a -> IO CInt) -> a1 -> Int -> Vector a constantAux fun x n = unsafePerformIO $ do v <- createVector n px <- newArray [x] (applyRaw v id) (fun px) #|"constantAux" free px return v type TConst t = Ptr t -> CInt -> Ptr t -> IO CInt foreign import ccall unsafe "constantF" cconstantF :: TConst Float foreign import ccall unsafe "constantR" cconstantR :: TConst Double foreign import ccall unsafe "constantQ" cconstantQ :: TConst (Complex Float) foreign import ccall unsafe "constantC" cconstantC :: TConst (Complex Double) foreign import ccall unsafe "constantI" cconstantI :: TConst Int32 foreign import ccall unsafe "constantL" cconstantL :: TConst Int64 {- extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1, Storable t, Num t3, Num t2, Integral t1, Integral t) => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)) -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a) extractAux f ord m moder vr modec vc = do let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc r <- createMatrix ord nr nc (vr # vc # m #!r) (f moder modec) #|"extract" return r -} extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1, Storable t, Num t3, Num t2, Integral t1, Integral t) => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)) -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a) extractAux f ord m moder vr modec vc = do let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc r <- createMatrix ord nr nc (vr # vc # m #! r) (f moder modec) #|"extract" return r type Extr x = CInt -> CInt -> CInt -> Ptr CInt -> -- CIdxs CInt -> Ptr CInt -> -- CIdxs CInt -> CInt -> CInt -> CInt -> Ptr x -> -- OM x CInt -> CInt -> CInt -> CInt -> Ptr x -> -- OM x IO CInt foreign import ccall unsafe "extractD" c_extractD :: Extr Double foreign import ccall unsafe "extractF" c_extractF :: Extr Float foreign import ccall unsafe "extractC" c_extractC :: Extr (Complex Double) foreign import ccall unsafe "extractQ" c_extractQ :: Extr (Complex Float) foreign import ccall unsafe "extractI" c_extractI :: Extr CInt foreign import ccall unsafe "extractL" c_extractL :: Extr Int64 setRectAux :: (TransArray c1, TransArray c) => (CInt -> CInt -> Trans c1 (Trans c (IO CInt))) -> Int -> Int -> c1 -> c -> IO () setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" type SetRect x = I -> I -> x ::> x::> Ok foreign import ccall unsafe "setRectD" c_setRectD :: SetRect Double foreign import ccall unsafe "setRectF" c_setRectF :: SetRect Float foreign import ccall unsafe "setRectC" c_setRectC :: SetRect (Complex Double) foreign import ccall unsafe "setRectQ" c_setRectQ :: SetRect (Complex Float) foreign import ccall unsafe "setRectI" c_setRectI :: SetRect I foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z sortG :: (Storable t, Storable a) => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a sortG f v = unsafePerformIO $ do r <- createVector (dim v) (v #! r) f #|"sortG" return r sortIdxD :: Vector Double -> Vector CInt sortIdxD = sortG c_sort_indexD sortIdxF :: Vector Float -> Vector CInt sortIdxF = sortG c_sort_indexF sortIdxI :: Vector CInt -> Vector CInt sortIdxI = sortG c_sort_indexI sortIdxL :: Vector Z -> Vector I sortIdxL = sortG c_sort_indexL foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt)) foreign import ccall unsafe "sort_indexF" c_sort_indexF :: CV Float (CV CInt (IO CInt)) foreign import ccall unsafe "sort_indexI" c_sort_indexI :: CV CInt (CV CInt (IO CInt)) foreign import ccall unsafe "sort_indexL" c_sort_indexL :: Z :> I :> Ok sortValD :: Vector Double -> Vector Double sortValD = sortG c_sort_valD sortValF :: Vector Float -> Vector Float sortValF = sortG c_sort_valF sortValI :: Vector CInt -> Vector CInt sortValI = sortG c_sort_valI sortValL :: Vector Z -> Vector Z sortValL = sortG c_sort_valL foreign import ccall unsafe "sort_valuesD" c_sort_valD :: CV Double (CV Double (IO CInt)) foreign import ccall unsafe "sort_valuesF" c_sort_valF :: CV Float (CV Float (IO CInt)) foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV CInt (CV CInt (IO CInt)) foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok compareG :: (TransArray c, Storable t, Storable a) => Trans c (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> c -> Vector t -> Vector a compareG f u v = unsafePerformIO $ do r <- createVector (dim v) (u # v #! r) f #|"compareG" return r compareD :: Vector Double -> Vector Double -> Vector CInt compareD = compareG c_compareD compareF :: Vector Float -> Vector Float -> Vector CInt compareF = compareG c_compareF compareI :: Vector CInt -> Vector CInt -> Vector CInt compareI = compareG c_compareI compareL :: Vector Z -> Vector Z -> Vector CInt compareL = compareG c_compareL foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV CInt (IO CInt))) foreign import ccall unsafe "compareI" c_compareI :: CV CInt (CV CInt (CV CInt (IO CInt))) foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a) => Trans c2 (Trans c1 (CInt -> Ptr t -> Trans c (CInt -> Ptr a -> IO CInt))) -> c2 -> c1 -> Vector t -> c -> Vector a selectG f c u v w = unsafePerformIO $ do r <- createVector (dim v) (c # u # v # w #! r) f #|"selectG" return r selectD :: Vector CInt -> Vector Double -> Vector Double -> Vector Double -> Vector Double selectD = selectG c_selectD selectF :: Vector CInt -> Vector Float -> Vector Float -> Vector Float -> Vector Float selectF = selectG c_selectF selectI :: Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt selectI = selectG c_selectI selectL :: Vector CInt -> Vector Z -> Vector Z -> Vector Z -> Vector Z selectL = selectG c_selectL selectC :: Vector CInt -> Vector (Complex Double) -> Vector (Complex Double) -> Vector (Complex Double) -> Vector (Complex Double) selectC = selectG c_selectC selectQ :: Vector CInt -> Vector (Complex Float) -> Vector (Complex Float) -> Vector (Complex Float) -> Vector (Complex Float) selectQ = selectG c_selectQ type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) foreign import ccall unsafe "chooseD" c_selectD :: Sel Double foreign import ccall unsafe "chooseF" c_selectF :: Sel Float foreign import ccall unsafe "chooseI" c_selectI :: Sel CInt foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double) foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float) foreign import ccall unsafe "chooseL" c_selectL :: Sel Z remapG :: (TransArray c, TransArray c1, Storable t, Storable a) => (CInt -> CInt -> CInt -> CInt -> Ptr t -> Trans c1 (Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt))) -> Matrix t -> c1 -> c -> Matrix a remapG f i j m = unsafePerformIO $ do r <- createMatrix RowMajor (rows i) (cols i) (i # j # m #! r) f #|"remapG" return r remapD :: Matrix CInt -> Matrix CInt -> Matrix Double -> Matrix Double remapD = remapG c_remapD remapF :: Matrix CInt -> Matrix CInt -> Matrix Float -> Matrix Float remapF = remapG c_remapF remapI :: Matrix CInt -> Matrix CInt -> Matrix CInt -> Matrix CInt remapI = remapG c_remapI remapL :: Matrix CInt -> Matrix CInt -> Matrix Z -> Matrix Z remapL = remapG c_remapL remapC :: Matrix CInt -> Matrix CInt -> Matrix (Complex Double) -> Matrix (Complex Double) remapC = remapG c_remapC remapQ :: Matrix CInt -> Matrix CInt -> Matrix (Complex Float) -> Matrix (Complex Float) remapQ = remapG c_remapQ type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) foreign import ccall unsafe "remapD" c_remapD :: Rem Double foreign import ccall unsafe "remapF" c_remapF :: Rem Float foreign import ccall unsafe "remapI" c_remapI :: Rem CInt foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double) foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float) foreign import ccall unsafe "remapL" c_remapL :: Rem Z rowOpAux :: (TransArray c, Storable a) => (CInt -> Ptr a -> CInt -> CInt -> CInt -> CInt -> Trans c (IO CInt)) -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO () rowOpAux f c x i1 i2 j1 j2 m = do px <- newArray [x] (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" free px type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok foreign import ccall unsafe "rowop_double" c_rowOpD :: RowOp R foreign import ccall unsafe "rowop_float" c_rowOpF :: RowOp Float foreign import ccall unsafe "rowop_TCD" c_rowOpC :: RowOp C foreign import ccall unsafe "rowop_TCF" c_rowOpQ :: RowOp (Complex Float) foreign import ccall unsafe "rowop_int32_t" c_rowOpI :: RowOp I foreign import ccall unsafe "rowop_int64_t" c_rowOpL :: RowOp Z foreign import ccall unsafe "rowop_mod_int32_t" c_rowOpMI :: I -> RowOp I foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z gemmg :: Storable x => Tgemm x -> Vector x -> Matrix x -> Matrix x -> Matrix x -> IO () gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" type Tgemm x = x :> x ::> x ::> x ::> Ok foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float foreign import ccall unsafe "gemm_TCD" c_gemmC :: Tgemm C foreign import ccall unsafe "gemm_TCF" c_gemmQ :: Tgemm (Complex Float) foreign import ccall unsafe "gemm_int32_t" c_gemmI :: Tgemm I foreign import ccall unsafe "gemm_int64_t" c_gemmL :: Tgemm Z foreign import ccall unsafe "gemm_mod_int32_t" c_gemmMI :: I -> Tgemm I foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z reorderAux :: Storable x => Reorder x -> Vector CInt -> Vector CInt -> Vector x -> Vector x reorderAux f s d v = unsafePerformIO $ do k <- createVector (dim s) r <- createVector (dim v) (k # s # d # v #! r) f #| "reorderV" return r type Reorder x = CV CInt (CV CInt (CV CInt (CV x (CV x (IO CInt))))) foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float foreign import ccall unsafe "reorderI" c_reorderI :: Reorder CInt foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double) foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float) foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z