From 85af0a1d5ba2d1c03f05458f9689195e82f6ae7e Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Thu, 22 May 2014 20:09:41 +0200 Subject: cgSolve --- packages/base/src/Data/Packed/Internal/Sparse.hs | 191 -------------------- packages/base/src/Numeric/LinearAlgebra/Data.hs | 2 +- packages/base/src/Numeric/LinearAlgebra/Util/CG.hs | 96 +++++++++++ packages/base/src/Numeric/Sparse.hs | 192 +++++++++++++++++++++ 4 files changed, 289 insertions(+), 192 deletions(-) delete mode 100644 packages/base/src/Data/Packed/Internal/Sparse.hs create mode 100644 packages/base/src/Numeric/LinearAlgebra/Util/CG.hs create mode 100644 packages/base/src/Numeric/Sparse.hs (limited to 'packages/base/src') diff --git a/packages/base/src/Data/Packed/Internal/Sparse.hs b/packages/base/src/Data/Packed/Internal/Sparse.hs deleted file mode 100644 index 544c913..0000000 --- a/packages/base/src/Data/Packed/Internal/Sparse.hs +++ /dev/null @@ -1,191 +0,0 @@ -{-# LANGUAGE RecordWildCards #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE FlexibleInstances #-} - - - -module Data.Packed.Internal.Sparse( - SMatrix(..), - mkCSR, mkDiag, - AssocMatrix, - toDense, - smXv -)where - -import Numeric.Container -import qualified Data.Vector.Storable as V -import Data.Function(on) -import Control.Arrow((***)) -import Control.Monad(when) -import Data.List(groupBy, sort) -import Foreign.C.Types(CInt(..)) -import Numeric.LinearAlgebra.Devel -import System.IO.Unsafe(unsafePerformIO) -import Foreign(Ptr) -import Text.Printf(printf) - -infixl 0 ~!~ -c ~!~ msg = when c (error msg) - -type AssocMatrix = [((Int,Int),Double)] - -data SMatrix - = CSR - { csrVals :: Vector Double - , csrCols :: Vector CInt - , csrRows :: Vector CInt - , nRows :: Int - , nCols :: Int - } - | CSC - { cscVals :: Vector Double - , cscRows :: Vector CInt - , cscCols :: Vector CInt - , nRows :: Int - , nCols :: Int - } - | Diag - { diagVals :: Vector Double - , nRows :: Int - , nCols :: Int - } --- | Banded - deriving Show - -mkCSR :: AssocMatrix -> SMatrix -mkCSR sm' = CSR{..} - where - sm = sort sm' - rws = map ((fromList *** fromList) - . unzip - . map ((succ.fi.snd) *** id) - ) - . groupBy ((==) `on` (fst.fst)) - $ sm - rszs = map (fi . dim . fst) rws - csrRows = fromList (scanl (+) 1 rszs) - csrVals = vjoin (map snd rws) - csrCols = vjoin (map fst rws) - nRows = dim csrRows - 1 - nCols = fromIntegral (V.maximum csrCols) - - -mkDiagR r c v - | dim v <= min r c = Diag{..} - | otherwise = error $ printf "mkDiagR: incorrect sizes (%d,%d) [%d]" r c (dim v) - where - nRows = r - nCols = c - diagVals = v - -mkDiag v = mkDiagR (dim v) (dim v) v - - -type IV t = CInt -> Ptr CInt -> t -type V t = CInt -> Ptr Double -> t -type SMxV = V (IV (IV (V (V (IO CInt))))) - -smXv :: SMatrix -> Vector Double -> Vector Double -smXv CSR{..} v = unsafePerformIO $ do - dim v /= nCols ~!~ printf "smXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) - r <- createVector nRows - app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv" - return r - -smXv CSC{..} v = unsafePerformIO $ do - dim v /= nCols ~!~ printf "smXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) - r <- createVector nRows - app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv" - return r - -smXv Diag{..} v - | dim v == nCols - = vjoin [ subVector 0 (dim diagVals) v `mul` diagVals - , konst 0 (nRows - dim diagVals) ] - | otherwise = error $ printf "smXv (Diag): incorrect sizes: (%d,%d) [%d] x %d" - nRows nCols (dim diagVals) (dim v) - - -instance Contraction SMatrix (Vector Double) (Vector Double) - where - contraction = smXv - --------------------------------------------------------------------------------- - -foreign import ccall unsafe "smXv" - c_smXv :: SMxV - -foreign import ccall unsafe "smTXv" - c_smTXv :: SMxV - --------------------------------------------------------------------------------- - -toDense :: AssocMatrix -> Matrix Double -toDense asm = assoc (r+1,c+1) 0 asm - where - (r,c) = (maximum *** maximum) . unzip . map fst $ asm - - - -instance Transposable (SMatrix) - where - tr (CSR vs cs rs n m) = CSC vs cs rs m n - tr (CSC vs rs cs n m) = CSR vs rs cs m n - tr (Diag v n m) = Diag v m n - -instance Transposable (Matrix Double) - where - tr = trans - - - --------------------------------------------------------------------------------- - -instance Testable SMatrix - where - checkT _ = (ok,info) - where - sma = convo2 20 3 - x1 = vect [1..20] - x2 = vect [1..40] - sm = mkCSR sma - - s1 = sm ◇ x1 - d1 = toDense sma ◇ x1 - - s2 = tr sm ◇ x2 - d2 = tr (toDense sma) ◇ x2 - - sdia = mkDiagR 40 20 (vect [1..10]) - s3 = sdia ◇ x1 - s4 = tr sdia ◇ x2 - ddia = diagRect 0 (vect [1..10]) 40 20 - d3 = ddia ◇ x1 - d4 = tr ddia ◇ x2 - - info = do - print sm - disp (toDense sma) - print s1; print d1 - print s2; print d2 - print s3; print d3 - print s4; print d4 - - ok = s1==d1 - && s2==d2 - && s3==d3 - && s4==d4 - - disp = putStr . dispf 2 - - vect = fromList :: [Double] -> Vector Double - - convomat :: Int -> Int -> AssocMatrix - convomat n k = [ ((i,j `mod` n),1) | i<-[0..n-1], j <- [i..i+k-1]] - - convo2 :: Int -> Int -> AssocMatrix - convo2 n k = m1 ++ m2 - where - m1 = convomat n k - m2 = map (((+n) *** id) *** id) m1 - diff --git a/packages/base/src/Numeric/LinearAlgebra/Data.hs b/packages/base/src/Numeric/LinearAlgebra/Data.hs index 49bc1c0..e3cbe31 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Data.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Data.hs @@ -69,5 +69,5 @@ import Data.Packed.Matrix import Numeric.Container import Numeric.LinearAlgebra.Util import Data.Complex -import Data.Packed.Internal.Sparse +import Numeric.Sparse diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs new file mode 100644 index 0000000..2c782e8 --- /dev/null +++ b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs @@ -0,0 +1,96 @@ +{-# LANGUAGE FlexibleContexts, FlexibleInstances #-} +{-# LANGUAGE RecordWildCards #-} + +module Numeric.LinearAlgebra.Util.CG( + cgSolve, + CGMat +) where + +import Numeric.Container +import Numeric.Vector() + +{- +import Util.Misc(debug, debugMat) + +(//) :: Show a => a -> String -> a +infix 0 // -- , /// +a // b = debug b id a + +(///) :: DV -> String -> DV +infix 0 /// +v /// b = debugMat b 2 asRow v +-} + + +type DV = Vector Double + +data CGState = CGState + { cgp :: DV + , cgr :: DV + , cgr2 :: Double + , cgx :: DV + , cgdx :: Double + } + +cg :: Bool -> (DV -> DV) -> (DV -> DV) -> CGState -> CGState +cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx + where + ap1 = a p + ap | sym = ap1 + | otherwise = at ap1 + pap | sym = p ◇ ap1 + | otherwise = norm2 ap1 ** 2 + alpha = r2 / pap + dx = scale alpha p + x' = x + dx + r' = r - scale alpha ap + r'2 = r' ◇ r' + beta = r'2 / r2 + p' = r' + scale beta p + + rdx = norm2 dx / max 1 (norm2 x) + +conjugrad + :: (Transposable m, Contraction m DV DV) + => Bool -> m -> DV -> DV -> Double -> Double -> [CGState] +conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b + +solveG + :: (DV -> DV) -> (DV -> DV) + -> ((DV -> DV) -> (DV -> DV) -> CGState -> CGState) + -> DV + -> DV + -> Double -> Double + -> [CGState] +solveG mat ma meth rawb x0' ϵb ϵx + = takeUntil ok . iterate (meth mat ma) $ CGState p0 r0 r20 x0 1 + where + a = mat . ma + b = mat rawb + x0 = if x0' == 0 then konst 0 (dim b) else x0' + r0 = b - a x0 + r20 = r0 ◇ r0 + p0 = r0 + nb2 = b ◇ b + ok CGState {..} + = cgr2 Bool) -> [a] -> [a] +takeUntil q xs = a++ take 1 b + where + (a,b) = break q xs + +class (Transposable m, Contraction m (Vector Double) (Vector Double)) => CGMat m + +cgSolve + :: CGMat m + => Bool -- ^ symmetric + -> Double -- ^ relative tolerance for the residual (e.g. 1E-4) + -> Double -- ^ relative tolerance for δx (e.g. 1E-3) + -> m -- ^ coefficient matrix + -> Vector Double -- ^ right-hand side + -> Vector Double -- ^ solution +cgSolve sym er es a b = cgx $ last $ conjugrad sym a b 0 er es + diff --git a/packages/base/src/Numeric/Sparse.hs b/packages/base/src/Numeric/Sparse.hs new file mode 100644 index 0000000..3835590 --- /dev/null +++ b/packages/base/src/Numeric/Sparse.hs @@ -0,0 +1,192 @@ +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE FlexibleInstances #-} + +module Numeric.Sparse( + SMatrix(..), + mkCSR, mkDiag, + AssocMatrix, + toDense, + smXv +)where + +import Numeric.Container +import qualified Data.Vector.Storable as V +import Data.Function(on) +import Control.Arrow((***)) +import Control.Monad(when) +import Data.List(groupBy, sort) +import Foreign.C.Types(CInt(..)) +import Numeric.LinearAlgebra.Util.CG(CGMat) +import Data.Packed.Development +import System.IO.Unsafe(unsafePerformIO) +import Foreign(Ptr) +import Text.Printf(printf) + +infixl 0 ~!~ +c ~!~ msg = when c (error msg) + +type AssocMatrix = [((Int,Int),Double)] + +data SMatrix + = CSR + { csrVals :: Vector Double + , csrCols :: Vector CInt + , csrRows :: Vector CInt + , nRows :: Int + , nCols :: Int + } + | CSC + { cscVals :: Vector Double + , cscRows :: Vector CInt + , cscCols :: Vector CInt + , nRows :: Int + , nCols :: Int + } + | Diag + { diagVals :: Vector Double + , nRows :: Int + , nCols :: Int + } +-- | Banded + deriving Show + +mkCSR :: AssocMatrix -> SMatrix +mkCSR sm' = CSR{..} + where + sm = sort sm' + rws = map ((fromList *** fromList) + . unzip + . map ((succ.fi.snd) *** id) + ) + . groupBy ((==) `on` (fst.fst)) + $ sm + rszs = map (fi . dim . fst) rws + csrRows = fromList (scanl (+) 1 rszs) + csrVals = vjoin (map snd rws) + csrCols = vjoin (map fst rws) + nRows = dim csrRows - 1 + nCols = fromIntegral (V.maximum csrCols) + + +mkDiagR r c v + | dim v <= min r c = Diag{..} + | otherwise = error $ printf "mkDiagR: incorrect sizes (%d,%d) [%d]" r c (dim v) + where + nRows = r + nCols = c + diagVals = v + +mkDiag v = mkDiagR (dim v) (dim v) v + + +type IV t = CInt -> Ptr CInt -> t +type V t = CInt -> Ptr Double -> t +type SMxV = V (IV (IV (V (V (IO CInt))))) + +smXv :: SMatrix -> Vector Double -> Vector Double +smXv CSR{..} v = unsafePerformIO $ do + dim v /= nCols ~!~ printf "smXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) + r <- createVector nRows + app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv" + return r + +smXv CSC{..} v = unsafePerformIO $ do + dim v /= nCols ~!~ printf "smXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) + r <- createVector nRows + app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv" + return r + +smXv Diag{..} v + | dim v == nCols + = vjoin [ subVector 0 (dim diagVals) v `mul` diagVals + , konst 0 (nRows - dim diagVals) ] + | otherwise = error $ printf "smXv (Diag): incorrect sizes: (%d,%d) [%d] x %d" + nRows nCols (dim diagVals) (dim v) + + +instance Contraction SMatrix (Vector Double) (Vector Double) + where + contraction = smXv + +-------------------------------------------------------------------------------- + +foreign import ccall unsafe "smXv" + c_smXv :: SMxV + +foreign import ccall unsafe "smTXv" + c_smTXv :: SMxV + +-------------------------------------------------------------------------------- + +toDense :: AssocMatrix -> Matrix Double +toDense asm = assoc (r+1,c+1) 0 asm + where + (r,c) = (maximum *** maximum) . unzip . map fst $ asm + + + +instance Transposable SMatrix + where + tr (CSR vs cs rs n m) = CSC vs cs rs m n + tr (CSC vs rs cs n m) = CSR vs rs cs m n + tr (Diag v n m) = Diag v m n + +instance Transposable (Matrix Double) + where + tr = trans + + +instance CGMat SMatrix +instance CGMat (Matrix Double) + +-------------------------------------------------------------------------------- + +instance Testable SMatrix + where + checkT _ = (ok,info) + where + sma = convo2 20 3 + x1 = vect [1..20] + x2 = vect [1..40] + sm = mkCSR sma + + s1 = sm ◇ x1 + d1 = toDense sma ◇ x1 + + s2 = tr sm ◇ x2 + d2 = tr (toDense sma) ◇ x2 + + sdia = mkDiagR 40 20 (vect [1..10]) + s3 = sdia ◇ x1 + s4 = tr sdia ◇ x2 + ddia = diagRect 0 (vect [1..10]) 40 20 + d3 = ddia ◇ x1 + d4 = tr ddia ◇ x2 + + info = do + print sm + disp (toDense sma) + print s1; print d1 + print s2; print d2 + print s3; print d3 + print s4; print d4 + + ok = s1==d1 + && s2==d2 + && s3==d3 + && s4==d4 + + disp = putStr . dispf 2 + + vect = fromList :: [Double] -> Vector Double + + convomat :: Int -> Int -> AssocMatrix + convomat n k = [ ((i,j `mod` n),1) | i<-[0..n-1], j <- [i..i+k-1]] + + convo2 :: Int -> Int -> AssocMatrix + convo2 n k = m1 ++ m2 + where + m1 = convomat n k + m2 = map (((+n) *** id) *** id) m1 + -- cgit v1.2.3