From cf3c788f0c44577ac1a5365e8154200b53a36409 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Tue, 27 May 2014 10:41:40 +0200 Subject: static dimensions, cont. --- packages/base/src/Numeric/Sparse.hs | 127 ++++++++++++------------------------ 1 file changed, 42 insertions(+), 85 deletions(-) (limited to 'packages/base/src/Numeric/Sparse.hs') diff --git a/packages/base/src/Numeric/Sparse.hs b/packages/base/src/Numeric/Sparse.hs index 2df4578..4d05bdc 100644 --- a/packages/base/src/Numeric/Sparse.hs +++ b/packages/base/src/Numeric/Sparse.hs @@ -3,11 +3,11 @@ {-# LANGUAGE FlexibleInstances #-} module Numeric.Sparse( - SMatrix(..), - mkCSR, mkDiag, + GMatrix(..), + mkSparse, mkDiagR, dense, AssocMatrix, toDense, - smXv + gmXv, (!#>) )where import Data.Packed.Numeric @@ -17,8 +17,7 @@ import Control.Arrow((***)) import Control.Monad(when) import Data.List(groupBy, sort) import Foreign.C.Types(CInt(..)) -import Numeric.LinearAlgebra.Util.CG(CGMat,cgSolve) -import Numeric.LinearAlgebra.Algorithms(linearSolveLS, relativeError, NormType(..)) + import Data.Packed.Development import System.IO.Unsafe(unsafePerformIO) import Foreign(Ptr) @@ -29,7 +28,7 @@ c ~!~ msg = when c (error msg) type AssocMatrix = [((Int,Int),Double)] -data SMatrix +data GMatrix = CSR { csrVals :: Vector Double , csrCols :: Vector CInt @@ -46,14 +45,26 @@ data SMatrix } | Diag { diagVals :: Vector Double + , nRows :: Int + , nCols :: Int + } + | Dense + { gmDense :: Matrix Double , nRows :: Int , nCols :: Int } -- | Banded deriving Show -mkCSR :: AssocMatrix -> SMatrix -mkCSR sm' = CSR{..} +dense :: Matrix Double -> GMatrix +dense m = Dense{..} + where + gmDense = m + nRows = rows m + nCols = cols m + +mkSparse :: AssocMatrix -> GMatrix +mkSparse sm' = CSR{..} where sm = sort sm' rws = map ((fromList *** fromList) @@ -78,37 +89,47 @@ mkDiagR r c v nCols = c diagVals = v -mkDiag v = mkDiagR (dim v) (dim v) v - type IV t = CInt -> Ptr CInt -> t type V t = CInt -> Ptr Double -> t type SMxV = V (IV (IV (V (V (IO CInt))))) -smXv :: SMatrix -> Vector Double -> Vector Double -smXv CSR{..} v = unsafePerformIO $ do - dim v /= nCols ~!~ printf "smXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) +gmXv :: GMatrix -> Vector Double -> Vector Double +gmXv CSR{..} v = unsafePerformIO $ do + dim v /= nCols ~!~ printf "gmXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) r <- createVector nRows app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv" return r -smXv CSC{..} v = unsafePerformIO $ do - dim v /= nCols ~!~ printf "smXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) +gmXv CSC{..} v = unsafePerformIO $ do + dim v /= nCols ~!~ printf "gmXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) r <- createVector nRows app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv" return r -smXv Diag{..} v +gmXv Diag{..} v | dim v == nCols = vjoin [ subVector 0 (dim diagVals) v `mul` diagVals , konst 0 (nRows - dim diagVals) ] - | otherwise = error $ printf "smXv (Diag): incorrect sizes: (%d,%d) [%d] x %d" + | otherwise = error $ printf "gmXv (Diag): incorrect sizes: (%d,%d) [%d] x %d" nRows nCols (dim diagVals) (dim v) +gmXv Dense{..} v + | dim v == nCols + = mXv gmDense v + | otherwise = error $ printf "gmXv (Dense): incorrect sizes: (%d,%d) x %d" + nRows nCols (dim v) + -instance Contraction SMatrix (Vector Double) (Vector Double) +-- | general matrix - vector product +infixr 8 !#> +(!#>) :: GMatrix -> Vector Double -> Vector Double +(!#>) = gmXv + + +instance Contraction GMatrix (Vector Double) (Vector Double) where - contraction = smXv + contraction = gmXv -------------------------------------------------------------------------------- @@ -127,75 +148,11 @@ toDense asm = assoc (r+1,c+1) 0 asm -instance Transposable SMatrix +instance Transposable GMatrix GMatrix where tr (CSR vs cs rs n m) = CSC vs cs rs m n tr (CSC vs rs cs n m) = CSR vs rs cs m n tr (Diag v n m) = Diag v m n + tr (Dense a n m) = Dense (tr a) m n -instance CGMat SMatrix -instance CGMat (Matrix Double) - --------------------------------------------------------------------------------- - -instance Testable SMatrix - where - checkT _ = (ok,info) - where - sma = convo2 20 3 - x1 = vect [1..20] - x2 = vect [1..40] - sm = mkCSR sma - dm = toDense sma - - s1 = sm ◇ x1 - d1 = dm ◇ x1 - - s2 = tr sm ◇ x2 - d2 = tr dm ◇ x2 - - sdia = mkDiagR 40 20 (vect [1..10]) - s3 = sdia ◇ x1 - s4 = tr sdia ◇ x2 - ddia = diagRect 0 (vect [1..10]) 40 20 - d3 = ddia ◇ x1 - d4 = tr ddia ◇ x2 - - v = testb 40 - s5 = cgSolve False sm v - d5 = denseSolve dm v - - info = do - print sm - disp (toDense sma) - print s1; print d1 - print s2; print d2 - print s3; print d3 - print s4; print d4 - print s5; print d5 - print $ relativeError Infinity s5 d5 - - ok = s1==d1 - && s2==d2 - && s3==d3 - && s4==d4 - && relativeError Infinity s5 d5 < 1E-10 - - disp = putStr . dispf 2 - - vect = fromList :: [Double] -> Vector Double - - convomat :: Int -> Int -> AssocMatrix - convomat n k = [ ((i,j `mod` n),1) | i<-[0..n-1], j <- [i..i+k-1]] - - convo2 :: Int -> Int -> AssocMatrix - convo2 n k = m1 ++ m2 - where - m1 = convomat n k - m2 = map (((+n) *** id) *** id) m1 - - testb n = vect $ take n $ cycle ([0..10]++[9,8..1]) - - denseSolve a = flatten . linearSolveLS a . asColumn - -- cgit v1.2.3