From 53559833d2166010eed754027484fb8d5525e710 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Tue, 27 May 2014 20:21:47 +0200 Subject: expose CSR --- packages/base/src/Numeric/LinearAlgebra/Data.hs | 6 +- packages/base/src/Numeric/LinearAlgebra/Real.hs | 10 +++ packages/base/src/Numeric/LinearAlgebra/Util.hs | 7 ++ packages/base/src/Numeric/LinearAlgebra/Util/CG.hs | 2 +- packages/base/src/Numeric/Sparse.hs | 92 ++++++++++++++-------- 5 files changed, 82 insertions(+), 35 deletions(-) (limited to 'packages/base/src/Numeric') diff --git a/packages/base/src/Numeric/LinearAlgebra/Data.hs b/packages/base/src/Numeric/LinearAlgebra/Data.hs index 3417a5e..33a2c9a 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Data.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Data.hs @@ -49,8 +49,8 @@ module Numeric.LinearAlgebra.Data( find, maxIndex, minIndex, maxElement, minElement, atIndex, -- * Sparse - GMatrix, AssocMatrix, mkSparse, toDense, - mkDiagR, dense, + AssocMatrix, toDense, + mkSparse, mkDiagR, mkDense, -- * IO disp, @@ -68,7 +68,7 @@ module Numeric.LinearAlgebra.Data( module Data.Complex, - Vector, Matrix + Vector, Matrix, GMatrix, CSR(..), mkCSR ) where diff --git a/packages/base/src/Numeric/LinearAlgebra/Real.hs b/packages/base/src/Numeric/LinearAlgebra/Real.hs index db15705..1e8b544 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Real.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Real.hs @@ -141,6 +141,16 @@ instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t))) negate = (lift1F . lift1F) negate fromInteger x = Dim (Dim (fromInteger x)) +instance Fractional (Dim n (Vector Double)) + where + fromRational x = Dim (fromRational x) + (/) = lift2F (/) + +instance Fractional (Dim m (Dim n (Matrix Double))) + where + fromRational x = Dim (Dim (fromRational x)) + (/) = (lift2F.lift2F) (/) + -------------------------------------------------------------------------------- class Konst t diff --git a/packages/base/src/Numeric/LinearAlgebra/Util.hs b/packages/base/src/Numeric/LinearAlgebra/Util.hs index 47b1090..aee21b8 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Util.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Util.hs @@ -37,6 +37,7 @@ module Numeric.LinearAlgebra.Util( mnorm_1, mnorm_2, mnorm_0, mnorm_Inf, unitary, mt, + (~!~), pairwiseD2, rowOuters, null1, @@ -65,6 +66,7 @@ import Numeric.Matrix() import Numeric.Vector() import Numeric.LinearAlgebra.Random import Numeric.LinearAlgebra.Util.Convolution +import Control.Monad(when) type ℝ = Double type ℕ = Int @@ -385,3 +387,8 @@ vtrans p m | r == 0 = fromBlocks . map (map asColumn . takesV (replicate q p)) . where (q,r) = divMod (rows m) p +-------------------------------------------------------------------------------- + +infixl 0 ~!~ +c ~!~ msg = when c (error msg) + diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs index 50372f1..f821b57 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs @@ -115,7 +115,7 @@ instance Testable GMatrix sma = convo2 20 3 x1 = vect [1..20] x2 = vect [1..40] - sm = mkSparse sma + sm = (mkSparse . mkCSR) sma dm = toDense sma s1 = sm !#> x1 diff --git a/packages/base/src/Numeric/Sparse.hs b/packages/base/src/Numeric/Sparse.hs index 4d05bdc..3c19c93 100644 --- a/packages/base/src/Numeric/Sparse.hs +++ b/packages/base/src/Numeric/Sparse.hs @@ -3,8 +3,8 @@ {-# LANGUAGE FlexibleInstances #-} module Numeric.Sparse( - GMatrix(..), - mkSparse, mkDiagR, dense, + GMatrix, CSR(..), mkCSR, + mkSparse, mkDiagR, mkDense, AssocMatrix, toDense, gmXv, (!#>) @@ -28,18 +28,49 @@ c ~!~ msg = when c (error msg) type AssocMatrix = [((Int,Int),Double)] +data CSR = CSR + { csrVals :: Vector Double + , csrCols :: Vector CInt + , csrRows :: Vector CInt + , csrNRows :: Int + , csrNCols :: Int + } deriving Show + +data CSC = CSC + { cscVals :: Vector Double + , cscRows :: Vector CInt + , cscCols :: Vector CInt + , cscNRows :: Int + , cscNCols :: Int + } deriving Show + + +mkCSR :: AssocMatrix -> CSR +mkCSR sm' = CSR{..} + where + sm = sort sm' + rws = map ((fromList *** fromList) + . unzip + . map ((succ.fi.snd) *** id) + ) + . groupBy ((==) `on` (fst.fst)) + $ sm + rszs = map (fi . dim . fst) rws + csrRows = fromList (scanl (+) 1 rszs) + csrVals = vjoin (map snd rws) + csrCols = vjoin (map fst rws) + csrNRows = dim csrRows - 1 + csrNCols = fromIntegral (V.maximum csrCols) + + data GMatrix - = CSR - { csrVals :: Vector Double - , csrCols :: Vector CInt - , csrRows :: Vector CInt + = SparseR + { gmCSR :: CSR , nRows :: Int , nCols :: Int } - | CSC - { cscVals :: Vector Double - , cscRows :: Vector CInt - , cscCols :: Vector CInt + | SparseC + { gmCSC :: CSC , nRows :: Int , nCols :: Int } @@ -56,29 +87,21 @@ data GMatrix -- | Banded deriving Show -dense :: Matrix Double -> GMatrix -dense m = Dense{..} + +mkDense :: Matrix Double -> GMatrix +mkDense m = Dense{..} where gmDense = m nRows = rows m nCols = cols m -mkSparse :: AssocMatrix -> GMatrix -mkSparse sm' = CSR{..} + +mkSparse :: CSR -> GMatrix +mkSparse csr = SparseR {..} where - sm = sort sm' - rws = map ((fromList *** fromList) - . unzip - . map ((succ.fi.snd) *** id) - ) - . groupBy ((==) `on` (fst.fst)) - $ sm - rszs = map (fi . dim . fst) rws - csrRows = fromList (scanl (+) 1 rszs) - csrVals = vjoin (map snd rws) - csrCols = vjoin (map fst rws) - nRows = dim csrRows - 1 - nCols = fromIntegral (V.maximum csrCols) + gmCSR @ CSR {..} = csr + nRows = csrNRows + nCols = csrNCols mkDiagR r c v @@ -95,13 +118,13 @@ type V t = CInt -> Ptr Double -> t type SMxV = V (IV (IV (V (V (IO CInt))))) gmXv :: GMatrix -> Vector Double -> Vector Double -gmXv CSR{..} v = unsafePerformIO $ do +gmXv SparseR { gmCSR = CSR{..}, .. } v = unsafePerformIO $ do dim v /= nCols ~!~ printf "gmXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) r <- createVector nRows app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv" return r -gmXv CSC{..} v = unsafePerformIO $ do +gmXv SparseC { gmCSC = CSC{..}, .. } v = unsafePerformIO $ do dim v /= nCols ~!~ printf "gmXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) r <- createVector nRows app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv" @@ -147,11 +170,18 @@ toDense asm = assoc (r+1,c+1) 0 asm (r,c) = (maximum *** maximum) . unzip . map fst $ asm - -instance Transposable GMatrix GMatrix +instance Transposable CSR CSC where tr (CSR vs cs rs n m) = CSC vs cs rs m n + +instance Transposable CSC CSR + where tr (CSC vs rs cs n m) = CSR vs rs cs m n + +instance Transposable GMatrix GMatrix + where + tr (SparseR s n m) = SparseC (tr s) m n + tr (SparseC s n m) = SparseR (tr s) m n tr (Diag v n m) = Diag v m n tr (Dense a n m) = Dense (tr a) m n -- cgit v1.2.3