From badcbdfddc4be31fc79a6df4553795af18069efe Mon Sep 17 00:00:00 2001 From: Joe Crayne Date: Thu, 8 Aug 2019 02:22:30 -0400 Subject: Removed the Element class. --- packages/base/hmatrix.cabal | 1 + packages/base/src/Internal/Algorithms.hs | 5 +- packages/base/src/Internal/Container.hs | 96 ++++++- packages/base/src/Internal/Conversion.hs | 7 +- packages/base/src/Internal/Convolution.hs | 7 +- packages/base/src/Internal/Devel.hs | 21 +- packages/base/src/Internal/Element.hs | 84 +++--- packages/base/src/Internal/Extract.hs | 145 ++++++++++ packages/base/src/Internal/IO.hs | 9 +- packages/base/src/Internal/LAPACK.hs | 19 +- packages/base/src/Internal/Matrix.hs | 307 +++++++++------------ packages/base/src/Internal/Modular.hs | 6 +- packages/base/src/Internal/Numeric.hs | 80 ++++-- packages/base/src/Internal/ST.hs | 131 ++++++++- packages/base/src/Internal/Sparse.hs | 16 +- packages/base/src/Internal/Util.hs | 15 +- packages/base/src/Internal/Vector.hs | 10 +- packages/base/src/Internal/Vectorized.hs | 133 ++++----- packages/base/src/Numeric/LinearAlgebra.hs | 2 +- .../src/Numeric/LinearAlgebra/Tests/Instances.hs | 7 +- 20 files changed, 725 insertions(+), 376 deletions(-) create mode 100644 packages/base/src/Internal/Extract.hs diff --git a/packages/base/hmatrix.cabal b/packages/base/hmatrix.cabal index 4dc62e5..476a293 100644 --- a/packages/base/hmatrix.cabal +++ b/packages/base/hmatrix.cabal @@ -66,6 +66,7 @@ library Internal.Devel Internal.Vectorized Internal.Matrix + Internal.Extract Internal.ST Internal.IO Internal.Element diff --git a/packages/base/src/Internal/Algorithms.hs b/packages/base/src/Internal/Algorithms.hs index f5bddc6..aa51792 100644 --- a/packages/base/src/Internal/Algorithms.hs +++ b/packages/base/src/Internal/Algorithms.hs @@ -39,6 +39,7 @@ import qualified Data.Vector.Storable as Vector import Internal.ST import Internal.Vectorized(range) import Control.DeepSeq +import Foreign.Storable {- | Generic linear algebra functions for double precision real and complex matrices. @@ -742,7 +743,7 @@ pinvTol t m = v' `mXm` diag s' `mXm` ctrans u' where -- | Numeric rank of a matrix from the SVD decomposition. -rankSVD :: Element t +rankSVD :: Storable t => Double -- ^ numeric zero (e.g. 1*'eps') -> Matrix t -- ^ input matrix m -> Vector Double -- ^ 'sv' of m @@ -1003,7 +1004,7 @@ fixPerm' s = res $ mutable f s0 s0 = reshape 1 (range (length s)) res = flatten . fst swap m i j = rowOper (SWAP i j AllCols) m - f :: (Num t, Element t) => (Int, Int) -> STMatrix s t -> ST s () -- needed because of TypeFamilies + f :: (Num t, Storable t) => (Int, Int) -> STMatrix s t -> ST s () -- needed because of TypeFamilies f _ p = sequence_ $ zipWith (swap p) [0..] s triang r c h v = (r>)) #endif @@ -227,7 +236,7 @@ meanCov x = (med,cov) where -------------------------------------------------------------------------------- -sortVector :: (Ord t, Element t) => Vector t -> Vector t +sortVector :: (Ord t, Storable t) => Vector t -> Vector t sortVector = sortV {- | @@ -248,7 +257,7 @@ sortVector = sortV -2.20 0.11 -1.58 -0.01 0.19 -0.29 1.04 1.06 -2.09 -0.75 -} -sortIndex :: (Ord t, Element t) => Vector t -> Vector I +sortIndex :: (Ord t, Storable t) => Vector t -> Vector I sortIndex = sortI ccompare :: (Ord t, Container c t) => c t -> c t -> c I @@ -296,10 +305,91 @@ The indexes are autoconformable. , 10, 16, 22 ] -} -remap :: Element t => Matrix I -> Matrix I -> Matrix t -> Matrix t +remap :: Storable t => Matrix I -> Matrix I -> Matrix t -> Matrix t remap i j m | minElement i >= 0 && maxElement i < fromIntegral (rows m) && minElement j >= 0 && maxElement j < fromIntegral (cols m) = remapM i' j' m | otherwise = error $ "out of range index in remap" where [i',j'] = conformMs [i,j] + +sortI :: (Storable a, Ord a) => Vector a -> Vector Int32 +sortI = sortG sort_index + +type C_Compare a = Ptr a -> Ptr a -> IO Int32 + +foreign import ccall "wrapper" wrapCompare :: C_Compare a -> IO (FunPtr (C_Compare a)) + +foreign import ccall "qsort" + c_qsort :: Ptr a -- ^ base + -> Word -- ^ nmemb + -> Word -- ^ size + -> FunPtr (C_Compare a) -- ^ compar + -> IO () + +sizeOfElem :: forall a. Storable a => Ptr a -> Int +sizeOfElem _ = sizeOf (undefined :: a) + +sort_index :: (Storable a, Ord a) => + Int32 -> Ptr a + -> Int32 -> Ptr Int32 + -> IO Int32 +sort_index vn vp rn rp = do + requires (vn == rn) BAD_SIZE $ do + comp <- wrapCompare $ \ap bp -> do + a <- peekElemOff vp . fromIntegral =<< peek (ap :: Ptr Int32) + b <- peekElemOff vp . fromIntegral =<< peek bp + return $ case compare a b of + LT -> -1 + GT -> 1 + EQ -> 0 + sequence_ [ pokeElemOff rp (fromIntegral i) i | i <- [0 .. rn-1] ] + c_qsort rp (fromIntegral rn) 4 comp + freeHaskellFunPtr comp + return 0 + +sortV :: (Storable a, Ord a) => Vector a -> Vector a +sortV = sortG sortStorable + +sortStorable :: (Storable a, Ord a) => + Int32 -> Ptr a + -> Int32 -> Ptr a + -> IO Int32 +sortStorable vn vp rn rp = do + requires (vn == rn) BAD_SIZE $ do + copyArray rp vp (fromIntegral vn * sizeOfElem vp) + comp <- wrapCompare $ \ap bp -> do + a <- peek ap + b <- peek bp + return $ case compare a b of + LT -> -1 + GT -> 1 + EQ -> 0 + c_qsort rp (fromIntegral rn) (fromIntegral $ sizeOfElem rp) comp + freeHaskellFunPtr comp + return 0 + +remapM :: Storable a => Matrix Int32 -> Matrix Int32 -> Matrix a -> Matrix a +remapM = remapG remapStorable + +remapStorable :: Storable a => + Int32 -> Int32 -> Int32 -> Int32 -> Ptr Int32 -- i + -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr Int32 -- j + -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr a -- m + -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr a -- r + -> IO Int32 +remapStorable ir ic iXr iXc ip + jr jc jXr jXc jp + mr mc mXr mXc mp + rr rc rXr rXc rp = do + requires (ir==jr && ic==jc && ir==rr && ic==rc) BAD_SIZE $ do + ($ 0) $ fix $ \aloop a -> when (a when (b Precision s d | s -> d, d -> s where +class (Storable s, Storable d) => Precision s d | s -> d, d -> s where double2FloatG :: Vector d -> Vector s float2DoubleG :: Vector s -> Vector d @@ -50,7 +51,7 @@ instance Precision I Z where -- | Supported real types -class (Element t, Element (Complex t), RealFloat t) +class (Storable t, Storable (Complex t), RealFloat t) => RealElement t instance RealElement Double @@ -69,7 +70,7 @@ class Complexable c where instance Complexable Vector where toComplex' = toComplexV fromComplex' = fromComplexV - comp' v = toComplex' (v,constantD 0 (dim v)) + comp' v = toComplex' (v,constantAux 0 (dim v)) single' = double2FloatG double' = float2DoubleG diff --git a/packages/base/src/Internal/Convolution.hs b/packages/base/src/Internal/Convolution.hs index 75fbef4..ae8ebc6 100644 --- a/packages/base/src/Internal/Convolution.hs +++ b/packages/base/src/Internal/Convolution.hs @@ -24,12 +24,13 @@ import Internal.Numeric import Internal.Element import Internal.Conversion import Internal.Container +import Foreign.Storable #if MIN_VERSION_base(4,11,0) import Prelude hiding ((<>)) #endif -vectSS :: Element t => Int -> Vector t -> Matrix t +vectSS :: Storable t => Int -> Vector t -> Matrix t vectSS n v = fromRows [ subVector k n v | k <- [0 .. dim v - n] ] @@ -82,7 +83,7 @@ corrMin ker v -matSS :: Element t => Int -> Matrix t -> [Matrix t] +matSS :: Storable t => Int -> Matrix t -> [Matrix t] matSS dr m = map (reshape c) [ subVector (k*c) n v | k <- [0 .. r - dr] ] where v = flatten m @@ -155,7 +156,7 @@ conv2 k m empty = r == 0 || c == 0 -separable :: Element t => (Vector t -> Vector t) -> Matrix t -> Matrix t +separable :: Storable t => (Vector t -> Vector t) -> Matrix t -> Matrix t -- ^ matrix computation implemented as separated vector operations by rows and columns. separable f = fromColumns . map f . toColumns . fromRows . map f . toRows diff --git a/packages/base/src/Internal/Devel.hs b/packages/base/src/Internal/Devel.hs index f72d8aa..b0594d4 100644 --- a/packages/base/src/Internal/Devel.hs +++ b/packages/base/src/Internal/Devel.hs @@ -13,6 +13,7 @@ module Internal.Devel where import Control.Monad ( when ) +import Data.Int import Foreign.C.Types ( CInt ) --import Foreign.Storable.Complex () import Foreign.Ptr(Ptr) @@ -28,7 +29,7 @@ infixl 0 // -- GSL error codes are <= 1024 -- | error codes for the auxiliary functions required by the wrappers -errorCode :: CInt -> String +errorCode :: Int32 -> String errorCode 2000 = "bad size" errorCode 2001 = "bad function code" errorCode 2002 = "memory problem" @@ -44,7 +45,7 @@ errorCode n = "code "++show n foreign import ccall unsafe "asm_finit" finit :: IO () -- | check the error code -check :: String -> IO CInt -> IO () +check :: String -> IO Int32 -> IO () check msg f = do -- finit err <- f @@ -54,7 +55,7 @@ check msg f = do -- | postfix error code check infixl 0 #| -(#|) :: IO CInt -> String -> IO () +(#|) :: IO Int32 -> String -> IO () (#|) = flip check -- | Error capture and conversion to Maybe @@ -65,12 +66,12 @@ mbCatch act = E.catch (Just `fmap` act) f -------------------------------------------------------------------------------- -type CM b r = CInt -> CInt -> Ptr b -> r -type CV b r = CInt -> Ptr b -> r -type OM b r = CInt -> CInt -> CInt -> CInt -> Ptr b -> r +type CM b r = Int32 -> Int32 -> Ptr b -> r +type CV b r = Int32 -> Ptr b -> r +type OM b r = Int32 -> Int32 -> Int32 -> Int32 -> Ptr b -> r -type CIdxs r = CV CInt r -type Ok = IO CInt +type CIdxs r = CV Int32 r +type Ok = IO Int32 infixr 5 :>, ::>, ..> type (:>) t r = CV t r @@ -87,8 +88,8 @@ class TransArray c instance Storable t => TransArray (Vector t) where - type Trans (Vector t) b = CInt -> Ptr t -> b - type TransRaw (Vector t) b = CInt -> Ptr t -> b + type Trans (Vector t) b = Int32 -> Ptr t -> b + type TransRaw (Vector t) b = Int32 -> Ptr t -> b apply = avec {-# INLINE apply #-} applyRaw = avec diff --git a/packages/base/src/Internal/Element.hs b/packages/base/src/Internal/Element.hs index 2e330ee..80eda8d 100644 --- a/packages/base/src/Internal/Element.hs +++ b/packages/base/src/Internal/Element.hs @@ -33,14 +33,14 @@ import Data.List.Split(chunksOf) import Foreign.Storable(Storable) import System.IO.Unsafe(unsafePerformIO) import Control.Monad(liftM) -import Foreign.C.Types(CInt) +import Data.Int ------------------------------------------------------------------- import Data.Binary -instance (Binary (Vector a), Element a) => Binary (Matrix a) where +instance (Binary (Vector a), Storable a) => Binary (Matrix a) where put m = do put (cols m) put (flatten m) @@ -52,7 +52,7 @@ instance (Binary (Vector a), Element a) => Binary (Matrix a) where ------------------------------------------------------------------- -instance (Show a, Element a) => (Show (Matrix a)) where +instance (Show a, Storable a) => (Show (Matrix a)) where show m | rows m == 0 || cols m == 0 = sizes m ++" []" show m = (sizes m++) . dsp . map (map show) . toLists $ m @@ -70,7 +70,7 @@ dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unw ------------------------------------------------------------------ -instance (Element a, Read a) => Read (Matrix a) where +instance (Storable a, Read a) => Read (Matrix a) where readsPrec _ s = [((rs> Matrix t -> (Extractor,Extractor) -> Matrix t +(??) :: Storable t => Matrix t -> (Extractor,Extractor) -> Matrix t -minEl :: Vector CInt -> CInt +minEl :: Vector Int32 -> Int32 minEl = toScalarI Min -maxEl :: Vector CInt -> CInt +maxEl :: Vector Int32 -> Int32 maxEl = toScalarI Max -cmodi :: Foreign.C.Types.CInt -> Vector Foreign.C.Types.CInt -> Vector Foreign.C.Types.CInt +cmodi :: Int32 -> Vector Int32 -> Vector Int32 cmodi = vectorMapValI ModVS extractError :: Matrix t1 -> (Extractor, Extractor) -> t @@ -181,7 +181,7 @@ m ?? (e, TakeLast n) = m ?? (e, Drop (cols m - n)) m ?? (DropLast n, e) = m ?? (Take (rows m - n), e) m ?? (e, DropLast n) = m ?? (e, Take (cols m - n)) -m ?? (er,ec) = unsafePerformIO $ extractR (orderOf m) m moder rs modec cs +m ?? (er,ec) = unsafePerformIO $ extractAux (orderOf m) m moder rs modec cs where (moder,rs) = mkExt (rows m) er (modec,cs) = mkExt (cols m) ec @@ -209,14 +209,14 @@ common f = commonval . map f -- | creates a matrix from a vertical list of matrices -joinVert :: Element t => [Matrix t] -> Matrix t +joinVert :: Storable t => [Matrix t] -> Matrix t joinVert [] = emptyM 0 0 joinVert ms = case common cols ms of Nothing -> error "(impossible) joinVert on matrices with different number of columns" Just c -> matrixFromVector RowMajor (sum (map rows ms)) c $ vjoin (map flatten ms) -- | creates a matrix from a horizontal list of matrices -joinHoriz :: Element t => [Matrix t] -> Matrix t +joinHoriz :: Storable t => [Matrix t] -> Matrix t joinHoriz ms = trans. joinVert . map trans $ ms {- | Create a matrix from blocks given as a list of lists of matrices. @@ -240,13 +240,13 @@ disp = putStr . dispf 2 3 3 3 3 3 0 0 3 0 0 -} -fromBlocks :: Element t => [[Matrix t]] -> Matrix t +fromBlocks :: Storable t => [[Matrix t]] -> Matrix t fromBlocks = fromBlocksRaw . adaptBlocks -fromBlocksRaw :: Element t => [[Matrix t]] -> Matrix t +fromBlocksRaw :: Storable t => [[Matrix t]] -> Matrix t fromBlocksRaw mms = joinVert . map joinHoriz $ mms -adaptBlocks :: Element t => [[Matrix t]] -> [[Matrix t]] +adaptBlocks :: Storable t => [[Matrix t]] -> [[Matrix t]] adaptBlocks ms = ms' where bc = case common length ms of Just c -> c @@ -258,7 +258,7 @@ adaptBlocks ms = ms' where g [Just nr,Just nc] m | nr == r && nc == c = m - | r == 1 && c == 1 = matrixFromVector RowMajor nr nc (constantD x (nr*nc)) + | r == 1 && c == 1 = matrixFromVector RowMajor nr nc (constantAux x (nr*nc)) | r == 1 = fromRows (replicate nr (flatten m)) | otherwise = fromColumns (replicate nc (flatten m)) where @@ -288,7 +288,7 @@ adaptBlocks ms = ms' where , 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0 ] -} -diagBlock :: (Element t, Num t) => [Matrix t] -> Matrix t +diagBlock :: (Storable t, Num t) => [Matrix t] -> Matrix t diagBlock ms = fromBlocks $ zipWith f ms [0..] where f m k = take n $ replicate k z ++ m : repeat z @@ -299,13 +299,13 @@ diagBlock ms = fromBlocks $ zipWith f ms [0..] -- | Reverse rows -flipud :: Element t => Matrix t -> Matrix t +flipud :: Storable t => Matrix t -> Matrix t flipud m = extractRows [r-1,r-2 .. 0] $ m where r = rows m -- | Reverse columns -fliprl :: Element t => Matrix t -> Matrix t +fliprl :: Storable t => Matrix t -> Matrix t fliprl m = extractColumns [c-1,c-2 .. 0] $ m where c = cols m @@ -330,7 +330,7 @@ diagRect z v r c = ST.runSTMatrix $ do return m -- | extracts the diagonal from a rectangular matrix -takeDiag :: (Element t) => Matrix t -> Vector t +takeDiag :: (Storable t) => Matrix t -> Vector t takeDiag m = fromList [flatten m @> (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] ------------------------------------------------------------ @@ -363,32 +363,32 @@ r >< c = f where ---------------------------------------------------------------- -takeRows :: Element t => Int -> Matrix t -> Matrix t +takeRows :: Storable t => Int -> Matrix t -> Matrix t takeRows n mt = subMatrix (0,0) (n, cols mt) mt -- | Creates a matrix with the last n rows of another matrix -takeLastRows :: Element t => Int -> Matrix t -> Matrix t +takeLastRows :: Storable t => Int -> Matrix t -> Matrix t takeLastRows n mt = subMatrix (rows mt - n, 0) (n, cols mt) mt -dropRows :: Element t => Int -> Matrix t -> Matrix t +dropRows :: Storable t => Int -> Matrix t -> Matrix t dropRows n mt = subMatrix (n,0) (rows mt - n, cols mt) mt -- | Creates a copy of a matrix without the last n rows -dropLastRows :: Element t => Int -> Matrix t -> Matrix t +dropLastRows :: Storable t => Int -> Matrix t -> Matrix t dropLastRows n mt = subMatrix (0,0) (rows mt - n, cols mt) mt -takeColumns :: Element t => Int -> Matrix t -> Matrix t +takeColumns :: Storable t => Int -> Matrix t -> Matrix t takeColumns n mt = subMatrix (0,0) (rows mt, n) mt -- |Creates a matrix with the last n columns of another matrix -takeLastColumns :: Element t => Int -> Matrix t -> Matrix t +takeLastColumns :: Storable t => Int -> Matrix t -> Matrix t takeLastColumns n mt = subMatrix (0, cols mt - n) (rows mt, n) mt -dropColumns :: Element t => Int -> Matrix t -> Matrix t +dropColumns :: Storable t => Int -> Matrix t -> Matrix t dropColumns n mt = subMatrix (0,n) (rows mt, cols mt - n) mt -- | Creates a copy of a matrix without the last n columns -dropLastColumns :: Element t => Int -> Matrix t -> Matrix t +dropLastColumns :: Storable t => Int -> Matrix t -> Matrix t dropLastColumns n mt = subMatrix (0,0) (rows mt, cols mt - n) mt ---------------------------------------------------------------- @@ -402,7 +402,7 @@ dropLastColumns n mt = subMatrix (0,0) (rows mt, cols mt - n) mt , 5.0, 6.0 ] -} -fromLists :: Element t => [[t]] -> Matrix t +fromLists :: Storable t => [[t]] -> Matrix t fromLists = fromRows . map fromList -- | creates a 1-row matrix from a vector @@ -443,7 +443,7 @@ Hilbert matrix of order N: @hilb n = buildMatrix n n (\\(i,j)->1/(fromIntegral i + fromIntegral j +1))@ -} -buildMatrix :: Element a => Int -> Int -> ((Int, Int) -> a) -> Matrix a +buildMatrix :: Storable a => Int -> Int -> ((Int, Int) -> a) -> Matrix a buildMatrix rc cc f = fromLists $ map (map f) $ map (\ ri -> map (\ ci -> (ri, ci)) [0 .. (cc - 1)]) [0 .. (rc - 1)] @@ -458,11 +458,11 @@ fromArray2D m = (r> [Int] -> Matrix t -> Matrix t +extractRows :: Storable t => [Int] -> Matrix t -> Matrix t extractRows l m = m ?? (Pos (idxs l), All) -- | rearranges the rows of a matrix according to the order given in a list of integers. -extractColumns :: Element t => [Int] -> Matrix t -> Matrix t +extractColumns :: Storable t => [Int] -> Matrix t -> Matrix t extractColumns l m = m ?? (All, Pos (idxs l)) @@ -476,13 +476,13 @@ extractColumns l m = m ?? (All, Pos (idxs l)) , 0.0, 1.0, 0.0, 1.0, 0.0, 1.0 ] -} -repmat :: (Element t) => Matrix t -> Int -> Int -> Matrix t +repmat :: (Storable t) => Matrix t -> Int -> Int -> Matrix t repmat m r c | r == 0 || c == 0 = emptyM (r*rows m) (c*cols m) | otherwise = fromBlocks $ replicate r $ replicate c $ m -- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix. -liftMatrix2Auto :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t +liftMatrix2Auto :: (Storable t, Storable a, Storable b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t liftMatrix2Auto f m1 m2 | compat' m1 m2 = lM f m1 m2 | ok = lM f m1' m2' @@ -499,7 +499,7 @@ liftMatrix2Auto f m1 m2 m2' = conformMTo (r,c) m2 -- FIXME do not flatten if equal order -lM :: (Storable t, Element t1, Element t2) +lM :: (Storable t, Storable t1, Storable t2) => (Vector t1 -> Vector t2 -> Vector t) -> Matrix t1 -> Matrix t2 -> Matrix t lM f m1 m2 = matrixFromVector @@ -520,7 +520,7 @@ compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2 ------------------------------------------------------------ -toBlockRows :: Element t => [Int] -> Matrix t -> [Matrix t] +toBlockRows :: Storable t => [Int] -> Matrix t -> [Matrix t] toBlockRows [r] m | r == rows m = [m] toBlockRows rs m @@ -530,13 +530,13 @@ toBlockRows rs m szs = map (* cols m) rs g k = (k><0)[] -toBlockCols :: Element t => [Int] -> Matrix t -> [Matrix t] +toBlockCols :: Storable t => [Int] -> Matrix t -> [Matrix t] toBlockCols [c] m | c == cols m = [m] toBlockCols cs m = map trans . toBlockRows cs . trans $ m -- | Partition a matrix into blocks with the given numbers of rows and columns. -- The remaining rows and columns are discarded. -toBlocks :: (Element t) => [Int] -> [Int] -> Matrix t -> [[Matrix t]] +toBlocks :: (Storable t) => [Int] -> [Int] -> Matrix t -> [[Matrix t]] toBlocks rs cs m | ok = map (toBlockCols cs) . toBlockRows rs $ m | otherwise = error $ "toBlocks: bad partition: "++show rs++" "++show cs @@ -546,7 +546,7 @@ toBlocks rs cs m -- | Fully partition a matrix into blocks of the same size. If the dimensions are not -- a multiple of the given size the last blocks will be smaller. -toBlocksEvery :: (Element t) => Int -> Int -> Matrix t -> [[Matrix t]] +toBlocksEvery :: (Storable t) => Int -> Int -> Matrix t -> [[Matrix t]] toBlocksEvery r c m | r < 1 || c < 1 = error $ "toBlocksEvery expects block sizes > 0, given "++show r++" and "++ show c | otherwise = toBlocks rs cs m @@ -576,7 +576,7 @@ m[1,2] = 6 -} mapMatrixWithIndexM_ - :: (Element a, Num a, Monad m) => + :: (Storable a, Num a, Monad m) => ((Int, Int) -> a -> m ()) -> Matrix a -> m () mapMatrixWithIndexM_ g m = mapVectorWithIndexM_ (mk c g) . flatten $ m where @@ -592,7 +592,7 @@ Just (3><3) -} mapMatrixWithIndexM - :: (Element a, Storable b, Monad m) => + :: (Storable a, Storable b, Monad m) => ((Int, Int) -> a -> m b) -> Matrix a -> m (Matrix b) mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . flatten $ m where @@ -608,11 +608,11 @@ mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . fla -} mapMatrixWithIndex - :: (Element a, Storable b) => + :: (Storable a, Storable b) => ((Int, Int) -> a -> b) -> Matrix a -> Matrix b mapMatrixWithIndex g m = reshape c . mapVectorWithIndex (mk c g) . flatten $ m where c = cols m -mapMatrix :: (Element a, Element b) => (a -> b) -> Matrix a -> Matrix b +mapMatrix :: (Storable a, Storable b) => (a -> b) -> Matrix a -> Matrix b mapMatrix f = liftMatrix (mapVector f) diff --git a/packages/base/src/Internal/Extract.hs b/packages/base/src/Internal/Extract.hs new file mode 100644 index 0000000..84ee20f --- /dev/null +++ b/packages/base/src/Internal/Extract.hs @@ -0,0 +1,145 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE NondecreasingIndentation #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE UnboxedTuples #-} +module Internal.Extract where +import Control.Monad +import Data.Complex +import Data.Function +import Data.Int +import Foreign.Ptr +import Foreign.Storable + +type ConstPtr a = Ptr a +pattern ConstPtr a = a + +extractStorable :: Storable t => + Int32 -- int modei + -> Int32 -- int modej + -> Int32 -- / KIVEC(i) + -> ConstPtr Int32 -- \ + -> Int32 -- / KIVEC(j) + -> ConstPtr Int32 -- \ + -> Int32 -- / + -> Int32 -- / + -> Int32 -- { KO##T##MAT(m) + -> Int32 -- \ + -> ConstPtr t -- \ + -> Int32 -- / + -> Int32 -- / + -> Int32 -- { O##T##MAT(r) + -> Int32 -- \ + -> Ptr t -- \ + -> IO Int32 +extractStorable modei + modej + in_ (ConstPtr ip) + jn (ConstPtr jp) + mr mc mXr mXc (ConstPtr mp) + rr rc rXr rXc rp = do + -- int i,j,si,sj,ni,nj; + ni <- if modei/=0 then return in_ + else fmap succ $ (-) <$> peekElemOff ip 1 <*> peekElemOff ip 0 + nj <- if modej/=0 then return jn + else fmap succ $ (-) <$> peekElemOff jp 1 <*> peekElemOff jp 0 + ($ 0) $ fix $ \iloop i -> when (i peek ip + ($ 0) $ fix $ \jloop j -> when (j peek jp + pokeElemOff rp (fromIntegral $ i*rXr + j*rXc) + =<< peekElemOff mp (fromIntegral $ si*mXr + sj*mXc) + jloop $! succ j + iloop $! succ i + return 0 + +{-# SPECIALIZE extractStorable :: + Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32 + -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr Double + -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr Double + -> IO Int32 #-} + +{-# SPECIALIZE extractStorable :: + Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32 + -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr Float + -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr Float + -> IO Int32 #-} + +{-# SPECIALIZE extractStorable :: + Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32 + -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr (Complex Double) + -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr (Complex Double) + -> IO Int32 #-} + +{-# SPECIALIZE extractStorable :: + Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32 + -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr (Complex Float) + -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr (Complex Float) + -> IO Int32 #-} + +{-# SPECIALIZE extractStorable :: + Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32 + -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr Int32 + -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr Int32 + -> IO Int32 #-} + +{-# SPECIALIZE extractStorable :: + Int32 -> Int32 -> Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr Int32 + -> Int32 -> Int32 -> Int32 -> Int32 -> ConstPtr Int64 + -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr Int64 + -> IO Int32 #-} + +{- +type Reorder x = CV Int32 (CV Int32 (CV Int32 (CV x (CV x (IO Int32))))) + +foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double +foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float +foreign import ccall unsafe "reorderI" c_reorderI :: Reorder Int32 +foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double) +foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float) +foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z +-} + +-- #define ERROR(CODE) MACRO(return CODE;) +-- #define REQUIRES(COND, CODE) MACRO(if(!(COND)) {ERROR(CODE);}) + +requires :: Monad m => Bool -> Int32 -> m Int32 -> m Int32 +requires cond code go = + if cond then go + else return code + +pattern BAD_SIZE = 2000 + +reorderStorable :: Storable a => + Int32 -> Ptr Int32 -- k + -> Int32 -> ConstPtr Int32 -- strides + -> Int32 -> ConstPtr Int32 -- dims + -> Int32 -> ConstPtr a -- v + -> Int32 -> Ptr a -- r + -> IO Int32 +reorderStorable kn kp stridesn stridesp dimsn dimsp vn vp rn rp = do + requires (kn == stridesn && stridesn == dimsn) BAD_SIZE $ do + let ijlloop !i !j l fin = do + pokeElemOff kp (fromIntegral l) 0 + dimspl <- peekElemOff dimsp (fromIntegral l) + stridespl <- peekElemOff stridesp (fromIntegral l) + if (l do + requires (i <= vn && j < rn) BAD_SIZE $ do + (\go -> go 0 0) $ fix $ \ijloop i j -> do + pokeElemOff rp (fromIntegral i) =<< peekElemOff vp (fromIntegral j) + (\go -> go (kn - 1) j) $ fix $ \lloop l !j -> do + kpl <- succ <$> peekElemOff kp (fromIntegral l) + pokeElemOff kp (fromIntegral l) kpl + dimspl <- peekElemOff dimsp (fromIntegral l) + if (kpl < dimspl) + then do + stridespl <- peekElemOff stridesp (fromIntegral l) + ijloop (succ i) (j + stridespl) + else do + if l == 0 then return 0 else do + pokeElemOff kp (fromIntegral l) 0 + stridespl <- peekElemOff stridesp (fromIntegral l) + lloop (pred l) (j - stridespl*(dimspl-1)) diff --git a/packages/base/src/Internal/IO.hs b/packages/base/src/Internal/IO.hs index b0f5606..de5eea5 100644 --- a/packages/base/src/Internal/IO.hs +++ b/packages/base/src/Internal/IO.hs @@ -23,6 +23,7 @@ import Internal.Vectorized import Text.Printf(printf, PrintfArg, PrintfType) import Data.List(intersperse,transpose) import Data.Complex +import Foreign.Storable -- | Formatting tool @@ -45,7 +46,7 @@ this function the user can easily define any desired display function: @disp = putStr . format \" \" (printf \"%.2f\")@ -} -format :: (Element t) => String -> (t -> String) -> Matrix t -> String +format :: (Storable t) => String -> (t -> String) -> Matrix t -> String format sep f m = table sep . map (map f) . toLists $ m {- | Show a matrix with \"autoscaling\" and a given number of decimal places. @@ -81,14 +82,14 @@ dispf d x = sdims x ++ "\n" ++ formatFixed (if isInt x then 0 else d) x sdims :: Matrix t -> [Char] sdims x = show (rows x) ++ "x" ++ show (cols x) -formatFixed :: (Show a, Text.Printf.PrintfArg t, Element t) +formatFixed :: (Show a, Text.Printf.PrintfArg t, Storable t) => a -> Matrix t -> String formatFixed d x = format " " (printf ("%."++show d++"f")) $ x isInt :: Matrix Double -> Bool isInt = all lookslikeInt . toList . flatten -formatScaled :: (Text.Printf.PrintfArg b, RealFrac b, Floating b, Num t, Element b, Show t) +formatScaled :: (Text.Printf.PrintfArg b, RealFrac b, Floating b, Num t, Storable b, Show t) => t -> Matrix b -> [Char] formatScaled dec t = "E"++show o++"\n" ++ ss where ss = format " " (printf fmt. g) t @@ -104,7 +105,7 @@ formatScaled dec t = "E"++show o++"\n" ++ ss 10 |> 0.00 0.11 0.22 0.33 0.44 0.56 0.67 0.78 0.89 1.00 -} -vecdisp :: (Element t) => (Matrix t -> String) -> Vector t -> String +vecdisp :: (Storable t) => (Matrix t -> String) -> Vector t -> String vecdisp f v = ((show (dim v) ++ " |> ") ++) . (++"\n") . unwords . lines . tail . dropWhile (not . (`elem` " \n")) diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index 27d1f95..d88ff6b 100644 --- a/packages/base/src/Internal/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs @@ -22,9 +22,12 @@ import Data.Bifunctor (first) import Internal.Devel import Internal.Vector +import Internal.Vectorized (constantAux) import Internal.Matrix hiding ((#), (#!)) import Internal.Conversion import Internal.Element +import Internal.ST (setRect) +import Data.Int import Foreign.Ptr(nullPtr) import Foreign.C.Types import Control.Monad(when) @@ -46,10 +49,10 @@ type TMMM t = t ::> t ::> t ::> Ok type F = Float type Q = Complex Float -foreign import ccall unsafe "multiplyR" dgemmc :: CInt -> CInt -> TMMM R -foreign import ccall unsafe "multiplyC" zgemmc :: CInt -> CInt -> TMMM C -foreign import ccall unsafe "multiplyF" sgemmc :: CInt -> CInt -> TMMM F -foreign import ccall unsafe "multiplyQ" cgemmc :: CInt -> CInt -> TMMM Q +foreign import ccall unsafe "multiplyR" dgemmc :: Int32 -> Int32 -> TMMM R +foreign import ccall unsafe "multiplyC" zgemmc :: Int32 -> Int32 -> TMMM C +foreign import ccall unsafe "multiplyF" sgemmc :: Int32 -> Int32 -> TMMM F +foreign import ccall unsafe "multiplyQ" cgemmc :: Int32 -> Int32 -> TMMM Q foreign import ccall unsafe "multiplyI" c_multiplyI :: I -> TMMM I foreign import ccall unsafe "multiplyL" c_multiplyL :: Z -> TMMM Z @@ -82,7 +85,7 @@ multiplyF a b = multiplyAux sgemmc "sgemmc" a b multiplyQ :: Matrix (Complex Float) -> Matrix (Complex Float) -> Matrix (Complex Float) multiplyQ a b = multiplyAux cgemmc "cgemmc" a b -multiplyI :: I -> Matrix CInt -> Matrix CInt -> Matrix CInt +multiplyI :: I -> Matrix Int32 -> Matrix Int32 -> Matrix Int32 multiplyI m a b = unsafePerformIO $ do when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b @@ -239,8 +242,8 @@ foreign import ccall unsafe "eig_l_R" dgeev :: R ::> R ::> C :> R ::> Ok foreign import ccall unsafe "eig_l_G" dggev :: R ::> R ::> C :> R :> R ::> R ::> Ok foreign import ccall unsafe "eig_l_C" zgeev :: C ::> C ::> C :> C ::> Ok foreign import ccall unsafe "eig_l_GC" zggev :: C ::> C ::> C :> C :> C ::> C ::> Ok -foreign import ccall unsafe "eig_l_S" dsyev :: CInt -> R :> R ::> Ok -foreign import ccall unsafe "eig_l_H" zheev :: CInt -> R :> C ::> Ok +foreign import ccall unsafe "eig_l_S" dsyev :: Int32 -> R :> R ::> Ok +foreign import ccall unsafe "eig_l_H" zheev :: Int32 -> R :> C ::> Ok eigAux f st m = unsafePerformIO $ do a <- copy ColumnMajor m @@ -636,7 +639,7 @@ qrgrAux f st n (a, tau) = unsafePerformIO $ do ((subVector 0 n tau') #! res) f #| st return res where - tau' = vjoin [tau, constantD 0 n] + tau' = vjoin [tau, constantAux 0 n] ----------------------------------------------------------------------------------- foreign import ccall unsafe "hess_l_R" dgehrd :: R :> R ::> Ok diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 5436e59..04092f9 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs @@ -2,6 +2,7 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE ViewPatterns #-} @@ -22,12 +23,14 @@ module Internal.Matrix where import Internal.Vector import Internal.Devel +import Internal.Extract import Internal.Vectorized hiding ((#), (#!)) import Foreign.Marshal.Alloc ( free ) import Foreign.Marshal.Array(newArray) import Foreign.Ptr ( Ptr ) import Foreign.Storable ( Storable ) import Data.Complex ( Complex ) +import Data.Int import Foreign.C.Types ( CInt(..) ) import Foreign.C.String ( CString, newCString ) import System.IO.Unsafe ( unsafePerformIO ) @@ -61,19 +64,23 @@ size :: Matrix t -> (Int, Int) size m = (irows m, icols m) {-# INLINE size #-} +-- | True if the matrix is in RowMajor form. rowOrder :: Matrix t -> Bool rowOrder m = xCol m == 1 || cols m == 1 {-# INLINE rowOrder #-} +-- | True if the matrix is in ColMajor form or if their is only one row. colOrder :: Matrix t -> Bool colOrder m = xRow m == 1 || rows m == 1 {-# INLINE colOrder #-} +-- | True if the matrix is a single row or column vector. is1d :: Matrix t -> Bool is1d (size->(r,c)) = r==1 || c==1 {-# INLINE is1d #-} --- data is not contiguous +-- | True if the matrix is not contiguous. This usually +-- means it is a slice of some larger matrix. isSlice :: Storable t => Matrix t -> Bool isSlice m@(size->(r,c)) = r*c < dim (xdat m) {-# INLINE isSlice #-} @@ -95,19 +102,23 @@ showInternal m = printf "%dx%d %s %s %d:%d (%d)\n" r c slc ord xr xc dv -------------------------------------------------------------------------------- --- | Matrix transpose. +-- | O(1) Matrix transpose. This is only a logical transposition that does not +-- re-order the element storage. If the storage order is important, use 'cmat' +-- or 'fmat'. trans :: Matrix t -> Matrix t trans m@Matrix { irows = r, icols = c, xRow = xr, xCol = xc } = m { irows = c, icols = r, xRow = xc, xCol = xr } -cmat :: (Element t) => Matrix t -> Matrix t +-- | Obtain the RowMajor equivalent of a given Matrix. +cmat :: (Storable t) => Matrix t -> Matrix t cmat m | rowOrder m = m | otherwise = extractAll RowMajor m -fmat :: (Element t) => Matrix t -> Matrix t +-- | Obtain the ColumnMajor equivalent of a given Matrix. +fmat :: (Storable t) => Matrix t -> Matrix t fmat m | colOrder m = m | otherwise = extractAll ColumnMajor m @@ -115,14 +126,14 @@ fmat m -- C-Haskell matrix adapters {-# INLINE amatr #-} -amatr :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> Ptr a -> f) -> IO r +amatr :: Storable a => Matrix a -> (f -> IO r) -> (Int32 -> Int32 -> Ptr a -> f) -> IO r amatr x f g = unsafeWith (xdat x) (f . g r c) where r = fi (rows x) c = fi (cols x) {-# INLINE amat #-} -amat :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> CInt -> CInt -> Ptr a -> f) -> IO r +amat :: Storable a => Matrix a -> (f -> IO r) -> (Int32 -> Int32 -> Int32 -> Int32 -> Ptr a -> f) -> IO r amat x f g = unsafeWith (xdat x) (f . g r c sr sc) where r = fi (rows x) @@ -133,8 +144,8 @@ amat x f g = unsafeWith (xdat x) (f . g r c sr sc) instance Storable t => TransArray (Matrix t) where - type TransRaw (Matrix t) b = CInt -> CInt -> Ptr t -> b - type Trans (Matrix t) b = CInt -> CInt -> CInt -> CInt -> Ptr t -> b + type TransRaw (Matrix t) b = Int32 -> Int32 -> Ptr t -> b + type Trans (Matrix t) b = Int32 -> Int32 -> Int32 -> Int32 -> Ptr t -> b apply = amat {-# INLINE apply #-} applyRaw = amatr @@ -151,10 +162,10 @@ a #! b = a # b # id -------------------------------------------------------------------------------- -copy :: Element t => MatrixOrder -> Matrix t -> IO (Matrix t) -copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) +copy :: Storable t => MatrixOrder -> Matrix t -> IO (Matrix t) +copy ord m = extractAux ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) -extractAll :: Element t => MatrixOrder -> Matrix t -> Matrix t +extractAll :: Storable t => MatrixOrder -> Matrix t -> Matrix t extractAll ord m = unsafePerformIO (copy ord m) {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. @@ -164,14 +175,14 @@ extractAll ord m = unsafePerformIO (copy ord m) it :: (Num t, Element t) => Vector t -} -flatten :: Element t => Matrix t -> Vector t +flatten :: Storable t => Matrix t -> Vector t flatten m | isSlice m || not (rowOrder m) = xdat (extractAll RowMajor m) | otherwise = xdat m -- | the inverse of 'Data.Packed.Matrix.fromLists' -toLists :: (Element t) => Matrix t -> [[t]] +toLists :: (Storable t) => Matrix t -> [[t]] toLists = map toList . toRows @@ -192,7 +203,7 @@ compatdim (a:b:xs) -- | Create a matrix from a list of vectors. -- All vectors must have the same dimension, -- or dimension 1, which is are automatically expanded. -fromRows :: Element t => [Vector t] -> Matrix t +fromRows :: Storable t => [Vector t] -> Matrix t fromRows [] = emptyM 0 0 fromRows vs = case compatdim (map dim vs) of Nothing -> error $ "fromRows expects vectors with equal sizes (or singletons), given: " ++ show (map dim vs) @@ -203,25 +214,25 @@ fromRows vs = case compatdim (map dim vs) of adapt c v | c == 0 = fromList[] | dim v == c = v - | otherwise = constantD (v@>0) c + | otherwise = constantAux (v@>0) c -- | extracts the rows of a matrix as a list of vectors -toRows :: Element t => Matrix t -> [Vector t] +toRows :: Storable t => Matrix t -> [Vector t] toRows m | rowOrder m = map sub rowRange | otherwise = map ext rowRange where rowRange = [0..rows m-1] sub k = subVector (k*xRow m) (cols m) (xdat m) - ext k = xdat $ unsafePerformIO $ extractR RowMajor m 1 (idxs[k]) 0 (idxs[0,cols m-1]) + ext k = xdat $ unsafePerformIO $ extractAux RowMajor m 1 (idxs[k]) 0 (idxs[0,cols m-1]) -- | Creates a matrix from a list of vectors, as columns -fromColumns :: Element t => [Vector t] -> Matrix t +fromColumns :: Storable t => [Vector t] -> Matrix t fromColumns m = trans . fromRows $ m -- | Creates a list of vectors from the columns of a matrix -toColumns :: Element t => Matrix t -> [Vector t] +toColumns :: Storable t => Matrix t -> [Vector t] toColumns m = toRows . trans $ m -- | Reads a matrix position. @@ -271,13 +282,13 @@ reshape c v = matrixFromVector RowMajor (dim v `div` c) c v -- | application of a vector function on the flattened matrix elements -liftMatrix :: (Element a, Element b) => (Vector a -> Vector b) -> Matrix a -> Matrix b +liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b liftMatrix f m@Matrix { irows = r, icols = c, xdat = d} | isSlice m = matrixFromVector RowMajor r c (f (flatten m)) | otherwise = matrixFromVector (orderOf m) r c (f d) -- | application of a vector function on the flattened matrices elements -liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t +liftMatrix2 :: (Storable t, Storable a, Storable b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t liftMatrix2 f m1@(size->(r,c)) m2 | (r,c)/=size m2 = error "nonconformant matrices in liftMatrix2" | rowOrder m1 = matrixFromVector RowMajor r c (f (flatten m1) (flatten m2)) @@ -285,103 +296,8 @@ liftMatrix2 f m1@(size->(r,c)) m2 ------------------------------------------------------------------ --- | Supported matrix elements. -class (Storable a) => Element a where - constantD :: a -> Int -> Vector a - extractR :: MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) - setRect :: Int -> Int -> Matrix a -> Matrix a -> IO () - sortI :: Ord a => Vector a -> Vector CInt - sortV :: Ord a => Vector a -> Vector a - compareV :: Ord a => Vector a -> Vector a -> Vector CInt - selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a - remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a - rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () - gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () - reorderV :: Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation - - -instance Element Float where - constantD = constantAux cconstantF - extractR = extractAux c_extractF - setRect = setRectAux c_setRectF - sortI = sortIdxF - sortV = sortValF - compareV = compareF - selectV = selectF - remapM = remapF - rowOp = rowOpAux c_rowOpF - gemm = gemmg c_gemmF - reorderV = reorderAux c_reorderF - -instance Element Double where - constantD = constantAux cconstantR - extractR = extractAux c_extractD - setRect = setRectAux c_setRectD - sortI = sortIdxD - sortV = sortValD - compareV = compareD - selectV = selectD - remapM = remapD - rowOp = rowOpAux c_rowOpD - gemm = gemmg c_gemmD - reorderV = reorderAux c_reorderD - -instance Element (Complex Float) where - constantD = constantAux cconstantQ - extractR = extractAux c_extractQ - setRect = setRectAux c_setRectQ - sortI = undefined - sortV = undefined - compareV = undefined - selectV = selectQ - remapM = remapQ - rowOp = rowOpAux c_rowOpQ - gemm = gemmg c_gemmQ - reorderV = reorderAux c_reorderQ - -instance Element (Complex Double) where - constantD = constantAux cconstantC - extractR = extractAux c_extractC - setRect = setRectAux c_setRectC - sortI = undefined - sortV = undefined - compareV = undefined - selectV = selectC - remapM = remapC - rowOp = rowOpAux c_rowOpC - gemm = gemmg c_gemmC - reorderV = reorderAux c_reorderC - -instance Element (CInt) where - constantD = constantAux cconstantI - extractR = extractAux c_extractI - setRect = setRectAux c_setRectI - sortI = sortIdxI - sortV = sortValI - compareV = compareI - selectV = selectI - remapM = remapI - rowOp = rowOpAux c_rowOpI - gemm = gemmg c_gemmI - reorderV = reorderAux c_reorderI - -instance Element Z where - constantD = constantAux cconstantL - extractR = extractAux c_extractL - setRect = setRectAux c_setRectL - sortI = sortIdxL - sortV = sortValL - compareV = compareL - selectV = selectL - remapM = remapL - rowOp = rowOpAux c_rowOpL - gemm = gemmg c_gemmL - reorderV = reorderAux c_reorderL - -------------------------------------------------------------------- - -- | reference to a rectangular slice of a matrix (no data copy) -subMatrix :: Element a +subMatrix :: Storable a => (Int,Int) -- ^ (r0,c0) starting position -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix -> Matrix a -- ^ input matrix @@ -402,34 +318,34 @@ subMatrix (r0,c0) (rt,ct) m maxZ :: (Num t1, Ord t1, Foldable t) => t t1 -> t1 maxZ xs = if minimum xs == 0 then 0 else maximum xs -conformMs :: Element t => [Matrix t] -> [Matrix t] +conformMs :: Storable t => [Matrix t] -> [Matrix t] conformMs ms = map (conformMTo (r,c)) ms where r = maxZ (map rows ms) c = maxZ (map cols ms) -conformVs :: Element t => [Vector t] -> [Vector t] +conformVs :: Storable t => [Vector t] -> [Vector t] conformVs vs = map (conformVTo n) vs where n = maxZ (map dim vs) -conformMTo :: Element t => (Int, Int) -> Matrix t -> Matrix t +conformMTo :: Storable t => (Int, Int) -> Matrix t -> Matrix t conformMTo (r,c) m | size m == (r,c) = m - | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) + | size m == (1,1) = matrixFromVector RowMajor r c (constantAux (m@@>(0,0)) (r*c)) | size m == (r,1) = repCols c m | size m == (1,c) = repRows r m | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c) -conformVTo :: Element t => Int -> Vector t -> Vector t +conformVTo :: Storable t => Int -> Vector t -> Vector t conformVTo n v | dim v == n = v - | dim v == 1 = constantD (v@>0) n + | dim v == 1 = constantAux (v@>0) n | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n -repRows :: Element t => Int -> Matrix t -> Matrix t +repRows :: Storable t => Int -> Matrix t -> Matrix t repRows n x = fromRows (replicate n (flatten x)) -repCols :: Element t => Int -> Matrix t -> Matrix t +repCols :: Storable t => Int -> Matrix t -> Matrix t repCols n x = fromColumns (replicate n (flatten x)) shSize :: Matrix t -> [Char] @@ -453,32 +369,50 @@ instance (Storable t, NFData t) => NFData (Matrix t) --------------------------------------------------------------- +{- extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1, Storable t, Num t3, Num t2, Integral t1, Integral t) - => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t - -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)) - -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a) -extractAux f ord m moder vr modec vc = do + => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)) -- f + -> MatrixOrder -- ord + -> c -- m + -> t3 -- moder + -> Vector t1 -- vr + -> t2 -- modec + -> Vector t -- vc + -> IO (Matrix a) +-} + +extractAux :: Storable a => + MatrixOrder + -> Matrix a + -> Int32 + -> Vector Int32 + -> Int32 + -> Vector Int32 + -> IO (Matrix a) +extractAux ord m moder vr modec vc = do let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc r <- createMatrix ord nr nc - (vr # vc # m #! r) (f moder modec) #|"extract" + (vr # vc # m #! r) (extractStorable moder modec) #|"extract" return r -type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt)))) +{- +type Extr x = Int32 -> Int32 -> CIdxs (CIdxs (OM x (OM x (IO Int32)))) foreign import ccall unsafe "extractD" c_extractD :: Extr Double foreign import ccall unsafe "extractF" c_extractF :: Extr Float foreign import ccall unsafe "extractC" c_extractC :: Extr (Complex Double) foreign import ccall unsafe "extractQ" c_extractQ :: Extr (Complex Float) -foreign import ccall unsafe "extractI" c_extractI :: Extr CInt +foreign import ccall unsafe "extractI" c_extractI :: Extr Int32 foreign import ccall unsafe "extractL" c_extractL :: Extr Z +-} --------------------------------------------------------------- setRectAux :: (TransArray c1, TransArray c) - => (CInt -> CInt -> Trans c1 (Trans c (IO CInt))) + => (Int32 -> Int32 -> Trans c1 (Trans c (IO Int32))) -> Int -> Int -> c1 -> c -> IO () setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" @@ -494,17 +428,17 @@ foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z -------------------------------------------------------------------------------- sortG :: (Storable t, Storable a) - => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a + => (Int32 -> Ptr t -> Int32 -> Ptr a -> IO Int32) -> Vector t -> Vector a sortG f v = unsafePerformIO $ do r <- createVector (dim v) (v #! r) f #|"sortG" return r -sortIdxD :: Vector Double -> Vector CInt +sortIdxD :: Vector Double -> Vector Int32 sortIdxD = sortG c_sort_indexD -sortIdxF :: Vector Float -> Vector CInt +sortIdxF :: Vector Float -> Vector Int32 sortIdxF = sortG c_sort_indexF -sortIdxI :: Vector CInt -> Vector CInt +sortIdxI :: Vector Int32 -> Vector Int32 sortIdxI = sortG c_sort_indexI sortIdxL :: Vector Z -> Vector I sortIdxL = sortG c_sort_indexL @@ -513,81 +447,81 @@ sortValD :: Vector Double -> Vector Double sortValD = sortG c_sort_valD sortValF :: Vector Float -> Vector Float sortValF = sortG c_sort_valF -sortValI :: Vector CInt -> Vector CInt +sortValI :: Vector Int32 -> Vector Int32 sortValI = sortG c_sort_valI sortValL :: Vector Z -> Vector Z sortValL = sortG c_sort_valL -foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt)) -foreign import ccall unsafe "sort_indexF" c_sort_indexF :: CV Float (CV CInt (IO CInt)) -foreign import ccall unsafe "sort_indexI" c_sort_indexI :: CV CInt (CV CInt (IO CInt)) +foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV Int32 (IO Int32)) +foreign import ccall unsafe "sort_indexF" c_sort_indexF :: CV Float (CV Int32 (IO Int32)) +foreign import ccall unsafe "sort_indexI" c_sort_indexI :: CV Int32 (CV Int32 (IO Int32)) foreign import ccall unsafe "sort_indexL" c_sort_indexL :: Z :> I :> Ok -foreign import ccall unsafe "sort_valuesD" c_sort_valD :: CV Double (CV Double (IO CInt)) -foreign import ccall unsafe "sort_valuesF" c_sort_valF :: CV Float (CV Float (IO CInt)) -foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV CInt (CV CInt (IO CInt)) +foreign import ccall unsafe "sort_valuesD" c_sort_valD :: CV Double (CV Double (IO Int32)) +foreign import ccall unsafe "sort_valuesF" c_sort_valF :: CV Float (CV Float (IO Int32)) +foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV Int32 (CV Int32 (IO Int32)) foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok -------------------------------------------------------------------------------- compareG :: (TransArray c, Storable t, Storable a) - => Trans c (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) + => Trans c (Int32 -> Ptr t -> Int32 -> Ptr a -> IO Int32) -> c -> Vector t -> Vector a compareG f u v = unsafePerformIO $ do r <- createVector (dim v) (u # v #! r) f #|"compareG" return r -compareD :: Vector Double -> Vector Double -> Vector CInt +compareD :: Vector Double -> Vector Double -> Vector Int32 compareD = compareG c_compareD -compareF :: Vector Float -> Vector Float -> Vector CInt +compareF :: Vector Float -> Vector Float -> Vector Int32 compareF = compareG c_compareF -compareI :: Vector CInt -> Vector CInt -> Vector CInt +compareI :: Vector Int32 -> Vector Int32 -> Vector Int32 compareI = compareG c_compareI -compareL :: Vector Z -> Vector Z -> Vector CInt +compareL :: Vector Z -> Vector Z -> Vector Int32 compareL = compareG c_compareL -foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) -foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV CInt (IO CInt))) -foreign import ccall unsafe "compareI" c_compareI :: CV CInt (CV CInt (CV CInt (IO CInt))) +foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV Int32 (IO Int32))) +foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV Int32 (IO Int32))) +foreign import ccall unsafe "compareI" c_compareI :: CV Int32 (CV Int32 (CV Int32 (IO Int32))) foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok -------------------------------------------------------------------------------- selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a) - => Trans c2 (Trans c1 (CInt -> Ptr t -> Trans c (CInt -> Ptr a -> IO CInt))) + => Trans c2 (Trans c1 (Int32 -> Ptr t -> Trans c (Int32 -> Ptr a -> IO Int32))) -> c2 -> c1 -> Vector t -> c -> Vector a selectG f c u v w = unsafePerformIO $ do r <- createVector (dim v) (c # u # v # w #! r) f #|"selectG" return r -selectD :: Vector CInt -> Vector Double -> Vector Double -> Vector Double -> Vector Double +selectD :: Vector Int32 -> Vector Double -> Vector Double -> Vector Double -> Vector Double selectD = selectG c_selectD -selectF :: Vector CInt -> Vector Float -> Vector Float -> Vector Float -> Vector Float +selectF :: Vector Int32 -> Vector Float -> Vector Float -> Vector Float -> Vector Float selectF = selectG c_selectF -selectI :: Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt +selectI :: Vector Int32 -> Vector Int32 -> Vector Int32 -> Vector Int32 -> Vector Int32 selectI = selectG c_selectI -selectL :: Vector CInt -> Vector Z -> Vector Z -> Vector Z -> Vector Z +selectL :: Vector Int32 -> Vector Z -> Vector Z -> Vector Z -> Vector Z selectL = selectG c_selectL -selectC :: Vector CInt +selectC :: Vector Int32 -> Vector (Complex Double) -> Vector (Complex Double) -> Vector (Complex Double) -> Vector (Complex Double) selectC = selectG c_selectC -selectQ :: Vector CInt +selectQ :: Vector Int32 -> Vector (Complex Float) -> Vector (Complex Float) -> Vector (Complex Float) -> Vector (Complex Float) selectQ = selectG c_selectQ -type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) +type Sel x = CV Int32 (CV x (CV x (CV x (CV x (IO Int32))))) foreign import ccall unsafe "chooseD" c_selectD :: Sel Double foreign import ccall unsafe "chooseF" c_selectF :: Sel Float -foreign import ccall unsafe "chooseI" c_selectI :: Sel CInt +foreign import ccall unsafe "chooseI" c_selectI :: Sel Int32 foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double) foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float) foreign import ccall unsafe "chooseL" c_selectL :: Sel Z @@ -595,35 +529,35 @@ foreign import ccall unsafe "chooseL" c_selectL :: Sel Z --------------------------------------------------------------------------- remapG :: (TransArray c, TransArray c1, Storable t, Storable a) - => (CInt -> CInt -> CInt -> CInt -> Ptr t - -> Trans c1 (Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt))) + => (Int32 -> Int32 -> Int32 -> Int32 -> Ptr t + -> Trans c1 (Trans c (Int32 -> Int32 -> Int32 -> Int32 -> Ptr a -> IO Int32))) -> Matrix t -> c1 -> c -> Matrix a remapG f i j m = unsafePerformIO $ do r <- createMatrix RowMajor (rows i) (cols i) (i # j # m #! r) f #|"remapG" return r -remapD :: Matrix CInt -> Matrix CInt -> Matrix Double -> Matrix Double +remapD :: Matrix Int32 -> Matrix Int32 -> Matrix Double -> Matrix Double remapD = remapG c_remapD -remapF :: Matrix CInt -> Matrix CInt -> Matrix Float -> Matrix Float +remapF :: Matrix Int32 -> Matrix Int32 -> Matrix Float -> Matrix Float remapF = remapG c_remapF -remapI :: Matrix CInt -> Matrix CInt -> Matrix CInt -> Matrix CInt +remapI :: Matrix Int32 -> Matrix Int32 -> Matrix Int32 -> Matrix Int32 remapI = remapG c_remapI -remapL :: Matrix CInt -> Matrix CInt -> Matrix Z -> Matrix Z +remapL :: Matrix Int32 -> Matrix Int32 -> Matrix Z -> Matrix Z remapL = remapG c_remapL -remapC :: Matrix CInt - -> Matrix CInt +remapC :: Matrix Int32 + -> Matrix Int32 -> Matrix (Complex Double) -> Matrix (Complex Double) remapC = remapG c_remapC -remapQ :: Matrix CInt -> Matrix CInt -> Matrix (Complex Float) -> Matrix (Complex Float) +remapQ :: Matrix Int32 -> Matrix Int32 -> Matrix (Complex Float) -> Matrix (Complex Float) remapQ = remapG c_remapQ -type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) +type Rem x = OM Int32 (OM Int32 (OM x (OM x (IO Int32)))) foreign import ccall unsafe "remapD" c_remapD :: Rem Double foreign import ccall unsafe "remapF" c_remapF :: Rem Float -foreign import ccall unsafe "remapI" c_remapI :: Rem CInt +foreign import ccall unsafe "remapI" c_remapI :: Rem Int32 foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double) foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float) foreign import ccall unsafe "remapL" c_remapL :: Rem Z @@ -631,14 +565,14 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z -------------------------------------------------------------------------------- rowOpAux :: (TransArray c, Storable a) => - (CInt -> Ptr a -> CInt -> CInt -> CInt -> CInt -> Trans c (IO CInt)) + (Int32 -> Ptr a -> Int32 -> Int32 -> Int32 -> Int32 -> Trans c (IO Int32)) -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO () rowOpAux f c x i1 i2 j1 j2 m = do px <- newArray [x] (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" free px -type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok +type RowOp x = Int32 -> Ptr x -> Int32 -> Int32 -> Int32 -> Int32 -> x ::> Ok foreign import ccall unsafe "rowop_double" c_rowOpD :: RowOp R foreign import ccall unsafe "rowop_float" c_rowOpF :: RowOp Float @@ -652,7 +586,7 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z -------------------------------------------------------------------------------- gemmg :: (TransArray c1, TransArray c, TransArray c2, TransArray c3) - => Trans c3 (Trans c2 (Trans c1 (Trans c (IO CInt)))) + => Trans c3 (Trans c2 (Trans c1 (Trans c (IO Int32)))) -> c3 -> c2 -> c1 -> c -> IO () gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" @@ -669,21 +603,26 @@ foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z -------------------------------------------------------------------------------- +{- reorderAux :: (TransArray c, Storable t, Storable a1, Storable t1, Storable a) => - (CInt -> Ptr a -> CInt -> Ptr t1 - -> Trans c (CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt)) + (Int32 -> Ptr a -> Int32 -> Ptr t1 + -> Trans c (Int32 -> Ptr t -> Int32 -> Ptr a1 -> IO Int32)) -> Vector t1 -> c -> Vector t -> Vector a1 +-} +reorderAux :: (TransArray c, Storable a, + Trans c (Int32 -> Ptr a -> Int32 -> Ptr a -> IO Int32) ~ (Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr a -> Int32 -> Ptr a -> IO Int32)) => + p -> Vector Int32 -> c -> Vector a -> Vector a reorderAux f s d v = unsafePerformIO $ do k <- createVector (dim s) r <- createVector (dim v) - (k # s # d # v #! r) f #| "reorderV" + (k # s # d # v #! r) reorderStorable #| "reorderV" return r -type Reorder x = CV CInt (CV CInt (CV CInt (CV x (CV x (IO CInt))))) +type Reorder x = CV Int32 (CV Int32 (CV Int32 (CV x (CV x (IO Int32))))) foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float -foreign import ccall unsafe "reorderI" c_reorderI :: Reorder CInt +foreign import ccall unsafe "reorderI" c_reorderI :: Reorder Int32 foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double) foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float) foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z @@ -691,12 +630,12 @@ foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z -- | Transpose an array with dimensions @dims@ by making a copy using @strides@. For example, for an array with 3 indices, -- @(reorderVector strides dims v) ! ((i * dims ! 1 + j) * dims ! 2 + k) == v ! (i * strides ! 0 + j * strides ! 1 + k * strides ! 2)@ -- This function is intended to be used internally by tensor libraries. -reorderVector :: Element a - => Vector CInt -- ^ @strides@: array strides - -> Vector CInt -- ^ @dims@: array dimensions of new array @v@ +reorderVector :: Storable a + => Vector Int32 -- ^ @strides@: array strides + -> Vector Int32 -- ^ @dims@: array dimensions of new array @v@ -> Vector a -- ^ @v@: flattened input array -> Vector a -- ^ @v'@: flattened output array -reorderVector = reorderV +reorderVector = reorderAux () -------------------------------------------------------------------------------- diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs index eb0c5a8..e67aa67 100644 --- a/packages/base/src/Internal/Modular.hs +++ b/packages/base/src/Internal/Modular.hs @@ -135,6 +135,7 @@ instance (Integral t, KnownNat n) => Num (Mod n t) fromInteger = l0 (\m x -> fromInteger x `mod` (fromIntegral m)) +#if 0 instance KnownNat m => Element (Mod m I) where constantD x n = i2f (constantD (unMod x) n) @@ -168,6 +169,7 @@ instance KnownNat m => Element (Mod m Z) gemm u a b c = gemmg (c_gemmML m') (f2i u) (f2iM a) (f2iM b) (f2iM c) where m' = fromIntegral . natVal $ (undefined :: Proxy m) +#endif instance KnownNat m => CTrans (Mod m I) @@ -306,10 +308,10 @@ f2i :: Storable t => Vector (Mod n t) -> Vector t f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) where (fp,i,n) = unsafeToForeignPtr v -f2iM :: (Element t, Element (Mod n t)) => Matrix (Mod n t) -> Matrix t +f2iM :: (Storable t, Storable (Mod n t)) => Matrix (Mod n t) -> Matrix t f2iM m = m { xdat = f2i (xdat m) } -i2fM :: (Element t, Element (Mod n t)) => Matrix t -> Matrix (Mod n t) +i2fM :: (Storable t, Storable (Mod n t)) => Matrix t -> Matrix (Mod n t) i2fM m = m { xdat = i2f (xdat m) } vmod :: forall m t. (KnownNat m, Storable t, Integral t, Numeric t) => Vector t -> Vector (Mod m t) diff --git a/packages/base/src/Internal/Numeric.hs b/packages/base/src/Internal/Numeric.hs index fd0a217..4f7bb82 100644 --- a/packages/base/src/Internal/Numeric.hs +++ b/packages/base/src/Internal/Numeric.hs @@ -4,6 +4,7 @@ {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE PatternSynonyms #-} {-# OPTIONS_GHC -fno-warn-missing-signatures #-} @@ -22,12 +23,18 @@ module Internal.Numeric where import Internal.Vector import Internal.Matrix import Internal.Element +import Internal.Extract (requires,pattern BAD_SIZE) import Internal.ST as ST import Internal.Conversion import Internal.Vectorized import Internal.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ,multiplyI,multiplyL) +import Control.Monad +import Data.Function +import Data.Int import Data.List.Split(chunksOf) import qualified Data.Vector.Storable as V +import Foreign.Ptr +import Foreign.Storable -------------------------------------------------------------------------------- @@ -44,7 +51,7 @@ type instance ArgOf Matrix a = a -> a -> a -------------------------------------------------------------------------------- -- | Basic element-by-element functions for numeric containers -class Element e => Container c e +class Storable e => Container c e where conj' :: c e -> c e size' :: c e -> IndexOf c @@ -56,7 +63,7 @@ class Element e => Container c e -- | element by element multiplication mul :: c e -> c e -> c e equal :: c e -> c e -> Bool - cmap' :: (Element b) => (e -> b) -> c e -> c b + cmap' :: (Storable b) => (e -> b) -> c e -> c b konst' :: e -> IndexOf c -> c e build' :: IndexOf c -> (ArgOf c e) -> c e atIndex' :: c e -> IndexOf c -> e @@ -107,7 +114,7 @@ instance Container Vector I mul = vectorZipI Mul equal = (==) scalar' = V.singleton - konst' = constantD + konst' = constantAux build' = buildV cmap' = mapVector atIndex' = (@>) @@ -146,7 +153,7 @@ instance Container Vector Z mul = vectorZipL Mul equal = (==) scalar' = V.singleton - konst' = constantD + konst' = constantAux build' = buildV cmap' = mapVector atIndex' = (@>) @@ -186,7 +193,7 @@ instance Container Vector Float mul = vectorZipF Mul equal = (==) scalar' = V.singleton - konst' = constantD + konst' = constantAux build' = buildV cmap' = mapVector atIndex' = (@>) @@ -223,7 +230,7 @@ instance Container Vector Double mul = vectorZipR Mul equal = (==) scalar' = V.singleton - konst' = constantD + konst' = constantAux build' = buildV cmap' = mapVector atIndex' = (@>) @@ -260,7 +267,7 @@ instance Container Vector (Complex Double) mul = vectorZipC Mul equal = (==) scalar' = V.singleton - konst' = constantD + konst' = constantAux build' = buildV cmap' = mapVector atIndex' = (@>) @@ -296,7 +303,7 @@ instance Container Vector (Complex Float) mul = vectorZipQ Mul equal = (==) scalar' = V.singleton - konst' = constantD + konst' = constantAux build' = buildV cmap' = mapVector atIndex' = (@>) @@ -323,7 +330,7 @@ instance Container Vector (Complex Float) --------------------------------------------------------------- -instance (Num a, Element a, Container Vector a) => Container Matrix a +instance (Num a, Storable a, Container Vector a) => Container Matrix a where conj' = liftMatrix conj' size' = size @@ -418,8 +425,8 @@ fromZ = fromZ' toZ :: (Container c e) => c e -> c Z toZ = toZ' --- | like 'fmap' (cannot implement instance Functor because of Element class constraint) -cmap :: (Element b, Container c e) => (e -> b) -> c e -> c b +-- | like 'fmap' (cannot implement instance Functor because of Storable class constraint) +cmap :: (Storable b, Container c e) => (e -> b) -> c e -> c b cmap = cmap' -- | generic indexing function @@ -470,7 +477,7 @@ step step = step' --- | Element by element version of @case compare a b of {LT -> l; EQ -> e; GT -> g}@. +-- | Storable by element version of @case compare a b of {LT -> l; EQ -> e; GT -> g}@. -- -- Arguments with any dimension = 1 are automatically expanded: -- @@ -598,7 +605,7 @@ instance Numeric Z -------------------------------------------------------------------------------- -- | Matrix product and related functions -class (Num e, Element e) => Product e where +class (Num e, Storable e) => Product e where -- | matrix product multiply :: Matrix e -> Matrix e -> Matrix e -- | sum of absolute value of elements (differs in complex case from @norm1@) @@ -823,12 +830,12 @@ buildV n f = fromList [f k | k <- ks] -------------------------------------------------------- -- | Creates a square matrix with a given diagonal. -diag :: (Num a, Element a) => Vector a -> Matrix a +diag :: (Num a, Storable a) => Vector a -> Matrix a diag v = diagRect 0 v n n where n = dim v -- | creates the identity matrix of given dimension -ident :: (Num a, Element a) => Int -> Matrix a -ident n = diag (constantD 1 n) +ident :: (Num a, Storable a) => Int -> Matrix a +ident n = diag (constantAux 1 n) -------------------------------------------------------- @@ -943,3 +950,44 @@ class Testable t -------------------------------------------------------------------------------- +compareV :: (Storable a, Ord a) => Vector a -> Vector a -> Vector Int32 +compareV = compareG compareStorable + +compareStorable :: (Storable a, Ord a) => + Int32 -> Ptr a + -> Int32 -> Ptr a + -> Int32 -> Ptr Int32 + -> IO Int32 +compareStorable xn xp yn yp rn rp = do + requires (xn==yn && xn==rn) BAD_SIZE $ do + ($ 0) $ fix $ \kloop k -> when (k -1 + GT -> 1 + EQ -> 0 + kloop (succ k) + return 0 + +selectV :: Storable a => Vector Int32 -> Vector a -> Vector a -> Vector a -> Vector a +selectV = selectG selectStorable + +selectStorable :: Storable a => + Int32 -> Ptr Int32 + -> Int32 -> Ptr a + -> Int32 -> Ptr a + -> Int32 -> Ptr a + -> Int32 -> Ptr a + -> IO Int32 +selectStorable condn condp ltn ltp eqn eqp gtn gtp rn rp = do + requires (condn==ltn && ltn==eqn && ltn==gtn && ltn==rn) BAD_SIZE $ do + ($ 0) $ fix $ \kloop k -> when (k peekElemOff ltp (fromIntegral k) + GT -> peekElemOff gtp (fromIntegral k) + EQ -> peekElemOff eqp (fromIntegral k) + kloop (succ k) + return 0 + diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index 7d54e6d..326b90a 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs @@ -1,6 +1,7 @@ {-# LANGUAGE Rank2Types #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE PatternSynonyms #-} ----------------------------------------------------------------------------- -- | @@ -30,14 +31,20 @@ module Internal.ST ( unsafeThawVector, unsafeFreezeVector, newUndefinedMatrix, unsafeReadMatrix, unsafeWriteMatrix, - unsafeThawMatrix, unsafeFreezeMatrix + unsafeThawMatrix, unsafeFreezeMatrix, + setRect ) where import Internal.Vector import Internal.Matrix import Internal.Vectorized +import Internal.Devel ((#|)) import Control.Monad.ST(ST, runST) -import Foreign.Storable(Storable, peekElemOff, pokeElemOff) +import Control.Monad +import Data.Function +import Data.Int +import Foreign.Ptr +import Foreign.Storable import Control.Monad.ST.Unsafe(unsafeIOToST) {-# INLINE ioReadV #-} @@ -121,7 +128,7 @@ ioWriteM m r c val = ioWriteV (xdat m) (r * xRow m + c * xCol m) val newtype STMatrix s t = STMatrix (Matrix t) -thawMatrix :: Element t => Matrix t -> ST s (STMatrix s t) +thawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) @@ -142,17 +149,17 @@ unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c -liftSTMatrix :: (Element t) => (Matrix t -> a) -> STMatrix s t -> ST s a +liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s t -> ST s a liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x unsafeFreezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x -freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t) +freezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) freezeMatrix m = liftSTMatrix id m -cloneMatrix :: Element t => Matrix t -> IO (Matrix t) +cloneMatrix :: Storable t => Matrix t -> IO (Matrix t) cloneMatrix m = copy (orderOf m) m {-# INLINE safeIndexM #-} @@ -172,7 +179,7 @@ readMatrix = safeIndexM unsafeReadMatrix writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s () writeMatrix = safeIndexM unsafeWriteMatrix -setMatrix :: Element t => STMatrix s t -> Int -> Int -> Matrix t -> ST s () +setMatrix :: Storable t => STMatrix s t -> Int -> Int -> Matrix t -> ST s () setMatrix (STMatrix x) i j m = unsafeIOToST $ setRect i j m x newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t) @@ -210,7 +217,7 @@ data RowOper t = AXPY t Int Int ColRange | SCAL t RowRange ColRange | SWAP Int Int ColRange -rowOper :: (Num t, Element t) => RowOper t -> STMatrix s t -> ST s () +rowOper :: (Num t, Storable t) => RowOper t -> STMatrix s t -> ST s () rowOper (AXPY x i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 0 x i1' i2' j1 j2 m where @@ -230,8 +237,8 @@ rowOper (SWAP i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 2 0 i1' i2' j1 j2 m i2' = i2 `mod` (rows m) -extractMatrix :: Element a => STMatrix t a -> RowRange -> ColRange -> ST s (Matrix a) -extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) +extractMatrix :: Storable a => STMatrix t a -> RowRange -> ColRange -> ST s (Matrix a) +extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractAux (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) where (i1,i2) = getRowRange (rows m) rr (j1,j2) = getColRange (cols m) rc @@ -239,19 +246,117 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[ -- | r0 c0 height width data Slice s t = Slice (STMatrix s t) Int Int Int Int -slice :: Element a => Slice t a -> Matrix a +slice :: Storable a => Slice t a -> Matrix a slice (Slice (STMatrix m) r0 c0 nr nc) = subMatrix (r0,c0) (nr,nc) m -gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () +gemmm :: (Storable t, Num t) => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () gemmm beta (slice->r) alpha (slice->a) (slice->b) = res where res = unsafeIOToST (gemm v a b r) v = fromList [alpha,beta] -mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) +mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) mutable f a = runST $ do x <- thawMatrix a info <- f (rows a, cols a) x r <- unsafeFreezeMatrix x return (r,info) + + + +setRect :: Storable t => Int -> Int -> Matrix t -> Matrix t -> IO () +setRect i j m r = (m Internal.Matrix.#! r) (setRectStorable (fi i) (fi j)) #|"setRect" + +setRectStorable :: Storable t => + Int32 -> Int32 + -> Int32 -> Int32 -> Int32 -> Int32 -> {- const -} Ptr t + -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr t + -> IO Int32 +setRectStorable i j mr mc mXr mXc mp rr rc rXr rXc rp = do + ($ 0) $ fix $ \aloop a -> when (a when (b Int -> t -> Int -> Int -> Int -> Int -> Matrix t -> IO () +rowOp = rowOpAux rowOpStorable + +pattern BAD_CODE = 2001 + +rowOpStorable :: (Storable t, Num t) => + Int32 -> Ptr t -> Int32 -> Int32 -> Int32 -> Int32 + -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr t + -> IO Int32 +rowOpStorable 0 pa i1 i2 j1 j2 rr rc rXr rXc rp = do + -- AXPY_IMP + a <- peek pa + ($ j1) $ fix $ \jloop j -> when (j<=j2) $ do + ri1j <- peekElemOff rp $ fromIntegral $ rXr*i1 + rXc*j + let i2j = fromIntegral $ rXr*i2 + rXc*j + ri2j <- peekElemOff rp i2j + pokeElemOff rp i2j $ ri2j + a*ri1j + jloop (succ j) + return 0 +rowOpStorable 1 pa i1 i2 j1 j2 rr rc rXr rXc rp = do + -- SCAL_IMP + a <- peek pa + ($ i1) $ fix $ \iloop i -> when (i<=i2) $ do + ($ j1) $ fix $ \jloop j -> when (j<=j2) $ do + let rijp = rp `plusPtr` fromIntegral (rXr*i + rXc*j) + rij <- peek rijp + poke rijp $ a * rij + jloop (succ j) + iloop (succ i) + return 0 +rowOpStorable 2 pa i1 i2 j1 j2 rr rc rXr rXc rp | i1 == i2 = return 0 +rowOpStorable 2 pa i1 i2 j1 j2 rr rc rXr rXc rp = do + -- SWAP_IMP + ($ j1) $ fix $ \kloop k -> when (k<=j2) $ do + let i1k = fromIntegral $ rXr*i1 + rXc*k + i2k = fromIntegral $ rXr*i2 + rXc*k + aux <- peekElemOff rp i1k + pokeElemOff rp i1k =<< peekElemOff rp i2k + pokeElemOff rp i2k aux + kloop (succ k) + return 0 +rowOpStorable _ pa i1 i2 j1 j2 rr rc rXr rXc rp = do + return BAD_CODE + +gemm :: (Storable t, Num t) => Vector t -> Matrix t -> Matrix t -> Matrix t -> IO () +gemm v m1 m2 m3 = (v Internal.Matrix.# m1 Internal.Matrix.# m2 Internal.Matrix.#! m3) gemmStorable #|"gemm" + +-- ScalarLike t +gemmStorable :: (Storable t, Num t) => + Int32 -> Ptr t -- VECG(T,c) + -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr t -- MATG(T,a) + -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr t -- MATG(T,b) + -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr t -- MATG(T,r) + -> IO Int32 +gemmStorable cn cp + ar ac aXr aXc ap + br bc bXr bXc bp + rr rc rXr rXc rp = do + a <- peek cp + b <- peekElemOff cp 1 + ($ 0) $ fix $ \iloop i -> when (i when (j do + let ij = fromIntegral $ i*rXr + j*rXc + rij <- peekElemOff rp ij + pokeElemOff rp ij (b*rij + a*t) + jloop (succ j) + iloop (succ i) + return 0 diff --git a/packages/base/src/Internal/Sparse.hs b/packages/base/src/Internal/Sparse.hs index fbea11a..423b169 100644 --- a/packages/base/src/Internal/Sparse.hs +++ b/packages/base/src/Internal/Sparse.hs @@ -20,7 +20,7 @@ import Data.Function(on) import Control.Arrow((***)) import Control.Monad(when) import Data.List(groupBy, sort) -import Foreign.C.Types(CInt(..)) +import Data.Int import Internal.Devel import System.IO.Unsafe(unsafePerformIO) @@ -34,16 +34,16 @@ type AssocMatrix = [((Int,Int),Double)] data CSR = CSR { csrVals :: Vector Double - , csrCols :: Vector CInt - , csrRows :: Vector CInt + , csrCols :: Vector Int32 + , csrRows :: Vector Int32 , csrNRows :: Int , csrNCols :: Int } deriving Show data CSC = CSC { cscVals :: Vector Double - , cscRows :: Vector CInt - , cscCols :: Vector CInt + , cscRows :: Vector Int32 + , cscCols :: Vector Int32 , cscNRows :: Int , cscNCols :: Int } deriving Show @@ -138,9 +138,9 @@ mkDiagR r c v diagVals = v -type IV t = CInt -> Ptr CInt -> t -type V t = CInt -> Ptr Double -> t -type SMxV = V (IV (IV (V (V (IO CInt))))) +type IV t = Int32 -> Ptr Int32 -> t +type V t = Int32 -> Ptr Double -> t +type SMxV = V (IV (IV (V (V (IO Int32))))) gmXv :: GMatrix -> Vector Double -> Vector Double gmXv SparseR { gmCSR = CSR{..}, .. } v = unsafePerformIO $ do diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs index f642e8d..6f3b4c8 100644 --- a/packages/base/src/Internal/Util.hs +++ b/packages/base/src/Internal/Util.hs @@ -83,6 +83,7 @@ import Control.Arrow((&&&),(***)) import Data.Complex import Data.Function(on) import Internal.ST +import Foreign.Storable #if MIN_VERSION_base(4,11,0) import Prelude hiding ((<>)) #endif @@ -174,7 +175,7 @@ a & b = vjoin [a,b] -} infixl 3 ||| -(|||) :: Element t => Matrix t -> Matrix t -> Matrix t +(|||) :: Storable t => Matrix t -> Matrix t -> Matrix t a ||| b = fromBlocks [[a,b]] -- | a synonym for ('|||') (unicode 0x00a6, broken bar) @@ -185,7 +186,7 @@ infixl 3 ¦ -- | vertical concatenation -- -(===) :: Element t => Matrix t -> Matrix t -> Matrix t +(===) :: Storable t => Matrix t -> Matrix t -> Matrix t infixl 2 === a === b = fromBlocks [[a],[b]] @@ -225,7 +226,7 @@ col = asColumn . fromList -} infixl 9 ? -(?) :: Element t => Matrix t -> [Int] -> Matrix t +(?) :: Storable t => Matrix t -> [Int] -> Matrix t (?) = flip extractRows {- | extract columns @@ -240,7 +241,7 @@ infixl 9 ? -} infixl 9 ¿ -(¿) :: Element t => Matrix t -> [Int] -> Matrix t +(¿) :: Storable t => Matrix t -> [Int] -> Matrix t (¿)= flip extractColumns @@ -329,7 +330,7 @@ instance Normed (Vector (Complex Float)) norm_Inf = norm_Inf . double -- | Frobenius norm (Schatten p-norm with p=2) -norm_Frob :: (Normed (Vector t), Element t) => Matrix t -> R +norm_Frob :: (Normed (Vector t), Storable t) => Matrix t -> R norm_Frob = norm_2 . flatten -- | Sum of singular values (Schatten p-norm with p=1) @@ -346,7 +347,7 @@ True True -} -magnit :: (Element t, Normed (Vector t)) => R -> t -> Bool +magnit :: (Storable t, Normed (Vector t)) => R -> t -> Bool magnit e x = norm_1 (fromList [x]) > e @@ -415,7 +416,7 @@ instance Indexable (Vector (Complex Float)) (Complex Float) where (!) = (@>) -instance Element t => Indexable (Matrix t) (Vector t) +instance Storable t => Indexable (Matrix t) (Vector t) where m!j = subVector (j*c) c (flatten m) where diff --git a/packages/base/src/Internal/Vector.hs b/packages/base/src/Internal/Vector.hs index 6271bb6..3037019 100644 --- a/packages/base/src/Internal/Vector.hs +++ b/packages/base/src/Internal/Vector.hs @@ -32,7 +32,7 @@ import Foreign.ForeignPtr import Foreign.Ptr import Foreign.Storable import Foreign.C.Types(CInt) -import Data.Int(Int64) +import Data.Int import Data.Complex import System.IO.Unsafe(unsafePerformIO) import GHC.ForeignPtr(mallocPlainForeignPtrBytes) @@ -46,18 +46,18 @@ import Control.Monad(replicateM) import qualified Data.ByteString.Internal as BS import Data.Vector.Storable.Internal(updPtr) -type I = CInt +type I = Int32 type Z = Int64 type R = Double type C = Complex Double -- | specialized fromIntegral -fi :: Int -> CInt +fi :: Int -> Int32 fi = fromIntegral -- | specialized fromIntegral -ti :: CInt -> Int +ti :: Int32 -> Int ti = fromIntegral @@ -69,7 +69,7 @@ dim = Vector.length -- C-Haskell vector adapter {-# INLINE avec #-} -avec :: Storable a => Vector a -> (f -> IO r) -> ((CInt -> Ptr a -> f) -> IO r) +avec :: Storable a => Vector a -> (f -> IO r) -> ((Int32 -> Ptr a -> f) -> IO r) avec v f g = unsafeWith v $ \ptr -> f (g (fromIntegral (Vector.length v)) ptr) -- allocates memory for a new vector diff --git a/packages/base/src/Internal/Vectorized.hs b/packages/base/src/Internal/Vectorized.hs index 32430c6..ede3826 100644 --- a/packages/base/src/Internal/Vectorized.hs +++ b/packages/base/src/Internal/Vectorized.hs @@ -18,10 +18,12 @@ module Internal.Vectorized where import Internal.Vector import Internal.Devel import Data.Complex +import Data.Function +import Data.Int import Foreign.Marshal.Alloc(free,malloc) import Foreign.Marshal.Array(newArray,copyArray) import Foreign.Ptr(Ptr) -import Foreign.Storable(peek,Storable) +import Foreign.Storable(peek,pokeElemOff,Storable) import Foreign.C.Types import Foreign.C.String import System.IO.Unsafe(unsafePerformIO) @@ -36,8 +38,8 @@ a # b = applyRaw a b a #! b = a # b # id {-# INLINE (#!) #-} -fromei :: Enum a => a -> CInt -fromei x = fromIntegral (fromEnum x) :: CInt +fromei :: Enum a => a -> Int32 +fromei x = fromIntegral (fromEnum x) :: Int32 data FunCodeV = Sin | Cos @@ -103,20 +105,20 @@ sumQ = sumg c_sumQ sumC :: Vector (Complex Double) -> Complex Double sumC = sumg c_sumC -sumI :: ( TransRaw c (CInt -> Ptr a -> IO CInt) ~ (CInt -> Ptr I -> I :> Ok) +sumI :: ( TransRaw c (Int32 -> Ptr a -> IO Int32) ~ (Int32 -> Ptr I -> I :> Ok) , TransArray c , Storable a ) => I -> c -> a sumI m = sumg (c_sumI m) -sumL :: ( TransRaw c (CInt -> Ptr a -> IO CInt) ~ (CInt -> Ptr Z -> Z :> Ok) +sumL :: ( TransRaw c (Int32 -> Ptr a -> IO Int32) ~ (Int32 -> Ptr Z -> Z :> Ok) , TransArray c , Storable a ) => Z -> c -> a sumL m = sumg (c_sumL m) -sumg :: (TransArray c, Storable a) => TransRaw c (CInt -> Ptr a -> IO CInt) -> c -> a +sumg :: (TransArray c, Storable a) => TransRaw c (Int32 -> Ptr a -> IO Int32) -> c -> a sumg f x = unsafePerformIO $ do r <- createVector 1 (x #! r) f #| "sum" @@ -154,7 +156,7 @@ prodL :: Z-> Vector Z -> Z prodL = prodg . c_prodL prodg :: (TransArray c, Storable a) - => TransRaw c (CInt -> Ptr a -> IO CInt) -> c -> a + => TransRaw c (Int32 -> Ptr a -> IO Int32) -> c -> a prodg f x = unsafePerformIO $ do r <- createVector 1 (x #! r) f #| "prod" @@ -171,7 +173,7 @@ foreign import ccall unsafe "prodL" c_prodL :: Z -> TVV Z ------------------------------------------------------------------ toScalarAux :: (Enum a, TransArray c, Storable a1) - => (CInt -> TransRaw c (CInt -> Ptr a1 -> IO CInt)) -> a -> c -> a1 + => (Int32 -> TransRaw c (Int32 -> Ptr a1 -> IO Int32)) -> a -> c -> a1 toScalarAux fun code v = unsafePerformIO $ do r <- createVector 1 (v #! r) (fun (fromei code)) #|"toScalarAux" @@ -179,7 +181,7 @@ toScalarAux fun code v = unsafePerformIO $ do vectorMapAux :: (Enum a, Storable t, Storable a1) - => (CInt -> CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt) + => (Int32 -> Int32 -> Ptr t -> Int32 -> Ptr a1 -> IO Int32) -> a -> Vector t -> Vector a1 vectorMapAux fun code v = unsafePerformIO $ do r <- createVector (dim v) @@ -187,7 +189,7 @@ vectorMapAux fun code v = unsafePerformIO $ do return r vectorMapValAux :: (Enum a, Storable a2, Storable t, Storable a1) - => (CInt -> Ptr a2 -> CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt) + => (Int32 -> Ptr a2 -> Int32 -> Ptr t -> Int32 -> Ptr a1 -> IO Int32) -> a -> a2 -> Vector t -> Vector a1 vectorMapValAux fun code val v = unsafePerformIO $ do r <- createVector (dim v) @@ -197,7 +199,7 @@ vectorMapValAux fun code val v = unsafePerformIO $ do return r vectorZipAux :: (Enum a, TransArray c, Storable t, Storable a1) - => (CInt -> CInt -> Ptr t -> TransRaw c (CInt -> Ptr a1 -> IO CInt)) + => (Int32 -> Int32 -> Ptr t -> TransRaw c (Int32 -> Ptr a1 -> IO Int32)) -> a -> Vector t -> c -> Vector a1 vectorZipAux fun code u v = unsafePerformIO $ do r <- createVector (dim u) @@ -210,37 +212,37 @@ vectorZipAux fun code u v = unsafePerformIO $ do toScalarR :: FunCodeS -> Vector Double -> Double toScalarR oper = toScalarAux c_toScalarR (fromei oper) -foreign import ccall unsafe "toScalarR" c_toScalarR :: CInt -> TVV Double +foreign import ccall unsafe "toScalarR" c_toScalarR :: Int32 -> TVV Double -- | obtains different functions of a vector: norm1, norm2, max, min, posmax, posmin, etc. toScalarF :: FunCodeS -> Vector Float -> Float toScalarF oper = toScalarAux c_toScalarF (fromei oper) -foreign import ccall unsafe "toScalarF" c_toScalarF :: CInt -> TVV Float +foreign import ccall unsafe "toScalarF" c_toScalarF :: Int32 -> TVV Float -- | obtains different functions of a vector: only norm1, norm2 toScalarC :: FunCodeS -> Vector (Complex Double) -> Double toScalarC oper = toScalarAux c_toScalarC (fromei oper) -foreign import ccall unsafe "toScalarC" c_toScalarC :: CInt -> Complex Double :> Double :> Ok +foreign import ccall unsafe "toScalarC" c_toScalarC :: Int32 -> Complex Double :> Double :> Ok -- | obtains different functions of a vector: only norm1, norm2 toScalarQ :: FunCodeS -> Vector (Complex Float) -> Float toScalarQ oper = toScalarAux c_toScalarQ (fromei oper) -foreign import ccall unsafe "toScalarQ" c_toScalarQ :: CInt -> Complex Float :> Float :> Ok +foreign import ccall unsafe "toScalarQ" c_toScalarQ :: Int32 -> Complex Float :> Float :> Ok -- | obtains different functions of a vector: norm1, norm2, max, min, posmax, posmin, etc. -toScalarI :: FunCodeS -> Vector CInt -> CInt +toScalarI :: FunCodeS -> Vector Int32 -> Int32 toScalarI oper = toScalarAux c_toScalarI (fromei oper) -foreign import ccall unsafe "toScalarI" c_toScalarI :: CInt -> TVV CInt +foreign import ccall unsafe "toScalarI" c_toScalarI :: Int32 -> TVV Int32 -- | obtains different functions of a vector: norm1, norm2, max, min, posmax, posmin, etc. toScalarL :: FunCodeS -> Vector Z -> Z toScalarL oper = toScalarAux c_toScalarL (fromei oper) -foreign import ccall unsafe "toScalarL" c_toScalarL :: CInt -> TVV Z +foreign import ccall unsafe "toScalarL" c_toScalarL :: Int32 -> TVV Z ------------------------------------------------------------------ @@ -249,37 +251,37 @@ foreign import ccall unsafe "toScalarL" c_toScalarL :: CInt -> TVV Z vectorMapR :: FunCodeV -> Vector Double -> Vector Double vectorMapR = vectorMapAux c_vectorMapR -foreign import ccall unsafe "mapR" c_vectorMapR :: CInt -> TVV Double +foreign import ccall unsafe "mapR" c_vectorMapR :: Int32 -> TVV Double -- | map of complex vectors with given function vectorMapC :: FunCodeV -> Vector (Complex Double) -> Vector (Complex Double) vectorMapC oper = vectorMapAux c_vectorMapC (fromei oper) -foreign import ccall unsafe "mapC" c_vectorMapC :: CInt -> TVV (Complex Double) +foreign import ccall unsafe "mapC" c_vectorMapC :: Int32 -> TVV (Complex Double) -- | map of real vectors with given function vectorMapF :: FunCodeV -> Vector Float -> Vector Float vectorMapF = vectorMapAux c_vectorMapF -foreign import ccall unsafe "mapF" c_vectorMapF :: CInt -> TVV Float +foreign import ccall unsafe "mapF" c_vectorMapF :: Int32 -> TVV Float -- | map of real vectors with given function vectorMapQ :: FunCodeV -> Vector (Complex Float) -> Vector (Complex Float) vectorMapQ = vectorMapAux c_vectorMapQ -foreign import ccall unsafe "mapQ" c_vectorMapQ :: CInt -> TVV (Complex Float) +foreign import ccall unsafe "mapQ" c_vectorMapQ :: Int32 -> TVV (Complex Float) -- | map of real vectors with given function -vectorMapI :: FunCodeV -> Vector CInt -> Vector CInt +vectorMapI :: FunCodeV -> Vector Int32 -> Vector Int32 vectorMapI = vectorMapAux c_vectorMapI -foreign import ccall unsafe "mapI" c_vectorMapI :: CInt -> TVV CInt +foreign import ccall unsafe "mapI" c_vectorMapI :: Int32 -> TVV Int32 -- | map of real vectors with given function vectorMapL :: FunCodeV -> Vector Z -> Vector Z vectorMapL = vectorMapAux c_vectorMapL -foreign import ccall unsafe "mapL" c_vectorMapL :: CInt -> TVV Z +foreign import ccall unsafe "mapL" c_vectorMapL :: Int32 -> TVV Z ------------------------------------------------------------------- @@ -287,37 +289,37 @@ foreign import ccall unsafe "mapL" c_vectorMapL :: CInt -> TVV Z vectorMapValR :: FunCodeSV -> Double -> Vector Double -> Vector Double vectorMapValR oper = vectorMapValAux c_vectorMapValR (fromei oper) -foreign import ccall unsafe "mapValR" c_vectorMapValR :: CInt -> Ptr Double -> TVV Double +foreign import ccall unsafe "mapValR" c_vectorMapValR :: Int32 -> Ptr Double -> TVV Double -- | map of complex vectors with given function vectorMapValC :: FunCodeSV -> Complex Double -> Vector (Complex Double) -> Vector (Complex Double) vectorMapValC = vectorMapValAux c_vectorMapValC -foreign import ccall unsafe "mapValC" c_vectorMapValC :: CInt -> Ptr (Complex Double) -> TVV (Complex Double) +foreign import ccall unsafe "mapValC" c_vectorMapValC :: Int32 -> Ptr (Complex Double) -> TVV (Complex Double) -- | map of real vectors with given function vectorMapValF :: FunCodeSV -> Float -> Vector Float -> Vector Float vectorMapValF oper = vectorMapValAux c_vectorMapValF (fromei oper) -foreign import ccall unsafe "mapValF" c_vectorMapValF :: CInt -> Ptr Float -> TVV Float +foreign import ccall unsafe "mapValF" c_vectorMapValF :: Int32 -> Ptr Float -> TVV Float -- | map of complex vectors with given function vectorMapValQ :: FunCodeSV -> Complex Float -> Vector (Complex Float) -> Vector (Complex Float) vectorMapValQ oper = vectorMapValAux c_vectorMapValQ (fromei oper) -foreign import ccall unsafe "mapValQ" c_vectorMapValQ :: CInt -> Ptr (Complex Float) -> TVV (Complex Float) +foreign import ccall unsafe "mapValQ" c_vectorMapValQ :: Int32 -> Ptr (Complex Float) -> TVV (Complex Float) -- | map of real vectors with given function -vectorMapValI :: FunCodeSV -> CInt -> Vector CInt -> Vector CInt +vectorMapValI :: FunCodeSV -> Int32 -> Vector Int32 -> Vector Int32 vectorMapValI oper = vectorMapValAux c_vectorMapValI (fromei oper) -foreign import ccall unsafe "mapValI" c_vectorMapValI :: CInt -> Ptr CInt -> TVV CInt +foreign import ccall unsafe "mapValI" c_vectorMapValI :: Int32 -> Ptr Int32 -> TVV Int32 -- | map of real vectors with given function vectorMapValL :: FunCodeSV -> Z -> Vector Z -> Vector Z vectorMapValL oper = vectorMapValAux c_vectorMapValL (fromei oper) -foreign import ccall unsafe "mapValL" c_vectorMapValL :: CInt -> Ptr Z -> TVV Z +foreign import ccall unsafe "mapValL" c_vectorMapValL :: Int32 -> Ptr Z -> TVV Z ------------------------------------------------------------------- @@ -328,42 +330,42 @@ type TVVV t = t :> t :> t :> Ok vectorZipR :: FunCodeVV -> Vector Double -> Vector Double -> Vector Double vectorZipR = vectorZipAux c_vectorZipR -foreign import ccall unsafe "zipR" c_vectorZipR :: CInt -> TVVV Double +foreign import ccall unsafe "zipR" c_vectorZipR :: Int32 -> TVVV Double -- | elementwise operation on complex vectors vectorZipC :: FunCodeVV -> Vector (Complex Double) -> Vector (Complex Double) -> Vector (Complex Double) vectorZipC = vectorZipAux c_vectorZipC -foreign import ccall unsafe "zipC" c_vectorZipC :: CInt -> TVVV (Complex Double) +foreign import ccall unsafe "zipC" c_vectorZipC :: Int32 -> TVVV (Complex Double) -- | elementwise operation on real vectors vectorZipF :: FunCodeVV -> Vector Float -> Vector Float -> Vector Float vectorZipF = vectorZipAux c_vectorZipF -foreign import ccall unsafe "zipF" c_vectorZipF :: CInt -> TVVV Float +foreign import ccall unsafe "zipF" c_vectorZipF :: Int32 -> TVVV Float -- | elementwise operation on complex vectors vectorZipQ :: FunCodeVV -> Vector (Complex Float) -> Vector (Complex Float) -> Vector (Complex Float) vectorZipQ = vectorZipAux c_vectorZipQ -foreign import ccall unsafe "zipQ" c_vectorZipQ :: CInt -> TVVV (Complex Float) +foreign import ccall unsafe "zipQ" c_vectorZipQ :: Int32 -> TVVV (Complex Float) --- | elementwise operation on CInt vectors -vectorZipI :: FunCodeVV -> Vector CInt -> Vector CInt -> Vector CInt +-- | elementwise operation on Int32 vectors +vectorZipI :: FunCodeVV -> Vector Int32 -> Vector Int32 -> Vector Int32 vectorZipI = vectorZipAux c_vectorZipI -foreign import ccall unsafe "zipI" c_vectorZipI :: CInt -> TVVV CInt +foreign import ccall unsafe "zipI" c_vectorZipI :: Int32 -> TVVV Int32 --- | elementwise operation on CInt vectors +-- | elementwise operation on Int32 vectors vectorZipL :: FunCodeVV -> Vector Z -> Vector Z -> Vector Z vectorZipL = vectorZipAux c_vectorZipL -foreign import ccall unsafe "zipL" c_vectorZipL :: CInt -> TVVV Z +foreign import ccall unsafe "zipL" c_vectorZipL :: Int32 -> TVVV Z -------------------------------------------------------------------------------- foreign import ccall unsafe "vectorScan" c_vectorScan - :: CString -> Ptr CInt -> Ptr (Ptr Double) -> IO CInt + :: CString -> Ptr Int32 -> Ptr (Ptr Double) -> IO Int32 vectorScan :: FilePath -> IO (Vector Double) vectorScan s = do @@ -401,7 +403,7 @@ randomVector seed dist n = unsafePerformIO $ do (r # id) (c_random_vector (fi seed) ((fi.fromEnum) dist)) #|"randomVector" return r -foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> Double :> Ok +foreign import ccall unsafe "random_vector" c_random_vector :: Int32 -> Int32 -> Double :> Ok -------------------------------------------------------------------------------- @@ -426,7 +428,7 @@ range n = unsafePerformIO $ do (r # id) c_range_vector #|"range" return r -foreign import ccall unsafe "range_vector" c_range_vector :: CInt :> Ok +foreign import ccall unsafe "range_vector" c_range_vector :: Int32 :> Ok float2DoubleV :: Vector Float -> Vector Double @@ -435,10 +437,10 @@ float2DoubleV = tog c_float2double double2FloatV :: Vector Double -> Vector Float double2FloatV = tog c_double2float -double2IntV :: Vector Double -> Vector CInt +double2IntV :: Vector Double -> Vector Int32 double2IntV = tog c_double2int -int2DoubleV :: Vector CInt -> Vector Double +int2DoubleV :: Vector Int32 -> Vector Double int2DoubleV = tog c_int2double double2longV :: Vector Double -> Vector Z @@ -448,10 +450,10 @@ long2DoubleV :: Vector Z -> Vector Double long2DoubleV = tog c_long2double -float2IntV :: Vector Float -> Vector CInt +float2IntV :: Vector Float -> Vector Int32 float2IntV = tog c_float2int -int2floatV :: Vector CInt -> Vector Float +int2floatV :: Vector Int32 -> Vector Float int2floatV = tog c_int2float int2longV :: Vector I -> Vector Z @@ -462,7 +464,7 @@ long2intV = tog c_long2int tog :: (Storable t, Storable a) - => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a + => (Int32 -> Ptr t -> Int32 -> Ptr a -> IO Int32) -> Vector t -> Vector a tog f v = unsafePerformIO $ do r <- createVector (dim v) (v #! r) f #|"tog" @@ -470,12 +472,12 @@ tog f v = unsafePerformIO $ do foreign import ccall unsafe "float2double" c_float2double :: Float :> Double :> Ok foreign import ccall unsafe "double2float" c_double2float :: Double :> Float :> Ok -foreign import ccall unsafe "int2double" c_int2double :: CInt :> Double :> Ok -foreign import ccall unsafe "double2int" c_double2int :: Double :> CInt :> Ok +foreign import ccall unsafe "int2double" c_int2double :: Int32 :> Double :> Ok +foreign import ccall unsafe "double2int" c_double2int :: Double :> Int32 :> Ok foreign import ccall unsafe "long2double" c_long2double :: Z :> Double :> Ok foreign import ccall unsafe "double2long" c_double2long :: Double :> Z :> Ok -foreign import ccall unsafe "int2float" c_int2float :: CInt :> Float :> Ok -foreign import ccall unsafe "float2int" c_float2int :: Float :> CInt :> Ok +foreign import ccall unsafe "int2float" c_int2float :: Int32 :> Float :> Ok +foreign import ccall unsafe "float2int" c_float2int :: Float :> Int32 :> Ok foreign import ccall unsafe "int2long" c_int2long :: I :> Z :> Ok foreign import ccall unsafe "long2int" c_long2int :: Z :> I :> Ok @@ -483,7 +485,7 @@ foreign import ccall unsafe "long2int" c_long2int :: Z :> I :> Ok --------------------------------------------------------------- stepg :: (Storable t, Storable a) - => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a + => (Int32 -> Ptr t -> Int32 -> Ptr a -> IO Int32) -> Vector t -> Vector a stepg f v = unsafePerformIO $ do r <- createVector (dim v) (v #! r) f #|"step" @@ -495,7 +497,7 @@ stepD = stepg c_stepD stepF :: Vector Float -> Vector Float stepF = stepg c_stepF -stepI :: Vector CInt -> Vector CInt +stepI :: Vector Int32 -> Vector Int32 stepI = stepg c_stepI stepL :: Vector Z -> Vector Z @@ -504,13 +506,13 @@ stepL = stepg c_stepL foreign import ccall unsafe "stepF" c_stepF :: TVV Float foreign import ccall unsafe "stepD" c_stepD :: TVV Double -foreign import ccall unsafe "stepI" c_stepI :: TVV CInt +foreign import ccall unsafe "stepI" c_stepI :: TVV Int32 foreign import ccall unsafe "stepL" c_stepL :: TVV Z -------------------------------------------------------------------------------- conjugateAux :: (Storable t, Storable a) - => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a + => (Int32 -> Ptr t -> Int32 -> Ptr a -> IO Int32) -> Vector t -> Vector a conjugateAux fun x = unsafePerformIO $ do v <- createVector (dim x) (x #! v) fun #|"conjugateAux" @@ -536,22 +538,29 @@ cloneVector v = do -------------------------------------------------------------------------------- -constantAux :: (Storable a1, Storable a) - => (Ptr a1 -> CInt -> Ptr a -> IO CInt) -> a1 -> Int -> Vector a -constantAux fun x n = unsafePerformIO $ do +constantAux :: Storable a => a -> Int -> Vector a +constantAux x n = unsafePerformIO $ do v <- createVector n px <- newArray [x] - (v # id) (fun px) #|"constantAux" + (v # id) (constantStorable px) #|"constantAux" free px return v +constantStorable :: Storable a => Ptr a -> Int32 -> Ptr a -> IO Int32 +constantStorable pval n p = do + val <- peek pval + ($ 0) $ fix $ \iloop i -> when (i t :> Ok foreign import ccall unsafe "constantF" cconstantF :: TConst Float foreign import ccall unsafe "constantR" cconstantR :: TConst Double foreign import ccall unsafe "constantQ" cconstantQ :: TConst (Complex Float) foreign import ccall unsafe "constantC" cconstantC :: TConst (Complex Double) -foreign import ccall unsafe "constantI" cconstantI :: TConst CInt +foreign import ccall unsafe "constantI" cconstantI :: TConst Int32 foreign import ccall unsafe "constantL" cconstantL :: TConst Z ---------------------------------------------------------------------- diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs index 9670187..a0a23bd 100644 --- a/packages/base/src/Numeric/LinearAlgebra.hs +++ b/packages/base/src/Numeric/LinearAlgebra.hs @@ -167,7 +167,7 @@ module Numeric.LinearAlgebra ( haussholder, optimiseMult, udot, nullspaceSVD, orthSVD, ranksv, iC, sym, mTm, trustSym, unSym, -- * Auxiliary classes - Element, Container, Product, Numeric, LSDiv, Herm, + Container, Product, Numeric, LSDiv, Herm, Complexable, RealElement, RealOf, ComplexOf, SingleOf, DoubleOf, IndexOf, diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs index 97cfd01..12eddb2 100644 --- a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs @@ -33,6 +33,7 @@ import System.Random import Numeric.LinearAlgebra.HMatrix hiding (vector) import Control.Monad(replicateM) import Test.QuickCheck(Arbitrary,arbitrary,choose,vector,sized,shrink) +import Foreign.Storable import GHC.TypeLits import Data.Proxy (Proxy(..)) @@ -69,7 +70,7 @@ instance KnownNat n => Arbitrary (Static.R n) where shrink _v = [] -instance (Element a, Arbitrary a) => Arbitrary (Matrix a) where +instance (Storable a, Arbitrary a) => Arbitrary (Matrix a) where arbitrary = do m <- chooseDim n <- chooseDim @@ -98,7 +99,7 @@ instance (KnownNat n, KnownNat m) => Arbitrary (Static.L m n) where -- a square matrix newtype (Sq a) = Sq (Matrix a) deriving Show -instance (Element a, Arbitrary a) => Arbitrary (Sq a) where +instance (Storable a, Arbitrary a) => Arbitrary (Sq a) where arbitrary = do n <- chooseDim l <- vector (n*n) @@ -141,7 +142,7 @@ instance (Field a, Arbitrary a, Num (Vector a)) => Arbitrary (Herm a) where return $ sym m' -class (Field a, Arbitrary a, Element (RealOf a), Random (RealOf a)) => ArbitraryField a +class (Field a, Arbitrary a, Storable (RealOf a), Random (RealOf a)) => ArbitraryField a instance ArbitraryField Double instance ArbitraryField (Complex Double) -- cgit v1.2.3