From badcbdfddc4be31fc79a6df4553795af18069efe Mon Sep 17 00:00:00 2001 From: Joe Crayne Date: Thu, 8 Aug 2019 02:22:30 -0400 Subject: Removed the Element class. --- packages/base/src/Internal/Matrix.hs | 307 ++++++++++++++--------------------- 1 file changed, 123 insertions(+), 184 deletions(-) (limited to 'packages/base/src/Internal/Matrix.hs') diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 5436e59..04092f9 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs @@ -2,6 +2,7 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE ViewPatterns #-} @@ -22,12 +23,14 @@ module Internal.Matrix where import Internal.Vector import Internal.Devel +import Internal.Extract import Internal.Vectorized hiding ((#), (#!)) import Foreign.Marshal.Alloc ( free ) import Foreign.Marshal.Array(newArray) import Foreign.Ptr ( Ptr ) import Foreign.Storable ( Storable ) import Data.Complex ( Complex ) +import Data.Int import Foreign.C.Types ( CInt(..) ) import Foreign.C.String ( CString, newCString ) import System.IO.Unsafe ( unsafePerformIO ) @@ -61,19 +64,23 @@ size :: Matrix t -> (Int, Int) size m = (irows m, icols m) {-# INLINE size #-} +-- | True if the matrix is in RowMajor form. rowOrder :: Matrix t -> Bool rowOrder m = xCol m == 1 || cols m == 1 {-# INLINE rowOrder #-} +-- | True if the matrix is in ColMajor form or if their is only one row. colOrder :: Matrix t -> Bool colOrder m = xRow m == 1 || rows m == 1 {-# INLINE colOrder #-} +-- | True if the matrix is a single row or column vector. is1d :: Matrix t -> Bool is1d (size->(r,c)) = r==1 || c==1 {-# INLINE is1d #-} --- data is not contiguous +-- | True if the matrix is not contiguous. This usually +-- means it is a slice of some larger matrix. isSlice :: Storable t => Matrix t -> Bool isSlice m@(size->(r,c)) = r*c < dim (xdat m) {-# INLINE isSlice #-} @@ -95,19 +102,23 @@ showInternal m = printf "%dx%d %s %s %d:%d (%d)\n" r c slc ord xr xc dv -------------------------------------------------------------------------------- --- | Matrix transpose. +-- | O(1) Matrix transpose. This is only a logical transposition that does not +-- re-order the element storage. If the storage order is important, use 'cmat' +-- or 'fmat'. trans :: Matrix t -> Matrix t trans m@Matrix { irows = r, icols = c, xRow = xr, xCol = xc } = m { irows = c, icols = r, xRow = xc, xCol = xr } -cmat :: (Element t) => Matrix t -> Matrix t +-- | Obtain the RowMajor equivalent of a given Matrix. +cmat :: (Storable t) => Matrix t -> Matrix t cmat m | rowOrder m = m | otherwise = extractAll RowMajor m -fmat :: (Element t) => Matrix t -> Matrix t +-- | Obtain the ColumnMajor equivalent of a given Matrix. +fmat :: (Storable t) => Matrix t -> Matrix t fmat m | colOrder m = m | otherwise = extractAll ColumnMajor m @@ -115,14 +126,14 @@ fmat m -- C-Haskell matrix adapters {-# INLINE amatr #-} -amatr :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> Ptr a -> f) -> IO r +amatr :: Storable a => Matrix a -> (f -> IO r) -> (Int32 -> Int32 -> Ptr a -> f) -> IO r amatr x f g = unsafeWith (xdat x) (f . g r c) where r = fi (rows x) c = fi (cols x) {-# INLINE amat #-} -amat :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> CInt -> CInt -> Ptr a -> f) -> IO r +amat :: Storable a => Matrix a -> (f -> IO r) -> (Int32 -> Int32 -> Int32 -> Int32 -> Ptr a -> f) -> IO r amat x f g = unsafeWith (xdat x) (f . g r c sr sc) where r = fi (rows x) @@ -133,8 +144,8 @@ amat x f g = unsafeWith (xdat x) (f . g r c sr sc) instance Storable t => TransArray (Matrix t) where - type TransRaw (Matrix t) b = CInt -> CInt -> Ptr t -> b - type Trans (Matrix t) b = CInt -> CInt -> CInt -> CInt -> Ptr t -> b + type TransRaw (Matrix t) b = Int32 -> Int32 -> Ptr t -> b + type Trans (Matrix t) b = Int32 -> Int32 -> Int32 -> Int32 -> Ptr t -> b apply = amat {-# INLINE apply #-} applyRaw = amatr @@ -151,10 +162,10 @@ a #! b = a # b # id -------------------------------------------------------------------------------- -copy :: Element t => MatrixOrder -> Matrix t -> IO (Matrix t) -copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) +copy :: Storable t => MatrixOrder -> Matrix t -> IO (Matrix t) +copy ord m = extractAux ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) -extractAll :: Element t => MatrixOrder -> Matrix t -> Matrix t +extractAll :: Storable t => MatrixOrder -> Matrix t -> Matrix t extractAll ord m = unsafePerformIO (copy ord m) {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. @@ -164,14 +175,14 @@ extractAll ord m = unsafePerformIO (copy ord m) it :: (Num t, Element t) => Vector t -} -flatten :: Element t => Matrix t -> Vector t +flatten :: Storable t => Matrix t -> Vector t flatten m | isSlice m || not (rowOrder m) = xdat (extractAll RowMajor m) | otherwise = xdat m -- | the inverse of 'Data.Packed.Matrix.fromLists' -toLists :: (Element t) => Matrix t -> [[t]] +toLists :: (Storable t) => Matrix t -> [[t]] toLists = map toList . toRows @@ -192,7 +203,7 @@ compatdim (a:b:xs) -- | Create a matrix from a list of vectors. -- All vectors must have the same dimension, -- or dimension 1, which is are automatically expanded. -fromRows :: Element t => [Vector t] -> Matrix t +fromRows :: Storable t => [Vector t] -> Matrix t fromRows [] = emptyM 0 0 fromRows vs = case compatdim (map dim vs) of Nothing -> error $ "fromRows expects vectors with equal sizes (or singletons), given: " ++ show (map dim vs) @@ -203,25 +214,25 @@ fromRows vs = case compatdim (map dim vs) of adapt c v | c == 0 = fromList[] | dim v == c = v - | otherwise = constantD (v@>0) c + | otherwise = constantAux (v@>0) c -- | extracts the rows of a matrix as a list of vectors -toRows :: Element t => Matrix t -> [Vector t] +toRows :: Storable t => Matrix t -> [Vector t] toRows m | rowOrder m = map sub rowRange | otherwise = map ext rowRange where rowRange = [0..rows m-1] sub k = subVector (k*xRow m) (cols m) (xdat m) - ext k = xdat $ unsafePerformIO $ extractR RowMajor m 1 (idxs[k]) 0 (idxs[0,cols m-1]) + ext k = xdat $ unsafePerformIO $ extractAux RowMajor m 1 (idxs[k]) 0 (idxs[0,cols m-1]) -- | Creates a matrix from a list of vectors, as columns -fromColumns :: Element t => [Vector t] -> Matrix t +fromColumns :: Storable t => [Vector t] -> Matrix t fromColumns m = trans . fromRows $ m -- | Creates a list of vectors from the columns of a matrix -toColumns :: Element t => Matrix t -> [Vector t] +toColumns :: Storable t => Matrix t -> [Vector t] toColumns m = toRows . trans $ m -- | Reads a matrix position. @@ -271,13 +282,13 @@ reshape c v = matrixFromVector RowMajor (dim v `div` c) c v -- | application of a vector function on the flattened matrix elements -liftMatrix :: (Element a, Element b) => (Vector a -> Vector b) -> Matrix a -> Matrix b +liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b liftMatrix f m@Matrix { irows = r, icols = c, xdat = d} | isSlice m = matrixFromVector RowMajor r c (f (flatten m)) | otherwise = matrixFromVector (orderOf m) r c (f d) -- | application of a vector function on the flattened matrices elements -liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t +liftMatrix2 :: (Storable t, Storable a, Storable b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t liftMatrix2 f m1@(size->(r,c)) m2 | (r,c)/=size m2 = error "nonconformant matrices in liftMatrix2" | rowOrder m1 = matrixFromVector RowMajor r c (f (flatten m1) (flatten m2)) @@ -285,103 +296,8 @@ liftMatrix2 f m1@(size->(r,c)) m2 ------------------------------------------------------------------ --- | Supported matrix elements. -class (Storable a) => Element a where - constantD :: a -> Int -> Vector a - extractR :: MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) - setRect :: Int -> Int -> Matrix a -> Matrix a -> IO () - sortI :: Ord a => Vector a -> Vector CInt - sortV :: Ord a => Vector a -> Vector a - compareV :: Ord a => Vector a -> Vector a -> Vector CInt - selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a - remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a - rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () - gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () - reorderV :: Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation - - -instance Element Float where - constantD = constantAux cconstantF - extractR = extractAux c_extractF - setRect = setRectAux c_setRectF - sortI = sortIdxF - sortV = sortValF - compareV = compareF - selectV = selectF - remapM = remapF - rowOp = rowOpAux c_rowOpF - gemm = gemmg c_gemmF - reorderV = reorderAux c_reorderF - -instance Element Double where - constantD = constantAux cconstantR - extractR = extractAux c_extractD - setRect = setRectAux c_setRectD - sortI = sortIdxD - sortV = sortValD - compareV = compareD - selectV = selectD - remapM = remapD - rowOp = rowOpAux c_rowOpD - gemm = gemmg c_gemmD - reorderV = reorderAux c_reorderD - -instance Element (Complex Float) where - constantD = constantAux cconstantQ - extractR = extractAux c_extractQ - setRect = setRectAux c_setRectQ - sortI = undefined - sortV = undefined - compareV = undefined - selectV = selectQ - remapM = remapQ - rowOp = rowOpAux c_rowOpQ - gemm = gemmg c_gemmQ - reorderV = reorderAux c_reorderQ - -instance Element (Complex Double) where - constantD = constantAux cconstantC - extractR = extractAux c_extractC - setRect = setRectAux c_setRectC - sortI = undefined - sortV = undefined - compareV = undefined - selectV = selectC - remapM = remapC - rowOp = rowOpAux c_rowOpC - gemm = gemmg c_gemmC - reorderV = reorderAux c_reorderC - -instance Element (CInt) where - constantD = constantAux cconstantI - extractR = extractAux c_extractI - setRect = setRectAux c_setRectI - sortI = sortIdxI - sortV = sortValI - compareV = compareI - selectV = selectI - remapM = remapI - rowOp = rowOpAux c_rowOpI - gemm = gemmg c_gemmI - reorderV = reorderAux c_reorderI - -instance Element Z where - constantD = constantAux cconstantL - extractR = extractAux c_extractL - setRect = setRectAux c_setRectL - sortI = sortIdxL - sortV = sortValL - compareV = compareL - selectV = selectL - remapM = remapL - rowOp = rowOpAux c_rowOpL - gemm = gemmg c_gemmL - reorderV = reorderAux c_reorderL - -------------------------------------------------------------------- - -- | reference to a rectangular slice of a matrix (no data copy) -subMatrix :: Element a +subMatrix :: Storable a => (Int,Int) -- ^ (r0,c0) starting position -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix -> Matrix a -- ^ input matrix @@ -402,34 +318,34 @@ subMatrix (r0,c0) (rt,ct) m maxZ :: (Num t1, Ord t1, Foldable t) => t t1 -> t1 maxZ xs = if minimum xs == 0 then 0 else maximum xs -conformMs :: Element t => [Matrix t] -> [Matrix t] +conformMs :: Storable t => [Matrix t] -> [Matrix t] conformMs ms = map (conformMTo (r,c)) ms where r = maxZ (map rows ms) c = maxZ (map cols ms) -conformVs :: Element t => [Vector t] -> [Vector t] +conformVs :: Storable t => [Vector t] -> [Vector t] conformVs vs = map (conformVTo n) vs where n = maxZ (map dim vs) -conformMTo :: Element t => (Int, Int) -> Matrix t -> Matrix t +conformMTo :: Storable t => (Int, Int) -> Matrix t -> Matrix t conformMTo (r,c) m | size m == (r,c) = m - | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) + | size m == (1,1) = matrixFromVector RowMajor r c (constantAux (m@@>(0,0)) (r*c)) | size m == (r,1) = repCols c m | size m == (1,c) = repRows r m | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c) -conformVTo :: Element t => Int -> Vector t -> Vector t +conformVTo :: Storable t => Int -> Vector t -> Vector t conformVTo n v | dim v == n = v - | dim v == 1 = constantD (v@>0) n + | dim v == 1 = constantAux (v@>0) n | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n -repRows :: Element t => Int -> Matrix t -> Matrix t +repRows :: Storable t => Int -> Matrix t -> Matrix t repRows n x = fromRows (replicate n (flatten x)) -repCols :: Element t => Int -> Matrix t -> Matrix t +repCols :: Storable t => Int -> Matrix t -> Matrix t repCols n x = fromColumns (replicate n (flatten x)) shSize :: Matrix t -> [Char] @@ -453,32 +369,50 @@ instance (Storable t, NFData t) => NFData (Matrix t) --------------------------------------------------------------- +{- extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1, Storable t, Num t3, Num t2, Integral t1, Integral t) - => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t - -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)) - -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a) -extractAux f ord m moder vr modec vc = do + => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)) -- f + -> MatrixOrder -- ord + -> c -- m + -> t3 -- moder + -> Vector t1 -- vr + -> t2 -- modec + -> Vector t -- vc + -> IO (Matrix a) +-} + +extractAux :: Storable a => + MatrixOrder + -> Matrix a + -> Int32 + -> Vector Int32 + -> Int32 + -> Vector Int32 + -> IO (Matrix a) +extractAux ord m moder vr modec vc = do let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc r <- createMatrix ord nr nc - (vr # vc # m #! r) (f moder modec) #|"extract" + (vr # vc # m #! r) (extractStorable moder modec) #|"extract" return r -type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt)))) +{- +type Extr x = Int32 -> Int32 -> CIdxs (CIdxs (OM x (OM x (IO Int32)))) foreign import ccall unsafe "extractD" c_extractD :: Extr Double foreign import ccall unsafe "extractF" c_extractF :: Extr Float foreign import ccall unsafe "extractC" c_extractC :: Extr (Complex Double) foreign import ccall unsafe "extractQ" c_extractQ :: Extr (Complex Float) -foreign import ccall unsafe "extractI" c_extractI :: Extr CInt +foreign import ccall unsafe "extractI" c_extractI :: Extr Int32 foreign import ccall unsafe "extractL" c_extractL :: Extr Z +-} --------------------------------------------------------------- setRectAux :: (TransArray c1, TransArray c) - => (CInt -> CInt -> Trans c1 (Trans c (IO CInt))) + => (Int32 -> Int32 -> Trans c1 (Trans c (IO Int32))) -> Int -> Int -> c1 -> c -> IO () setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" @@ -494,17 +428,17 @@ foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z -------------------------------------------------------------------------------- sortG :: (Storable t, Storable a) - => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a + => (Int32 -> Ptr t -> Int32 -> Ptr a -> IO Int32) -> Vector t -> Vector a sortG f v = unsafePerformIO $ do r <- createVector (dim v) (v #! r) f #|"sortG" return r -sortIdxD :: Vector Double -> Vector CInt +sortIdxD :: Vector Double -> Vector Int32 sortIdxD = sortG c_sort_indexD -sortIdxF :: Vector Float -> Vector CInt +sortIdxF :: Vector Float -> Vector Int32 sortIdxF = sortG c_sort_indexF -sortIdxI :: Vector CInt -> Vector CInt +sortIdxI :: Vector Int32 -> Vector Int32 sortIdxI = sortG c_sort_indexI sortIdxL :: Vector Z -> Vector I sortIdxL = sortG c_sort_indexL @@ -513,81 +447,81 @@ sortValD :: Vector Double -> Vector Double sortValD = sortG c_sort_valD sortValF :: Vector Float -> Vector Float sortValF = sortG c_sort_valF -sortValI :: Vector CInt -> Vector CInt +sortValI :: Vector Int32 -> Vector Int32 sortValI = sortG c_sort_valI sortValL :: Vector Z -> Vector Z sortValL = sortG c_sort_valL -foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt)) -foreign import ccall unsafe "sort_indexF" c_sort_indexF :: CV Float (CV CInt (IO CInt)) -foreign import ccall unsafe "sort_indexI" c_sort_indexI :: CV CInt (CV CInt (IO CInt)) +foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV Int32 (IO Int32)) +foreign import ccall unsafe "sort_indexF" c_sort_indexF :: CV Float (CV Int32 (IO Int32)) +foreign import ccall unsafe "sort_indexI" c_sort_indexI :: CV Int32 (CV Int32 (IO Int32)) foreign import ccall unsafe "sort_indexL" c_sort_indexL :: Z :> I :> Ok -foreign import ccall unsafe "sort_valuesD" c_sort_valD :: CV Double (CV Double (IO CInt)) -foreign import ccall unsafe "sort_valuesF" c_sort_valF :: CV Float (CV Float (IO CInt)) -foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV CInt (CV CInt (IO CInt)) +foreign import ccall unsafe "sort_valuesD" c_sort_valD :: CV Double (CV Double (IO Int32)) +foreign import ccall unsafe "sort_valuesF" c_sort_valF :: CV Float (CV Float (IO Int32)) +foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV Int32 (CV Int32 (IO Int32)) foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok -------------------------------------------------------------------------------- compareG :: (TransArray c, Storable t, Storable a) - => Trans c (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) + => Trans c (Int32 -> Ptr t -> Int32 -> Ptr a -> IO Int32) -> c -> Vector t -> Vector a compareG f u v = unsafePerformIO $ do r <- createVector (dim v) (u # v #! r) f #|"compareG" return r -compareD :: Vector Double -> Vector Double -> Vector CInt +compareD :: Vector Double -> Vector Double -> Vector Int32 compareD = compareG c_compareD -compareF :: Vector Float -> Vector Float -> Vector CInt +compareF :: Vector Float -> Vector Float -> Vector Int32 compareF = compareG c_compareF -compareI :: Vector CInt -> Vector CInt -> Vector CInt +compareI :: Vector Int32 -> Vector Int32 -> Vector Int32 compareI = compareG c_compareI -compareL :: Vector Z -> Vector Z -> Vector CInt +compareL :: Vector Z -> Vector Z -> Vector Int32 compareL = compareG c_compareL -foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) -foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV CInt (IO CInt))) -foreign import ccall unsafe "compareI" c_compareI :: CV CInt (CV CInt (CV CInt (IO CInt))) +foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV Int32 (IO Int32))) +foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV Int32 (IO Int32))) +foreign import ccall unsafe "compareI" c_compareI :: CV Int32 (CV Int32 (CV Int32 (IO Int32))) foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok -------------------------------------------------------------------------------- selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a) - => Trans c2 (Trans c1 (CInt -> Ptr t -> Trans c (CInt -> Ptr a -> IO CInt))) + => Trans c2 (Trans c1 (Int32 -> Ptr t -> Trans c (Int32 -> Ptr a -> IO Int32))) -> c2 -> c1 -> Vector t -> c -> Vector a selectG f c u v w = unsafePerformIO $ do r <- createVector (dim v) (c # u # v # w #! r) f #|"selectG" return r -selectD :: Vector CInt -> Vector Double -> Vector Double -> Vector Double -> Vector Double +selectD :: Vector Int32 -> Vector Double -> Vector Double -> Vector Double -> Vector Double selectD = selectG c_selectD -selectF :: Vector CInt -> Vector Float -> Vector Float -> Vector Float -> Vector Float +selectF :: Vector Int32 -> Vector Float -> Vector Float -> Vector Float -> Vector Float selectF = selectG c_selectF -selectI :: Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt +selectI :: Vector Int32 -> Vector Int32 -> Vector Int32 -> Vector Int32 -> Vector Int32 selectI = selectG c_selectI -selectL :: Vector CInt -> Vector Z -> Vector Z -> Vector Z -> Vector Z +selectL :: Vector Int32 -> Vector Z -> Vector Z -> Vector Z -> Vector Z selectL = selectG c_selectL -selectC :: Vector CInt +selectC :: Vector Int32 -> Vector (Complex Double) -> Vector (Complex Double) -> Vector (Complex Double) -> Vector (Complex Double) selectC = selectG c_selectC -selectQ :: Vector CInt +selectQ :: Vector Int32 -> Vector (Complex Float) -> Vector (Complex Float) -> Vector (Complex Float) -> Vector (Complex Float) selectQ = selectG c_selectQ -type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) +type Sel x = CV Int32 (CV x (CV x (CV x (CV x (IO Int32))))) foreign import ccall unsafe "chooseD" c_selectD :: Sel Double foreign import ccall unsafe "chooseF" c_selectF :: Sel Float -foreign import ccall unsafe "chooseI" c_selectI :: Sel CInt +foreign import ccall unsafe "chooseI" c_selectI :: Sel Int32 foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double) foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float) foreign import ccall unsafe "chooseL" c_selectL :: Sel Z @@ -595,35 +529,35 @@ foreign import ccall unsafe "chooseL" c_selectL :: Sel Z --------------------------------------------------------------------------- remapG :: (TransArray c, TransArray c1, Storable t, Storable a) - => (CInt -> CInt -> CInt -> CInt -> Ptr t - -> Trans c1 (Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt))) + => (Int32 -> Int32 -> Int32 -> Int32 -> Ptr t + -> Trans c1 (Trans c (Int32 -> Int32 -> Int32 -> Int32 -> Ptr a -> IO Int32))) -> Matrix t -> c1 -> c -> Matrix a remapG f i j m = unsafePerformIO $ do r <- createMatrix RowMajor (rows i) (cols i) (i # j # m #! r) f #|"remapG" return r -remapD :: Matrix CInt -> Matrix CInt -> Matrix Double -> Matrix Double +remapD :: Matrix Int32 -> Matrix Int32 -> Matrix Double -> Matrix Double remapD = remapG c_remapD -remapF :: Matrix CInt -> Matrix CInt -> Matrix Float -> Matrix Float +remapF :: Matrix Int32 -> Matrix Int32 -> Matrix Float -> Matrix Float remapF = remapG c_remapF -remapI :: Matrix CInt -> Matrix CInt -> Matrix CInt -> Matrix CInt +remapI :: Matrix Int32 -> Matrix Int32 -> Matrix Int32 -> Matrix Int32 remapI = remapG c_remapI -remapL :: Matrix CInt -> Matrix CInt -> Matrix Z -> Matrix Z +remapL :: Matrix Int32 -> Matrix Int32 -> Matrix Z -> Matrix Z remapL = remapG c_remapL -remapC :: Matrix CInt - -> Matrix CInt +remapC :: Matrix Int32 + -> Matrix Int32 -> Matrix (Complex Double) -> Matrix (Complex Double) remapC = remapG c_remapC -remapQ :: Matrix CInt -> Matrix CInt -> Matrix (Complex Float) -> Matrix (Complex Float) +remapQ :: Matrix Int32 -> Matrix Int32 -> Matrix (Complex Float) -> Matrix (Complex Float) remapQ = remapG c_remapQ -type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) +type Rem x = OM Int32 (OM Int32 (OM x (OM x (IO Int32)))) foreign import ccall unsafe "remapD" c_remapD :: Rem Double foreign import ccall unsafe "remapF" c_remapF :: Rem Float -foreign import ccall unsafe "remapI" c_remapI :: Rem CInt +foreign import ccall unsafe "remapI" c_remapI :: Rem Int32 foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double) foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float) foreign import ccall unsafe "remapL" c_remapL :: Rem Z @@ -631,14 +565,14 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z -------------------------------------------------------------------------------- rowOpAux :: (TransArray c, Storable a) => - (CInt -> Ptr a -> CInt -> CInt -> CInt -> CInt -> Trans c (IO CInt)) + (Int32 -> Ptr a -> Int32 -> Int32 -> Int32 -> Int32 -> Trans c (IO Int32)) -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO () rowOpAux f c x i1 i2 j1 j2 m = do px <- newArray [x] (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" free px -type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok +type RowOp x = Int32 -> Ptr x -> Int32 -> Int32 -> Int32 -> Int32 -> x ::> Ok foreign import ccall unsafe "rowop_double" c_rowOpD :: RowOp R foreign import ccall unsafe "rowop_float" c_rowOpF :: RowOp Float @@ -652,7 +586,7 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z -------------------------------------------------------------------------------- gemmg :: (TransArray c1, TransArray c, TransArray c2, TransArray c3) - => Trans c3 (Trans c2 (Trans c1 (Trans c (IO CInt)))) + => Trans c3 (Trans c2 (Trans c1 (Trans c (IO Int32)))) -> c3 -> c2 -> c1 -> c -> IO () gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" @@ -669,21 +603,26 @@ foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z -------------------------------------------------------------------------------- +{- reorderAux :: (TransArray c, Storable t, Storable a1, Storable t1, Storable a) => - (CInt -> Ptr a -> CInt -> Ptr t1 - -> Trans c (CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt)) + (Int32 -> Ptr a -> Int32 -> Ptr t1 + -> Trans c (Int32 -> Ptr t -> Int32 -> Ptr a1 -> IO Int32)) -> Vector t1 -> c -> Vector t -> Vector a1 +-} +reorderAux :: (TransArray c, Storable a, + Trans c (Int32 -> Ptr a -> Int32 -> Ptr a -> IO Int32) ~ (Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr a -> Int32 -> Ptr a -> IO Int32)) => + p -> Vector Int32 -> c -> Vector a -> Vector a reorderAux f s d v = unsafePerformIO $ do k <- createVector (dim s) r <- createVector (dim v) - (k # s # d # v #! r) f #| "reorderV" + (k # s # d # v #! r) reorderStorable #| "reorderV" return r -type Reorder x = CV CInt (CV CInt (CV CInt (CV x (CV x (IO CInt))))) +type Reorder x = CV Int32 (CV Int32 (CV Int32 (CV x (CV x (IO Int32))))) foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float -foreign import ccall unsafe "reorderI" c_reorderI :: Reorder CInt +foreign import ccall unsafe "reorderI" c_reorderI :: Reorder Int32 foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double) foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float) foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z @@ -691,12 +630,12 @@ foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z -- | Transpose an array with dimensions @dims@ by making a copy using @strides@. For example, for an array with 3 indices, -- @(reorderVector strides dims v) ! ((i * dims ! 1 + j) * dims ! 2 + k) == v ! (i * strides ! 0 + j * strides ! 1 + k * strides ! 2)@ -- This function is intended to be used internally by tensor libraries. -reorderVector :: Element a - => Vector CInt -- ^ @strides@: array strides - -> Vector CInt -- ^ @dims@: array dimensions of new array @v@ +reorderVector :: Storable a + => Vector Int32 -- ^ @strides@: array strides + -> Vector Int32 -- ^ @dims@: array dimensions of new array @v@ -> Vector a -- ^ @v@: flattened input array -> Vector a -- ^ @v'@: flattened output array -reorderVector = reorderV +reorderVector = reorderAux () -------------------------------------------------------------------------------- -- cgit v1.2.3