summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2014-05-23 10:51:16 +0200
committerAlberto Ruiz <aruiz@um.es>2014-05-23 10:51:16 +0200
commit0a9ef8f5b0088c1ac25175bffca4ed95d9e109a5 (patch)
tree7bd461ebbf140804e470ae03bb3a2e29bdd935e2 /packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
parent109fa7d25779e331356bbe310755c10eddfeb235 (diff)
relativeError, cgSolve'
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra/Util/CG.hs')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Util/CG.hs62
1 files changed, 37 insertions, 25 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
index 2c782e8..d21602d 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
@@ -2,8 +2,8 @@
2{-# LANGUAGE RecordWildCards #-} 2{-# LANGUAGE RecordWildCards #-}
3 3
4module Numeric.LinearAlgebra.Util.CG( 4module Numeric.LinearAlgebra.Util.CG(
5 cgSolve, 5 cgSolve, cgSolve',
6 CGMat 6 CGMat, CGState(..), R, V
7) where 7) where
8 8
9import Numeric.Container 9import Numeric.Container
@@ -16,23 +16,23 @@ import Util.Misc(debug, debugMat)
16infix 0 // -- , /// 16infix 0 // -- , ///
17a // b = debug b id a 17a // b = debug b id a
18 18
19(///) :: DV -> String -> DV 19(///) :: V -> String -> V
20infix 0 /// 20infix 0 ///
21v /// b = debugMat b 2 asRow v 21v /// b = debugMat b 2 asRow v
22-} 22-}
23 23
24 24type R = Double
25type DV = Vector Double 25type V = Vector R
26 26
27data CGState = CGState 27data CGState = CGState
28 { cgp :: DV 28 { cgp :: V -- ^ conjugate gradient
29 , cgr :: DV 29 , cgr :: V -- ^ residual
30 , cgr2 :: Double 30 , cgr2 :: R -- ^ squared norm of residual
31 , cgx :: DV 31 , cgx :: V -- ^ current solution
32 , cgdx :: Double 32 , cgdx :: R -- ^ normalized size of correction
33 } 33 }
34 34
35cg :: Bool -> (DV -> DV) -> (DV -> DV) -> CGState -> CGState 35cg :: Bool -> (V -> V) -> (V -> V) -> CGState -> CGState
36cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx 36cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx
37 where 37 where
38 ap1 = a p 38 ap1 = a p
@@ -51,16 +51,16 @@ cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx
51 rdx = norm2 dx / max 1 (norm2 x) 51 rdx = norm2 dx / max 1 (norm2 x)
52 52
53conjugrad 53conjugrad
54 :: (Transposable m, Contraction m DV DV) 54 :: (Transposable m, Contraction m V V)
55 => Bool -> m -> DV -> DV -> Double -> Double -> [CGState] 55 => Bool -> m -> V -> V -> R -> R -> [CGState]
56conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b 56conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b
57 57
58solveG 58solveG
59 :: (DV -> DV) -> (DV -> DV) 59 :: (V -> V) -> (V -> V)
60 -> ((DV -> DV) -> (DV -> DV) -> CGState -> CGState) 60 -> ((V -> V) -> (V -> V) -> CGState -> CGState)
61 -> DV 61 -> V
62 -> DV 62 -> V
63 -> Double -> Double 63 -> R -> R
64 -> [CGState] 64 -> [CGState]
65solveG mat ma meth rawb x0' ϵb ϵx 65solveG mat ma meth rawb x0' ϵb ϵx
66 = takeUntil ok . iterate (meth mat ma) $ CGState p0 r0 r20 x0 1 66 = takeUntil ok . iterate (meth mat ma) $ CGState p0 r0 r20 x0 1
@@ -82,15 +82,27 @@ takeUntil q xs = a++ take 1 b
82 where 82 where
83 (a,b) = break q xs 83 (a,b) = break q xs
84 84
85class (Transposable m, Contraction m (Vector Double) (Vector Double)) => CGMat m 85class (Transposable m, Contraction m V V) => CGMat m
86 86
87cgSolve 87cgSolve
88 :: CGMat m 88 :: CGMat m
89 => Bool -- ^ symmetric 89 => Bool -- ^ is symmetric
90 -> Double -- ^ relative tolerance for the residual (e.g. 1E-4) 90 -> m -- ^ coefficient matrix
91 -> Double -- ^ relative tolerance for δx (e.g. 1E-3)
92 -> m -- ^ coefficient matrix
93 -> Vector Double -- ^ right-hand side 91 -> Vector Double -- ^ right-hand side
94 -> Vector Double -- ^ solution 92 -> Vector Double -- ^ solution
95cgSolve sym er es a b = cgx $ last $ conjugrad sym a b 0 er es 93cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0
94 where
95 n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double))
96
97cgSolve'
98 :: CGMat m
99 => Bool -- ^ symmetric
100 -> R -- ^ relative tolerance for the residual (e.g. 1E-4)
101 -> R -- ^ relative tolerance for δx (e.g. 1E-3)
102 -> Int -- ^ maximum number of iterations
103 -> m -- ^ coefficient matrix
104 -> V -- ^ initial solution
105 -> V -- ^ right-hand side
106 -> [CGState] -- ^ solution
107cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es
96 108