From 0a9ef8f5b0088c1ac25175bffca4ed95d9e109a5 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Fri, 23 May 2014 10:51:16 +0200 Subject: relativeError, cgSolve' --- packages/base/src/Numeric/LinearAlgebra.hs | 9 ++-- .../base/src/Numeric/LinearAlgebra/Algorithms.hs | 21 ++++++-- packages/base/src/Numeric/LinearAlgebra/Util/CG.hs | 62 +++++++++++++--------- packages/base/src/Numeric/Sparse.hs | 19 +++++-- 4 files changed, 76 insertions(+), 35 deletions(-) (limited to 'packages/base/src/Numeric') diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs index 9e9151e..242122f 100644 --- a/packages/base/src/Numeric/LinearAlgebra.hs +++ b/packages/base/src/Numeric/LinearAlgebra.hs @@ -69,6 +69,7 @@ module Numeric.LinearAlgebra ( luSolve, cholSolve, cgSolve, + cgSolve', -- * Inverse and pseudoinverse inv, pinv, pinvTol, @@ -136,8 +137,8 @@ module Numeric.LinearAlgebra ( RealOf, ComplexOf, SingleOf, DoubleOf, IndexOf, Field, Normed, - CGMat, Transposable - + CGMat, Transposable, + R,V ) where import Numeric.LinearAlgebra.Data @@ -149,6 +150,6 @@ import Numeric.LinearAlgebra.Algorithms import Numeric.LinearAlgebra.Util import Numeric.LinearAlgebra.Random import Numeric.Sparse(smXv) -import Numeric.LinearAlgebra.Util.CG(cgSolve) -import Numeric.LinearAlgebra.Util.CG(CGMat) +import Numeric.LinearAlgebra.Util.CG + diff --git a/packages/base/src/Numeric/LinearAlgebra/Algorithms.hs b/packages/base/src/Numeric/LinearAlgebra/Algorithms.hs index 063bfc9..c7e7043 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Algorithms.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Algorithms.hs @@ -66,7 +66,7 @@ module Numeric.LinearAlgebra.Algorithms ( orth, -- * Norms Normed(..), NormType(..), - relativeError, + relativeError', relativeError, -- * Misc eps, peps, i, -- * Util @@ -719,11 +719,26 @@ instance Normed Matrix (Complex Float) where pnorm Frobenius = pnorm PNorm2 . flatten -- | Approximate number of common digits in the maximum element. -relativeError :: (Normed c t, Container c t) => c t -> c t -> Int -relativeError x y = dig (norm (x `sub` y) / norm x) +relativeError' :: (Normed c t, Container c t) => c t -> c t -> Int +relativeError' x y = dig (norm (x `sub` y) / norm x) where norm = pnorm Infinity dig r = round $ -logBase 10 (realToFrac r :: Double) + +relativeError :: (Normed c t, Num (c t)) => NormType -> c t -> c t -> Double +relativeError t a b = realToFrac r + where + norm = pnorm t + na = norm a + nb = norm b + nab = norm (a-b) + mx = max na nb + mn = min na nb + r = if mn < peps + then mx + else nab/mx + + ---------------------------------------------------------------------- -- | Generalized symmetric positive definite eigensystem Av = lBv, diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs index 2c782e8..d21602d 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs @@ -2,8 +2,8 @@ {-# LANGUAGE RecordWildCards #-} module Numeric.LinearAlgebra.Util.CG( - cgSolve, - CGMat + cgSolve, cgSolve', + CGMat, CGState(..), R, V ) where import Numeric.Container @@ -16,23 +16,23 @@ import Util.Misc(debug, debugMat) infix 0 // -- , /// a // b = debug b id a -(///) :: DV -> String -> DV +(///) :: V -> String -> V infix 0 /// v /// b = debugMat b 2 asRow v -} - -type DV = Vector Double +type R = Double +type V = Vector R data CGState = CGState - { cgp :: DV - , cgr :: DV - , cgr2 :: Double - , cgx :: DV - , cgdx :: Double + { cgp :: V -- ^ conjugate gradient + , cgr :: V -- ^ residual + , cgr2 :: R -- ^ squared norm of residual + , cgx :: V -- ^ current solution + , cgdx :: R -- ^ normalized size of correction } -cg :: Bool -> (DV -> DV) -> (DV -> DV) -> CGState -> CGState +cg :: Bool -> (V -> V) -> (V -> V) -> CGState -> CGState cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx where ap1 = a p @@ -51,16 +51,16 @@ cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx rdx = norm2 dx / max 1 (norm2 x) conjugrad - :: (Transposable m, Contraction m DV DV) - => Bool -> m -> DV -> DV -> Double -> Double -> [CGState] + :: (Transposable m, Contraction m V V) + => Bool -> m -> V -> V -> R -> R -> [CGState] conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b solveG - :: (DV -> DV) -> (DV -> DV) - -> ((DV -> DV) -> (DV -> DV) -> CGState -> CGState) - -> DV - -> DV - -> Double -> Double + :: (V -> V) -> (V -> V) + -> ((V -> V) -> (V -> V) -> CGState -> CGState) + -> V + -> V + -> R -> R -> [CGState] solveG mat ma meth rawb x0' ϵb ϵx = takeUntil ok . iterate (meth mat ma) $ CGState p0 r0 r20 x0 1 @@ -82,15 +82,27 @@ takeUntil q xs = a++ take 1 b where (a,b) = break q xs -class (Transposable m, Contraction m (Vector Double) (Vector Double)) => CGMat m +class (Transposable m, Contraction m V V) => CGMat m cgSolve :: CGMat m - => Bool -- ^ symmetric - -> Double -- ^ relative tolerance for the residual (e.g. 1E-4) - -> Double -- ^ relative tolerance for δx (e.g. 1E-3) - -> m -- ^ coefficient matrix + => Bool -- ^ is symmetric + -> m -- ^ coefficient matrix -> Vector Double -- ^ right-hand side - -> Vector Double -- ^ solution -cgSolve sym er es a b = cgx $ last $ conjugrad sym a b 0 er es + -> Vector Double -- ^ solution +cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0 + where + n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double)) + +cgSolve' + :: CGMat m + => Bool -- ^ symmetric + -> R -- ^ relative tolerance for the residual (e.g. 1E-4) + -> R -- ^ relative tolerance for δx (e.g. 1E-3) + -> Int -- ^ maximum number of iterations + -> m -- ^ coefficient matrix + -> V -- ^ initial solution + -> V -- ^ right-hand side + -> [CGState] -- ^ solution +cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es diff --git a/packages/base/src/Numeric/Sparse.hs b/packages/base/src/Numeric/Sparse.hs index 3835590..1957d3a 100644 --- a/packages/base/src/Numeric/Sparse.hs +++ b/packages/base/src/Numeric/Sparse.hs @@ -17,7 +17,8 @@ import Control.Arrow((***)) import Control.Monad(when) import Data.List(groupBy, sort) import Foreign.C.Types(CInt(..)) -import Numeric.LinearAlgebra.Util.CG(CGMat) +import Numeric.LinearAlgebra.Util.CG(CGMat,cgSolve) +import Numeric.LinearAlgebra.Algorithms(linearSolveLS, relativeError, NormType(..)) import Data.Packed.Development import System.IO.Unsafe(unsafePerformIO) import Foreign(Ptr) @@ -150,12 +151,13 @@ instance Testable SMatrix x1 = vect [1..20] x2 = vect [1..40] sm = mkCSR sma + dm = toDense sma s1 = sm ◇ x1 - d1 = toDense sma ◇ x1 + d1 = dm ◇ x1 s2 = tr sm ◇ x2 - d2 = tr (toDense sma) ◇ x2 + d2 = tr dm ◇ x2 sdia = mkDiagR 40 20 (vect [1..10]) s3 = sdia ◇ x1 @@ -164,6 +166,10 @@ instance Testable SMatrix d3 = ddia ◇ x1 d4 = tr ddia ◇ x2 + v = testb 40 + s5 = cgSolve False sm v + d5 = denseSolve dm v + info = do print sm disp (toDense sma) @@ -171,11 +177,14 @@ instance Testable SMatrix print s2; print d2 print s3; print d3 print s4; print d4 + print s5; print d5 + print $ relativeError Infinity s5 d5 ok = s1==d1 && s2==d2 && s3==d3 && s4==d4 + && relativeError Infinity s5 d5 < 1E-10 disp = putStr . dispf 2 @@ -189,4 +198,8 @@ instance Testable SMatrix where m1 = convomat n k m2 = map (((+n) *** id) *** id) m1 + + testb n = vect $ take n $ cycle ([0..10]++[9,8..1]) + + denseSolve a = flatten . linearSolveLS a . asColumn -- cgit v1.2.3