diff options
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra/Util')
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Util/CG.hs | 96 |
1 files changed, 96 insertions, 0 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs new file mode 100644 index 0000000..2c782e8 --- /dev/null +++ b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs | |||
@@ -0,0 +1,96 @@ | |||
1 | {-# LANGUAGE FlexibleContexts, FlexibleInstances #-} | ||
2 | {-# LANGUAGE RecordWildCards #-} | ||
3 | |||
4 | module Numeric.LinearAlgebra.Util.CG( | ||
5 | cgSolve, | ||
6 | CGMat | ||
7 | ) where | ||
8 | |||
9 | import Numeric.Container | ||
10 | import Numeric.Vector() | ||
11 | |||
12 | {- | ||
13 | import Util.Misc(debug, debugMat) | ||
14 | |||
15 | (//) :: Show a => a -> String -> a | ||
16 | infix 0 // -- , /// | ||
17 | a // b = debug b id a | ||
18 | |||
19 | (///) :: DV -> String -> DV | ||
20 | infix 0 /// | ||
21 | v /// b = debugMat b 2 asRow v | ||
22 | -} | ||
23 | |||
24 | |||
25 | type DV = Vector Double | ||
26 | |||
27 | data CGState = CGState | ||
28 | { cgp :: DV | ||
29 | , cgr :: DV | ||
30 | , cgr2 :: Double | ||
31 | , cgx :: DV | ||
32 | , cgdx :: Double | ||
33 | } | ||
34 | |||
35 | cg :: Bool -> (DV -> DV) -> (DV -> DV) -> CGState -> CGState | ||
36 | cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx | ||
37 | where | ||
38 | ap1 = a p | ||
39 | ap | sym = ap1 | ||
40 | | otherwise = at ap1 | ||
41 | pap | sym = p ◇ ap1 | ||
42 | | otherwise = norm2 ap1 ** 2 | ||
43 | alpha = r2 / pap | ||
44 | dx = scale alpha p | ||
45 | x' = x + dx | ||
46 | r' = r - scale alpha ap | ||
47 | r'2 = r' ◇ r' | ||
48 | beta = r'2 / r2 | ||
49 | p' = r' + scale beta p | ||
50 | |||
51 | rdx = norm2 dx / max 1 (norm2 x) | ||
52 | |||
53 | conjugrad | ||
54 | :: (Transposable m, Contraction m DV DV) | ||
55 | => Bool -> m -> DV -> DV -> Double -> Double -> [CGState] | ||
56 | conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b | ||
57 | |||
58 | solveG | ||
59 | :: (DV -> DV) -> (DV -> DV) | ||
60 | -> ((DV -> DV) -> (DV -> DV) -> CGState -> CGState) | ||
61 | -> DV | ||
62 | -> DV | ||
63 | -> Double -> Double | ||
64 | -> [CGState] | ||
65 | solveG mat ma meth rawb x0' ϵb ϵx | ||
66 | = takeUntil ok . iterate (meth mat ma) $ CGState p0 r0 r20 x0 1 | ||
67 | where | ||
68 | a = mat . ma | ||
69 | b = mat rawb | ||
70 | x0 = if x0' == 0 then konst 0 (dim b) else x0' | ||
71 | r0 = b - a x0 | ||
72 | r20 = r0 ◇ r0 | ||
73 | p0 = r0 | ||
74 | nb2 = b ◇ b | ||
75 | ok CGState {..} | ||
76 | = cgr2 <nb2*ϵb**2 | ||
77 | || cgdx < ϵx | ||
78 | |||
79 | |||
80 | takeUntil :: (a -> Bool) -> [a] -> [a] | ||
81 | takeUntil q xs = a++ take 1 b | ||
82 | where | ||
83 | (a,b) = break q xs | ||
84 | |||
85 | class (Transposable m, Contraction m (Vector Double) (Vector Double)) => CGMat m | ||
86 | |||
87 | cgSolve | ||
88 | :: CGMat m | ||
89 | => Bool -- ^ symmetric | ||
90 | -> Double -- ^ relative tolerance for the residual (e.g. 1E-4) | ||
91 | -> Double -- ^ relative tolerance for δx (e.g. 1E-3) | ||
92 | -> m -- ^ coefficient matrix | ||
93 | -> Vector Double -- ^ right-hand side | ||
94 | -> Vector Double -- ^ solution | ||
95 | cgSolve sym er es a b = cgx $ last $ conjugrad sym a b 0 er es | ||
96 | |||