From 379f6a9855a36979c0670a3f89b6c7202836369c Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Fri, 5 Jun 2015 16:36:08 +0200 Subject: move cg --- packages/base/src/Internal/CG.hs | 177 +++++++++++++++++++++ packages/base/src/Numeric/LinearAlgebra/Util/CG.hs | 171 -------------------- 2 files changed, 177 insertions(+), 171 deletions(-) create mode 100644 packages/base/src/Internal/CG.hs delete mode 100644 packages/base/src/Numeric/LinearAlgebra/Util/CG.hs diff --git a/packages/base/src/Internal/CG.hs b/packages/base/src/Internal/CG.hs new file mode 100644 index 0000000..1193b18 --- /dev/null +++ b/packages/base/src/Internal/CG.hs @@ -0,0 +1,177 @@ +{-# LANGUAGE FlexibleContexts, FlexibleInstances #-} +{-# LANGUAGE RecordWildCards #-} + +module Internal.CG( + cgSolve, cgSolve', + CGState(..), R, V +) where + +import Internal.Vector +import Internal.Matrix hiding (mat) +import Internal.Numeric +import Internal.Element +import Internal.IO +import Internal.Container +import Internal.Sparse +import Numeric.Vector() +import Internal.Algorithms(linearSolveLS, relativeError, pnorm, NormType(..)) +import Control.Arrow((***)) +import Data.Vector.Storable(fromList) + +{- +import Util.Misc(debug, debugMat) + +(//) :: Show a => a -> String -> a +infix 0 // -- , /// +a // b = debug b id a + +(///) :: V -> String -> V +infix 0 /// +v /// b = debugMat b 2 asRow v +-} + +type R = Double +type V = Vector R + +data CGState = CGState + { cgp :: V -- ^ conjugate gradient + , cgr :: V -- ^ residual + , cgr2 :: R -- ^ squared norm of residual + , cgx :: V -- ^ current solution + , cgdx :: R -- ^ normalized size of correction + } + +cg :: Bool -> (V -> V) -> (V -> V) -> CGState -> CGState +cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx + where + ap1 = a p + ap | sym = ap1 + | otherwise = at ap1 + pap | sym = p <·> ap1 + | otherwise = norm2 ap1 ** 2 + alpha = r2 / pap + dx = scale alpha p + x' = x + dx + r' = r - scale alpha ap + r'2 = r' <·> r' + beta = r'2 / r2 + p' = r' + scale beta p + + rdx = norm2 dx / max 1 (norm2 x) + +conjugrad + :: Bool -> GMatrix -> V -> V -> R -> R -> [CGState] +conjugrad sym a b = solveG (tr a !#>) (a !#>) (cg sym) b + +solveG + :: (V -> V) -> (V -> V) + -> ((V -> V) -> (V -> V) -> CGState -> CGState) + -> V + -> V + -> R -> R + -> [CGState] +solveG mat ma meth rawb x0' ϵb ϵx + = takeUntil ok . iterate (meth mat ma) $ CGState p0 r0 r20 x0 1 + where + a = mat . ma + b = mat rawb + x0 = if x0' == 0 then konst 0 (dim b) else x0' + r0 = b - a x0 + r20 = r0 <·> r0 + p0 = r0 + nb2 = b <·> b + ok CGState {..} + = cgr2 Bool) -> [a] -> [a] +takeUntil q xs = a++ take 1 b + where + (a,b) = break q xs + +cgSolve + :: Bool -- ^ is symmetric + -> GMatrix -- ^ coefficient matrix + -> Vector Double -- ^ right-hand side + -> Vector Double -- ^ solution +cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0 + where + n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double)) + +cgSolve' + :: Bool -- ^ symmetric + -> R -- ^ relative tolerance for the residual (e.g. 1E-4) + -> R -- ^ relative tolerance for δx (e.g. 1E-3) + -> Int -- ^ maximum number of iterations + -> GMatrix -- ^ coefficient matrix + -> V -- ^ initial solution + -> V -- ^ right-hand side + -> [CGState] -- ^ solution +cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es + + +-------------------------------------------------------------------------------- + +instance Testable GMatrix + where + checkT _ = (ok,info) + where + sma = convo2 20 3 + x1 = vect [1..20] + x2 = vect [1..40] + sm = mkSparse sma + dm = toDense sma + + s1 = sm !#> x1 + d1 = dm #> x1 + + s2 = tr sm !#> x2 + d2 = tr dm #> x2 + + sdia = mkDiagR 40 20 (vect [1..10]) + s3 = sdia !#> x1 + s4 = tr sdia !#> x2 + ddia = diagRect 0 (vect [1..10]) 40 20 + d3 = ddia #> x1 + d4 = tr ddia #> x2 + + v = testb 40 + s5 = cgSolve False sm v + d5 = denseSolve dm v + + info = do + print sm + disp (toDense sma) + print s1; print d1 + print s2; print d2 + print s3; print d3 + print s4; print d4 + print s5; print d5 + print $ relativeError (pnorm Infinity) s5 d5 + + ok = s1==d1 + && s2==d2 + && s3==d3 + && s4==d4 + && relativeError (pnorm Infinity) s5 d5 < 1E-10 + + disp = putStr . dispf 2 + + vect = fromList :: [Double] -> Vector Double + + convomat :: Int -> Int -> AssocMatrix + convomat n k = [ ((i,j `mod` n),1) | i<-[0..n-1], j <- [i..i+k-1]] + + convo2 :: Int -> Int -> AssocMatrix + convo2 n k = m1 ++ m2 + where + m1 = convomat n k + m2 = map (((+n) *** id) *** id) m1 + + testb n = vect $ take n $ cycle ([0..10]++[9,8..1]) + + denseSolve a = flatten . linearSolveLS a . asColumn + + -- mkDiag v = mkDiagR (dim v) (dim v) v + diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs deleted file mode 100644 index 899a5bf..0000000 --- a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs +++ /dev/null @@ -1,171 +0,0 @@ -{-# LANGUAGE FlexibleContexts, FlexibleInstances #-} -{-# LANGUAGE RecordWildCards #-} - -module Numeric.LinearAlgebra.Util.CG( - cgSolve, cgSolve', - CGState(..), R, V -) where - -import Data.Packed.Numeric -import Numeric.Sparse -import Numeric.Vector() -import Numeric.LinearAlgebra.Algorithms(linearSolveLS, relativeError, pnorm, NormType(..)) -import Control.Arrow((***)) - -{- -import Util.Misc(debug, debugMat) - -(//) :: Show a => a -> String -> a -infix 0 // -- , /// -a // b = debug b id a - -(///) :: V -> String -> V -infix 0 /// -v /// b = debugMat b 2 asRow v --} - -type R = Double -type V = Vector R - -data CGState = CGState - { cgp :: V -- ^ conjugate gradient - , cgr :: V -- ^ residual - , cgr2 :: R -- ^ squared norm of residual - , cgx :: V -- ^ current solution - , cgdx :: R -- ^ normalized size of correction - } - -cg :: Bool -> (V -> V) -> (V -> V) -> CGState -> CGState -cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx - where - ap1 = a p - ap | sym = ap1 - | otherwise = at ap1 - pap | sym = p <·> ap1 - | otherwise = norm2 ap1 ** 2 - alpha = r2 / pap - dx = scale alpha p - x' = x + dx - r' = r - scale alpha ap - r'2 = r' <·> r' - beta = r'2 / r2 - p' = r' + scale beta p - - rdx = norm2 dx / max 1 (norm2 x) - -conjugrad - :: Bool -> GMatrix -> V -> V -> R -> R -> [CGState] -conjugrad sym a b = solveG (tr a !#>) (a !#>) (cg sym) b - -solveG - :: (V -> V) -> (V -> V) - -> ((V -> V) -> (V -> V) -> CGState -> CGState) - -> V - -> V - -> R -> R - -> [CGState] -solveG mat ma meth rawb x0' ϵb ϵx - = takeUntil ok . iterate (meth mat ma) $ CGState p0 r0 r20 x0 1 - where - a = mat . ma - b = mat rawb - x0 = if x0' == 0 then konst 0 (dim b) else x0' - r0 = b - a x0 - r20 = r0 <·> r0 - p0 = r0 - nb2 = b <·> b - ok CGState {..} - = cgr2 Bool) -> [a] -> [a] -takeUntil q xs = a++ take 1 b - where - (a,b) = break q xs - -cgSolve - :: Bool -- ^ is symmetric - -> GMatrix -- ^ coefficient matrix - -> Vector Double -- ^ right-hand side - -> Vector Double -- ^ solution -cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0 - where - n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double)) - -cgSolve' - :: Bool -- ^ symmetric - -> R -- ^ relative tolerance for the residual (e.g. 1E-4) - -> R -- ^ relative tolerance for δx (e.g. 1E-3) - -> Int -- ^ maximum number of iterations - -> GMatrix -- ^ coefficient matrix - -> V -- ^ initial solution - -> V -- ^ right-hand side - -> [CGState] -- ^ solution -cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es - - --------------------------------------------------------------------------------- - -instance Testable GMatrix - where - checkT _ = (ok,info) - where - sma = convo2 20 3 - x1 = vect [1..20] - x2 = vect [1..40] - sm = mkSparse sma - dm = toDense sma - - s1 = sm !#> x1 - d1 = dm #> x1 - - s2 = tr sm !#> x2 - d2 = tr dm #> x2 - - sdia = mkDiagR 40 20 (vect [1..10]) - s3 = sdia !#> x1 - s4 = tr sdia !#> x2 - ddia = diagRect 0 (vect [1..10]) 40 20 - d3 = ddia #> x1 - d4 = tr ddia #> x2 - - v = testb 40 - s5 = cgSolve False sm v - d5 = denseSolve dm v - - info = do - print sm - disp (toDense sma) - print s1; print d1 - print s2; print d2 - print s3; print d3 - print s4; print d4 - print s5; print d5 - print $ relativeError (pnorm Infinity) s5 d5 - - ok = s1==d1 - && s2==d2 - && s3==d3 - && s4==d4 - && relativeError (pnorm Infinity) s5 d5 < 1E-10 - - disp = putStr . dispf 2 - - vect = fromList :: [Double] -> Vector Double - - convomat :: Int -> Int -> AssocMatrix - convomat n k = [ ((i,j `mod` n),1) | i<-[0..n-1], j <- [i..i+k-1]] - - convo2 :: Int -> Int -> AssocMatrix - convo2 n k = m1 ++ m2 - where - m1 = convomat n k - m2 = map (((+n) *** id) *** id) m1 - - testb n = vect $ take n $ cycle ([0..10]++[9,8..1]) - - denseSolve a = flatten . linearSolveLS a . asColumn - - -- mkDiag v = mkDiagR (dim v) (dim v) v - -- cgit v1.2.3