{-# LANGUAGE BangPatterns #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} module Internal.Specialized ( Mod(..) , f2i , i2f , f2iM , MatrixOrder(..) , Matrix(..) , createMatrix , matrixFromVector , cols , rows , size , shSize , shDim , constantD , extractR , setRect , sortI , sortV , compareV , selectV , remapM , rowOp , gemm , reorderV , specialize ) where import Control.Monad import Control.DeepSeq ( NFData(..) ) import Data.Coerce import Data.Complex import Data.Functor import Data.Int import Data.IORef import Data.Maybe import Data.Typeable (eqT,Proxy(..),cast) import Type.Reflection import Foreign.Marshal.Alloc(free,malloc) import Foreign.Marshal.Array(newArray,copyArray) import Foreign.ForeignPtr(castForeignPtr) import Foreign.Ptr import Foreign.Storable import Foreign.C.Types (CInt(..)) import System.IO.Unsafe #if MIN_VERSION_base(4,11,0) import GHC.TypeLits hiding (Mod) #else import GHC.TypeLits #endif import Internal.Vector -- (Vector,createVector,unsafeFromForeignPtr,unsafeToForeignPtr,(@>)) import Internal.Devel eqp :: (Typeable a, Typeable b) => proxy a -> Maybe (a :~: b) eqp _ = eqT ep32 :: (Typeable a) => proxy a -> Maybe (a :~: Int32) ep32 _ = eqT ep64 :: (Typeable a) => proxy a -> Maybe (a :~: Int64) ep64 _ = eqT epint :: (Typeable a) => proxy a -> Maybe (a :~: CInt) epint _ = eqT type Element t = (Storable t, Typeable t) -- | Wrapper with a phantom integer for statically checked modular arithmetic. newtype Mod (n :: Nat) t = Mod {unMod:: t} deriving (Storable) instance (NFData t) => NFData (Mod n t) where rnf (Mod x) = rnf x i2fM :: Storable t => Matrix t -> Matrix (Mod n t) i2fM m = m { xdat = i2f (xdat m) } i2f :: Storable t => Vector t -> Vector (Mod n t) i2f v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) where (fp,i,n) = unsafeToForeignPtr v f2i :: Storable t => Vector (Mod n t) -> Vector t f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) where (fp,i,n) = unsafeToForeignPtr v f2iM :: Storable t => Matrix (Mod n t) -> Matrix t f2iM m = m { xdat = f2i (xdat m) } data IntegralRep t a = IntegralRep { i2rep :: Vector t -> Vector a , i2repM :: Matrix t -> Matrix a , rep2i :: Vector a -> Vector t , rep2iM :: Matrix a -> Matrix t , rep2one :: a -> t , modulo :: !(Maybe t) } instance Show (IntegralRep t a) where show _ = "IntegralRep" idint :: Storable t => IntegralRep t t idint = IntegralRep id id id id id Nothing coerceint :: Coercible t a => IntegralRep t a coerceint = IntegralRep coerce coerce coerce coerce coerce Nothing -- This exists to hopefully save me from parsing a string since Type.Reflection -- currently has no direct way to extract a Nat value from a TypeRep. cachedNat :: IORef SomeNat cachedNat = unsafePerformIO $ newIORef (SomeNat (Proxy :: Proxy 3)) {-# NOINLINE cachedNat #-} withTypes :: p (a::k) -> q (b::h) -> f a b -> f a b withTypes _ _ = id modint :: forall t n. (Read t, Storable t, Integral t) => TypeRep n -> IntegralRep t (Mod n t) modint r = IntegralRep i2f i2fM f2i f2iM unMod (n `seq` Just n) where -- n = withTypeable r $ fromIntegral . natVal $ (undefined :: Proxy n) -- If only.. n = case unsafePerformIO $ readIORef cachedNat of SomeNat c -> withTypeable r $ case withTypes c r <$> eqT of Just Refl -> fromIntegral $ natVal c _ -> unsafePerformIO $ do let newnat = read . show $ r -- XXX: Hack to get nat value from Type.Reflection case someNatVal $ fromIntegral newnat of Just somenat@(SomeNat nt) -> nt `seq` writeIORef cachedNat somenat _ -> return () return newnat coercerep :: Coercible s t => IntegralRep s a -> IntegralRep t a coercerep = coerce typeRepOf :: Typeable a => proxy a -> TypeRep a typeRepOf proxy = typeRep data Specialized a = SpFloat !(a :~: Float) | SpDouble !(a :~: Double) | SpCFloat !(a :~: Complex Float) | SpCDouble !(a :~: Complex Double) | SpInt32 !(IntegralRep Int32 a) | SpInt64 !(IntegralRep Int64 a) deriving Show specialize :: forall m a. Typeable a => m a -> Maybe (Specialized a) specialize x = foldr1 mplus [ SpDouble <$> eqp x , ep64 x <&> \Refl -> SpInt64 idint , SpFloat <$> eqp x , ep32 x <&> \Refl -> SpInt32 idint , SpCDouble <$> eqp x , SpCFloat <$> eqp x , epint x <&> \Refl -> SpInt32 coerceint , case typeRepOf x of App (App modtyp n) inttyp -> do HRefl <- eqTypeRep (typeRep :: TypeRep (Mod :: Nat -> * -> *)) modtyp mplus (mplus (eqTypeRep (typeRep :: TypeRep Int32) inttyp <&> \HRefl -> SpInt32 $ modint n) (eqTypeRep (typeRep :: TypeRep CInt) inttyp <&> \HRefl -> SpInt32 $ coercerep $ modint n)) (eqTypeRep (typeRep :: TypeRep Int64) inttyp <&> \HRefl -> SpInt64 $ modint n) _ -> Nothing ] -- | Supported matrix elements. constantD :: Typeable a => a -> Int -> Vector a constantD x = fromMaybe (error "constantD") $ specialize (const x) <&> \case SpDouble Refl -> constantAux cconstantR x SpInt64 r -> i2rep r . constantAux cconstantL (rep2one r x) SpFloat Refl -> constantAux cconstantF x SpInt32 r -> i2rep r . constantAux cconstantI (rep2one r x) SpCDouble Refl -> constantAux cconstantC x SpCFloat Refl -> constantAux cconstantQ x data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) -- | Matrix representation suitable for BLAS\/LAPACK computations. data Matrix t = Matrix { irows :: {-# UNPACK #-} !Int , icols :: {-# UNPACK #-} !Int , xRow :: {-# UNPACK #-} !Int , xCol :: {-# UNPACK #-} !Int , xdat :: {-# UNPACK #-} !(Vector t) } -- allocates memory for a new matrix createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a) createMatrix ord r c = do p <- createVector (r*c) return (matrixFromVector ord r c p) matrixFromVector :: Storable t => MatrixOrder -> Int -> Int -> Vector t -> Matrix t matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 } matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d } matrixFromVector o r c v | r * c == dim v = m | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m where m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = c, xCol = 1 } | otherwise = Matrix { irows = r, icols = c, xdat = v, xRow = 1, xCol = r } shSize :: Matrix t -> [Char] shSize = shDim . size shDim :: (Show a, Show a1) => (a1, a) -> [Char] shDim (r,c) = "(" ++ show r ++"x"++ show c ++")" size :: Matrix t -> (Int, Int) size m = (irows m, icols m) {-# INLINE size #-} extractR :: Typeable a => MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) extractR ord m = fromMaybe (\mi is mj js -> error "extractR") $ specialize m <&> \case SpDouble Refl -> extractAux c_extractD ord m SpInt64 r -> \mi is mj js -> i2repM r <$> extractAux c_extractL ord (rep2iM r m) mi is mj js SpFloat Refl -> extractAux c_extractF ord m SpInt32 r -> \mi is mj js -> i2repM r <$> extractAux (coerce c_extractI) ord (rep2iM r m) mi is mj js SpCDouble Refl -> extractAux c_extractC ord m SpCFloat Refl -> extractAux c_extractQ ord m setRect :: Typeable a => Int -> Int -> Matrix a -> Matrix a -> IO () setRect i j m x = fromMaybe (error "setRect") $ specialize m <&> \case SpDouble Refl -> setRectAux c_setRectD i j m x SpInt64 r -> setRectAux c_setRectL i j (rep2iM r m) (rep2iM r x) SpFloat Refl -> setRectAux c_setRectF i j m x SpInt32 r -> setRectAux (coerce c_setRectI) i j (rep2iM r m) (rep2iM r x) SpCDouble Refl -> setRectAux c_setRectC i j m x SpCFloat Refl -> setRectAux c_setRectQ i j m x sortI :: (Typeable a , Ord a) => Vector a -> Vector CInt sortI v = maybe (error "sortI") ($ v) $ specialize v <&> \case SpDouble Refl -> sortIdxD SpInt64 r -> sortIdxL . rep2i r SpFloat Refl -> sortIdxF SpInt32 r -> coerce sortIdxI . rep2i r SpCDouble Refl -> undefined -- Unreachable: Ord not implemented for Complex SpCFloat Refl -> undefined -- Unreachable: Ord not implemented for Complex sortV :: (Typeable a , Ord a ) => Vector a -> Vector a sortV v = maybe (error "sortV") ($ v) $ specialize v <&> \case SpDouble Refl -> sortValD SpInt64 r -> i2rep r . sortValL . rep2i r SpFloat Refl -> sortValF SpInt32 r -> i2rep r . coerce sortValI . rep2i r SpCDouble Refl -> undefined -- Unreachable: Ord not implemented for Complex SpCFloat Refl -> undefined -- Unreachable: Ord not implemented for Complex compareV :: (Typeable a , Ord a ) => Vector a -> Vector a -> Vector CInt compareV u v = fromMaybe (error "compareV" u v) $ specialize u <&> \case SpDouble Refl -> compareD u v SpInt64 r -> compareL (rep2i r u) (rep2i r v) SpFloat Refl -> compareF u v SpInt32 r -> coerce compareI (rep2i r u) (rep2i r v) SpCDouble Refl -> undefined -- Unreachable: Ord not implemented for Complex SpCFloat Refl -> undefined -- Unreachable: Ord not implemented for Complex selectV :: Typeable a => Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a selectV c l e g = fromMaybe (error "selectV" c l e g) $ specialize l <&> \case SpDouble Refl -> selectD c l e g SpInt64 r -> i2rep r (selectL c (rep2i r l) (rep2i r e) (rep2i r g)) SpFloat Refl -> selectF c l e g SpInt32 r -> i2rep r (coerce selectI c (rep2i r l) (rep2i r e) (rep2i r g)) SpCDouble Refl -> selectC c l e g SpCFloat Refl -> selectQ c l e g remapM :: Typeable a => Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a remapM i j m = fromMaybe (error "remapM" i j m) $ specialize m <&> \case SpDouble Refl -> remapD i j m SpInt64 r -> i2repM r (remapL i j (rep2iM r m)) SpFloat Refl -> remapF i j m SpInt32 r -> i2repM r (coerce remapI i j (rep2iM r m)) SpCDouble Refl -> remapC i j m SpCFloat Refl -> remapQ i j m rowOp :: Typeable a => Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () rowOp c a i1 i2 j1 j2 x = fromMaybe (error "rowOp") $ specialize x <&> \case SpDouble Refl -> rowOpAux c_rowOpD c a i1 i2 j1 j2 x SpInt64 r -> case modulo r of Just m' -> rowOpAux (c_rowOpML m') c (rep2one r a) i1 i2 j1 j2 (rep2iM r x) Nothing -> rowOpAux c_rowOpL c (rep2one r a) i1 i2 j1 j2 (rep2iM r x) SpFloat Refl -> rowOpAux c_rowOpF c a i1 i2 j1 j2 x SpInt32 r -> case modulo r of Just m' -> rowOpAux (coerce c_rowOpMI m') c (rep2one r a) i1 i2 j1 j2 (rep2iM r x) Nothing -> rowOpAux (coerce c_rowOpI) c (rep2one r a) i1 i2 j1 j2 (rep2iM r x) SpCDouble Refl -> rowOpAux c_rowOpC c a i1 i2 j1 j2 x SpCFloat Refl -> rowOpAux c_rowOpQ c a i1 i2 j1 j2 x gemm :: Typeable a => Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () gemm u a b c = fromMaybe (error "gemm") $ specialize u <&> \case SpDouble Refl -> gemmg c_gemmD u a b c SpInt64 r -> case modulo r of Just m' -> gemmg (c_gemmML m') (rep2i r u) (rep2iM r a) (rep2iM r b) (rep2iM r c) Nothing -> gemmg c_gemmL (rep2i r u) (rep2iM r a) (rep2iM r b) (rep2iM r c) SpFloat Refl -> gemmg c_gemmF u a b c SpInt32 r -> case modulo r of Just m' -> gemmg (coerce c_gemmMI m') (rep2i r u) (rep2iM r a) (rep2iM r b) (rep2iM r c) Nothing -> gemmg (coerce c_gemmI) (rep2i r u) (rep2iM r a) (rep2iM r b) (rep2iM r c) SpCDouble Refl -> gemmg c_gemmC u a b c SpCFloat Refl -> gemmg c_gemmQ u a b c reorderV :: Typeable a => Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation reorderV strides dims v = fromMaybe (error "reorderV") $ specialize v <&> \case SpDouble Refl -> reorderAux c_reorderD strides dims v SpInt64 r -> i2rep r $ reorderAux c_reorderL strides dims (rep2i r v) SpFloat Refl -> reorderAux c_reorderF strides dims v SpInt32 r -> i2rep r $ reorderAux (coerce c_reorderI) strides dims (rep2i r v) SpCDouble Refl -> reorderAux c_reorderC strides dims v SpCFloat Refl -> reorderAux c_reorderQ strides dims v instance Storable t => TransArray (Matrix t) where type TransRaw (Matrix t) b = CInt -> CInt -> Ptr t -> b type Trans (Matrix t) b = CInt -> CInt -> CInt -> CInt -> Ptr t -> b apply = amat {-# INLINE apply #-} applyRaw = amatr {-# INLINE applyRaw #-} -- C-Haskell matrix adapters {-# INLINE amatr #-} amatr :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> Ptr a -> f) -> IO r amatr x f g = unsafeWith (xdat x) (f . g r c) where r = fi (rows x) c = fi (cols x) {-# INLINE amat #-} amat :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> CInt -> CInt -> Ptr a -> f) -> IO r amat x f g = unsafeWith (xdat x) (f . g r c sr sc) where r = fi (rows x) c = fi (cols x) sr = fi (xRow x) sc = fi (xCol x) rows :: Matrix t -> Int rows = irows {-# INLINE rows #-} cols :: Matrix t -> Int cols = icols {-# INLINE cols #-} infixr 1 # (#) :: TransArray c => c -> (b -> IO r) -> Trans c b -> IO r a # b = apply a b {-# INLINE (#) #-} (#!) :: (TransArray c, TransArray c1) => c1 -> c -> Trans c1 (Trans c (IO r)) -> IO r a #! b = a # b # id {-# INLINE (#!) #-} constantAux :: (Storable a1, Storable a) => (Ptr a1 -> CInt -> Ptr a -> IO CInt) -> a1 -> Int -> Vector a constantAux fun x n = unsafePerformIO $ do v <- createVector n px <- newArray [x] (applyRaw v id) (fun px) #|"constantAux" free px return v type TConst t = Ptr t -> CInt -> Ptr t -> IO CInt foreign import ccall unsafe "constantF" cconstantF :: TConst Float foreign import ccall unsafe "constantR" cconstantR :: TConst Double foreign import ccall unsafe "constantQ" cconstantQ :: TConst (Complex Float) foreign import ccall unsafe "constantC" cconstantC :: TConst (Complex Double) foreign import ccall unsafe "constantI" cconstantI :: TConst Int32 foreign import ccall unsafe "constantL" cconstantL :: TConst Int64 {- extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1, Storable t, Num t3, Num t2, Integral t1, Integral t) => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)) -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a) extractAux f ord m moder vr modec vc = do let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc r <- createMatrix ord nr nc (vr # vc # m #!r) (f moder modec) #|"extract" return r -} extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1, Storable t, Num t3, Num t2, Integral t1, Integral t) => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)) -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a) extractAux f ord m moder vr modec vc = do let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc r <- createMatrix ord nr nc (vr # vc # m #! r) (f moder modec) #|"extract" return r type Extr x = CInt -> CInt -> CInt -> Ptr CInt -> -- CIdxs CInt -> Ptr CInt -> -- CIdxs CInt -> CInt -> CInt -> CInt -> Ptr x -> -- OM x CInt -> CInt -> CInt -> CInt -> Ptr x -> -- OM x IO CInt foreign import ccall unsafe "extractD" c_extractD :: Extr Double foreign import ccall unsafe "extractF" c_extractF :: Extr Float foreign import ccall unsafe "extractC" c_extractC :: Extr (Complex Double) foreign import ccall unsafe "extractQ" c_extractQ :: Extr (Complex Float) foreign import ccall unsafe "extractI" c_extractI :: Extr CInt foreign import ccall unsafe "extractL" c_extractL :: Extr Int64 setRectAux :: (TransArray c1, TransArray c) => (CInt -> CInt -> Trans c1 (Trans c (IO CInt))) -> Int -> Int -> c1 -> c -> IO () setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" type SetRect x = I -> I -> x ::> x::> Ok foreign import ccall unsafe "setRectD" c_setRectD :: SetRect Double foreign import ccall unsafe "setRectF" c_setRectF :: SetRect Float foreign import ccall unsafe "setRectC" c_setRectC :: SetRect (Complex Double) foreign import ccall unsafe "setRectQ" c_setRectQ :: SetRect (Complex Float) foreign import ccall unsafe "setRectI" c_setRectI :: SetRect I foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z sortG :: (Storable t, Storable a) => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a sortG f v = unsafePerformIO $ do r <- createVector (dim v) (v #! r) f #|"sortG" return r sortIdxD :: Vector Double -> Vector CInt sortIdxD = sortG c_sort_indexD sortIdxF :: Vector Float -> Vector CInt sortIdxF = sortG c_sort_indexF sortIdxI :: Vector CInt -> Vector CInt sortIdxI = sortG c_sort_indexI sortIdxL :: Vector Z -> Vector I sortIdxL = sortG c_sort_indexL foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt)) foreign import ccall unsafe "sort_indexF" c_sort_indexF :: CV Float (CV CInt (IO CInt)) foreign import ccall unsafe "sort_indexI" c_sort_indexI :: CV CInt (CV CInt (IO CInt)) foreign import ccall unsafe "sort_indexL" c_sort_indexL :: Z :> I :> Ok sortValD :: Vector Double -> Vector Double sortValD = sortG c_sort_valD sortValF :: Vector Float -> Vector Float sortValF = sortG c_sort_valF sortValI :: Vector CInt -> Vector CInt sortValI = sortG c_sort_valI sortValL :: Vector Z -> Vector Z sortValL = sortG c_sort_valL foreign import ccall unsafe "sort_valuesD" c_sort_valD :: CV Double (CV Double (IO CInt)) foreign import ccall unsafe "sort_valuesF" c_sort_valF :: CV Float (CV Float (IO CInt)) foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV CInt (CV CInt (IO CInt)) foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok compareG :: (TransArray c, Storable t, Storable a) => Trans c (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> c -> Vector t -> Vector a compareG f u v = unsafePerformIO $ do r <- createVector (dim v) (u # v #! r) f #|"compareG" return r compareD :: Vector Double -> Vector Double -> Vector CInt compareD = compareG c_compareD compareF :: Vector Float -> Vector Float -> Vector CInt compareF = compareG c_compareF compareI :: Vector CInt -> Vector CInt -> Vector CInt compareI = compareG c_compareI compareL :: Vector Z -> Vector Z -> Vector CInt compareL = compareG c_compareL foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV CInt (IO CInt))) foreign import ccall unsafe "compareI" c_compareI :: CV CInt (CV CInt (CV CInt (IO CInt))) foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a) => Trans c2 (Trans c1 (CInt -> Ptr t -> Trans c (CInt -> Ptr a -> IO CInt))) -> c2 -> c1 -> Vector t -> c -> Vector a selectG f c u v w = unsafePerformIO $ do r <- createVector (dim v) (c # u # v # w #! r) f #|"selectG" return r selectD :: Vector CInt -> Vector Double -> Vector Double -> Vector Double -> Vector Double selectD = selectG c_selectD selectF :: Vector CInt -> Vector Float -> Vector Float -> Vector Float -> Vector Float selectF = selectG c_selectF selectI :: Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt selectI = selectG c_selectI selectL :: Vector CInt -> Vector Z -> Vector Z -> Vector Z -> Vector Z selectL = selectG c_selectL selectC :: Vector CInt -> Vector (Complex Double) -> Vector (Complex Double) -> Vector (Complex Double) -> Vector (Complex Double) selectC = selectG c_selectC selectQ :: Vector CInt -> Vector (Complex Float) -> Vector (Complex Float) -> Vector (Complex Float) -> Vector (Complex Float) selectQ = selectG c_selectQ type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) foreign import ccall unsafe "chooseD" c_selectD :: Sel Double foreign import ccall unsafe "chooseF" c_selectF :: Sel Float foreign import ccall unsafe "chooseI" c_selectI :: Sel CInt foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double) foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float) foreign import ccall unsafe "chooseL" c_selectL :: Sel Z remapG :: (TransArray c, TransArray c1, Storable t, Storable a) => (CInt -> CInt -> CInt -> CInt -> Ptr t -> Trans c1 (Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt))) -> Matrix t -> c1 -> c -> Matrix a remapG f i j m = unsafePerformIO $ do r <- createMatrix RowMajor (rows i) (cols i) (i # j # m #! r) f #|"remapG" return r remapD :: Matrix CInt -> Matrix CInt -> Matrix Double -> Matrix Double remapD = remapG c_remapD remapF :: Matrix CInt -> Matrix CInt -> Matrix Float -> Matrix Float remapF = remapG c_remapF remapI :: Matrix CInt -> Matrix CInt -> Matrix CInt -> Matrix CInt remapI = remapG c_remapI remapL :: Matrix CInt -> Matrix CInt -> Matrix Z -> Matrix Z remapL = remapG c_remapL remapC :: Matrix CInt -> Matrix CInt -> Matrix (Complex Double) -> Matrix (Complex Double) remapC = remapG c_remapC remapQ :: Matrix CInt -> Matrix CInt -> Matrix (Complex Float) -> Matrix (Complex Float) remapQ = remapG c_remapQ type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) foreign import ccall unsafe "remapD" c_remapD :: Rem Double foreign import ccall unsafe "remapF" c_remapF :: Rem Float foreign import ccall unsafe "remapI" c_remapI :: Rem CInt foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double) foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float) foreign import ccall unsafe "remapL" c_remapL :: Rem Z rowOpAux :: (TransArray c, Storable a) => (CInt -> Ptr a -> CInt -> CInt -> CInt -> CInt -> Trans c (IO CInt)) -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO () rowOpAux f c x i1 i2 j1 j2 m = do px <- newArray [x] (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" free px type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok foreign import ccall unsafe "rowop_double" c_rowOpD :: RowOp R foreign import ccall unsafe "rowop_float" c_rowOpF :: RowOp Float foreign import ccall unsafe "rowop_TCD" c_rowOpC :: RowOp C foreign import ccall unsafe "rowop_TCF" c_rowOpQ :: RowOp (Complex Float) foreign import ccall unsafe "rowop_int32_t" c_rowOpI :: RowOp I foreign import ccall unsafe "rowop_int64_t" c_rowOpL :: RowOp Z foreign import ccall unsafe "rowop_mod_int32_t" c_rowOpMI :: I -> RowOp I foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z gemmg :: Storable x => Tgemm x -> Vector x -> Matrix x -> Matrix x -> Matrix x -> IO () gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" type Tgemm x = x :> x ::> x ::> x ::> Ok foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float foreign import ccall unsafe "gemm_TCD" c_gemmC :: Tgemm C foreign import ccall unsafe "gemm_TCF" c_gemmQ :: Tgemm (Complex Float) foreign import ccall unsafe "gemm_int32_t" c_gemmI :: Tgemm I foreign import ccall unsafe "gemm_int64_t" c_gemmL :: Tgemm Z foreign import ccall unsafe "gemm_mod_int32_t" c_gemmMI :: I -> Tgemm I foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z reorderAux :: Storable x => Reorder x -> Vector CInt -> Vector CInt -> Vector x -> Vector x reorderAux f s d v = unsafePerformIO $ do k <- createVector (dim s) r <- createVector (dim v) (k # s # d # v #! r) f #| "reorderV" return r type Reorder x = CV CInt (CV CInt (CV CInt (CV x (CV x (IO CInt))))) foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float foreign import ccall unsafe "reorderI" c_reorderI :: Reorder CInt foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double) foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float) foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z