diff options
author | Alberto Ruiz <aruiz@um.es> | 2014-05-23 10:51:16 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2014-05-23 10:51:16 +0200 |
commit | 0a9ef8f5b0088c1ac25175bffca4ed95d9e109a5 (patch) | |
tree | 7bd461ebbf140804e470ae03bb3a2e29bdd935e2 /packages/base/src/Numeric | |
parent | 109fa7d25779e331356bbe310755c10eddfeb235 (diff) |
relativeError, cgSolve'
Diffstat (limited to 'packages/base/src/Numeric')
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra.hs | 9 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Algorithms.hs | 21 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Util/CG.hs | 62 | ||||
-rw-r--r-- | packages/base/src/Numeric/Sparse.hs | 19 |
4 files changed, 76 insertions, 35 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs index 9e9151e..242122f 100644 --- a/packages/base/src/Numeric/LinearAlgebra.hs +++ b/packages/base/src/Numeric/LinearAlgebra.hs | |||
@@ -69,6 +69,7 @@ module Numeric.LinearAlgebra ( | |||
69 | luSolve, | 69 | luSolve, |
70 | cholSolve, | 70 | cholSolve, |
71 | cgSolve, | 71 | cgSolve, |
72 | cgSolve', | ||
72 | 73 | ||
73 | -- * Inverse and pseudoinverse | 74 | -- * Inverse and pseudoinverse |
74 | inv, pinv, pinvTol, | 75 | inv, pinv, pinvTol, |
@@ -136,8 +137,8 @@ module Numeric.LinearAlgebra ( | |||
136 | RealOf, ComplexOf, SingleOf, DoubleOf, | 137 | RealOf, ComplexOf, SingleOf, DoubleOf, |
137 | IndexOf, | 138 | IndexOf, |
138 | Field, Normed, | 139 | Field, Normed, |
139 | CGMat, Transposable | 140 | CGMat, Transposable, |
140 | 141 | R,V | |
141 | ) where | 142 | ) where |
142 | 143 | ||
143 | import Numeric.LinearAlgebra.Data | 144 | import Numeric.LinearAlgebra.Data |
@@ -149,6 +150,6 @@ import Numeric.LinearAlgebra.Algorithms | |||
149 | import Numeric.LinearAlgebra.Util | 150 | import Numeric.LinearAlgebra.Util |
150 | import Numeric.LinearAlgebra.Random | 151 | import Numeric.LinearAlgebra.Random |
151 | import Numeric.Sparse(smXv) | 152 | import Numeric.Sparse(smXv) |
152 | import Numeric.LinearAlgebra.Util.CG(cgSolve) | 153 | import Numeric.LinearAlgebra.Util.CG |
153 | import Numeric.LinearAlgebra.Util.CG(CGMat) | 154 | |
154 | 155 | ||
diff --git a/packages/base/src/Numeric/LinearAlgebra/Algorithms.hs b/packages/base/src/Numeric/LinearAlgebra/Algorithms.hs index 063bfc9..c7e7043 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Algorithms.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Algorithms.hs | |||
@@ -66,7 +66,7 @@ module Numeric.LinearAlgebra.Algorithms ( | |||
66 | orth, | 66 | orth, |
67 | -- * Norms | 67 | -- * Norms |
68 | Normed(..), NormType(..), | 68 | Normed(..), NormType(..), |
69 | relativeError, | 69 | relativeError', relativeError, |
70 | -- * Misc | 70 | -- * Misc |
71 | eps, peps, i, | 71 | eps, peps, i, |
72 | -- * Util | 72 | -- * Util |
@@ -719,11 +719,26 @@ instance Normed Matrix (Complex Float) where | |||
719 | pnorm Frobenius = pnorm PNorm2 . flatten | 719 | pnorm Frobenius = pnorm PNorm2 . flatten |
720 | 720 | ||
721 | -- | Approximate number of common digits in the maximum element. | 721 | -- | Approximate number of common digits in the maximum element. |
722 | relativeError :: (Normed c t, Container c t) => c t -> c t -> Int | 722 | relativeError' :: (Normed c t, Container c t) => c t -> c t -> Int |
723 | relativeError x y = dig (norm (x `sub` y) / norm x) | 723 | relativeError' x y = dig (norm (x `sub` y) / norm x) |
724 | where norm = pnorm Infinity | 724 | where norm = pnorm Infinity |
725 | dig r = round $ -logBase 10 (realToFrac r :: Double) | 725 | dig r = round $ -logBase 10 (realToFrac r :: Double) |
726 | 726 | ||
727 | |||
728 | relativeError :: (Normed c t, Num (c t)) => NormType -> c t -> c t -> Double | ||
729 | relativeError t a b = realToFrac r | ||
730 | where | ||
731 | norm = pnorm t | ||
732 | na = norm a | ||
733 | nb = norm b | ||
734 | nab = norm (a-b) | ||
735 | mx = max na nb | ||
736 | mn = min na nb | ||
737 | r = if mn < peps | ||
738 | then mx | ||
739 | else nab/mx | ||
740 | |||
741 | |||
727 | ---------------------------------------------------------------------- | 742 | ---------------------------------------------------------------------- |
728 | 743 | ||
729 | -- | Generalized symmetric positive definite eigensystem Av = lBv, | 744 | -- | Generalized symmetric positive definite eigensystem Av = lBv, |
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs index 2c782e8..d21602d 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs | |||
@@ -2,8 +2,8 @@ | |||
2 | {-# LANGUAGE RecordWildCards #-} | 2 | {-# LANGUAGE RecordWildCards #-} |
3 | 3 | ||
4 | module Numeric.LinearAlgebra.Util.CG( | 4 | module Numeric.LinearAlgebra.Util.CG( |
5 | cgSolve, | 5 | cgSolve, cgSolve', |
6 | CGMat | 6 | CGMat, CGState(..), R, V |
7 | ) where | 7 | ) where |
8 | 8 | ||
9 | import Numeric.Container | 9 | import Numeric.Container |
@@ -16,23 +16,23 @@ import Util.Misc(debug, debugMat) | |||
16 | infix 0 // -- , /// | 16 | infix 0 // -- , /// |
17 | a // b = debug b id a | 17 | a // b = debug b id a |
18 | 18 | ||
19 | (///) :: DV -> String -> DV | 19 | (///) :: V -> String -> V |
20 | infix 0 /// | 20 | infix 0 /// |
21 | v /// b = debugMat b 2 asRow v | 21 | v /// b = debugMat b 2 asRow v |
22 | -} | 22 | -} |
23 | 23 | ||
24 | 24 | type R = Double | |
25 | type DV = Vector Double | 25 | type V = Vector R |
26 | 26 | ||
27 | data CGState = CGState | 27 | data CGState = CGState |
28 | { cgp :: DV | 28 | { cgp :: V -- ^ conjugate gradient |
29 | , cgr :: DV | 29 | , cgr :: V -- ^ residual |
30 | , cgr2 :: Double | 30 | , cgr2 :: R -- ^ squared norm of residual |
31 | , cgx :: DV | 31 | , cgx :: V -- ^ current solution |
32 | , cgdx :: Double | 32 | , cgdx :: R -- ^ normalized size of correction |
33 | } | 33 | } |
34 | 34 | ||
35 | cg :: Bool -> (DV -> DV) -> (DV -> DV) -> CGState -> CGState | 35 | cg :: Bool -> (V -> V) -> (V -> V) -> CGState -> CGState |
36 | cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx | 36 | cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx |
37 | where | 37 | where |
38 | ap1 = a p | 38 | ap1 = a p |
@@ -51,16 +51,16 @@ cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx | |||
51 | rdx = norm2 dx / max 1 (norm2 x) | 51 | rdx = norm2 dx / max 1 (norm2 x) |
52 | 52 | ||
53 | conjugrad | 53 | conjugrad |
54 | :: (Transposable m, Contraction m DV DV) | 54 | :: (Transposable m, Contraction m V V) |
55 | => Bool -> m -> DV -> DV -> Double -> Double -> [CGState] | 55 | => Bool -> m -> V -> V -> R -> R -> [CGState] |
56 | conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b | 56 | conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b |
57 | 57 | ||
58 | solveG | 58 | solveG |
59 | :: (DV -> DV) -> (DV -> DV) | 59 | :: (V -> V) -> (V -> V) |
60 | -> ((DV -> DV) -> (DV -> DV) -> CGState -> CGState) | 60 | -> ((V -> V) -> (V -> V) -> CGState -> CGState) |
61 | -> DV | 61 | -> V |
62 | -> DV | 62 | -> V |
63 | -> Double -> Double | 63 | -> R -> R |
64 | -> [CGState] | 64 | -> [CGState] |
65 | solveG mat ma meth rawb x0' ϵb ϵx | 65 | solveG mat ma meth rawb x0' ϵb ϵx |
66 | = takeUntil ok . iterate (meth mat ma) $ CGState p0 r0 r20 x0 1 | 66 | = takeUntil ok . iterate (meth mat ma) $ CGState p0 r0 r20 x0 1 |
@@ -82,15 +82,27 @@ takeUntil q xs = a++ take 1 b | |||
82 | where | 82 | where |
83 | (a,b) = break q xs | 83 | (a,b) = break q xs |
84 | 84 | ||
85 | class (Transposable m, Contraction m (Vector Double) (Vector Double)) => CGMat m | 85 | class (Transposable m, Contraction m V V) => CGMat m |
86 | 86 | ||
87 | cgSolve | 87 | cgSolve |
88 | :: CGMat m | 88 | :: CGMat m |
89 | => Bool -- ^ symmetric | 89 | => Bool -- ^ is symmetric |
90 | -> Double -- ^ relative tolerance for the residual (e.g. 1E-4) | 90 | -> m -- ^ coefficient matrix |
91 | -> Double -- ^ relative tolerance for δx (e.g. 1E-3) | ||
92 | -> m -- ^ coefficient matrix | ||
93 | -> Vector Double -- ^ right-hand side | 91 | -> Vector Double -- ^ right-hand side |
94 | -> Vector Double -- ^ solution | 92 | -> Vector Double -- ^ solution |
95 | cgSolve sym er es a b = cgx $ last $ conjugrad sym a b 0 er es | 93 | cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0 |
94 | where | ||
95 | n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double)) | ||
96 | |||
97 | cgSolve' | ||
98 | :: CGMat m | ||
99 | => Bool -- ^ symmetric | ||
100 | -> R -- ^ relative tolerance for the residual (e.g. 1E-4) | ||
101 | -> R -- ^ relative tolerance for δx (e.g. 1E-3) | ||
102 | -> Int -- ^ maximum number of iterations | ||
103 | -> m -- ^ coefficient matrix | ||
104 | -> V -- ^ initial solution | ||
105 | -> V -- ^ right-hand side | ||
106 | -> [CGState] -- ^ solution | ||
107 | cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es | ||
96 | 108 | ||
diff --git a/packages/base/src/Numeric/Sparse.hs b/packages/base/src/Numeric/Sparse.hs index 3835590..1957d3a 100644 --- a/packages/base/src/Numeric/Sparse.hs +++ b/packages/base/src/Numeric/Sparse.hs | |||
@@ -17,7 +17,8 @@ import Control.Arrow((***)) | |||
17 | import Control.Monad(when) | 17 | import Control.Monad(when) |
18 | import Data.List(groupBy, sort) | 18 | import Data.List(groupBy, sort) |
19 | import Foreign.C.Types(CInt(..)) | 19 | import Foreign.C.Types(CInt(..)) |
20 | import Numeric.LinearAlgebra.Util.CG(CGMat) | 20 | import Numeric.LinearAlgebra.Util.CG(CGMat,cgSolve) |
21 | import Numeric.LinearAlgebra.Algorithms(linearSolveLS, relativeError, NormType(..)) | ||
21 | import Data.Packed.Development | 22 | import Data.Packed.Development |
22 | import System.IO.Unsafe(unsafePerformIO) | 23 | import System.IO.Unsafe(unsafePerformIO) |
23 | import Foreign(Ptr) | 24 | import Foreign(Ptr) |
@@ -150,12 +151,13 @@ instance Testable SMatrix | |||
150 | x1 = vect [1..20] | 151 | x1 = vect [1..20] |
151 | x2 = vect [1..40] | 152 | x2 = vect [1..40] |
152 | sm = mkCSR sma | 153 | sm = mkCSR sma |
154 | dm = toDense sma | ||
153 | 155 | ||
154 | s1 = sm ◇ x1 | 156 | s1 = sm ◇ x1 |
155 | d1 = toDense sma ◇ x1 | 157 | d1 = dm ◇ x1 |
156 | 158 | ||
157 | s2 = tr sm ◇ x2 | 159 | s2 = tr sm ◇ x2 |
158 | d2 = tr (toDense sma) ◇ x2 | 160 | d2 = tr dm ◇ x2 |
159 | 161 | ||
160 | sdia = mkDiagR 40 20 (vect [1..10]) | 162 | sdia = mkDiagR 40 20 (vect [1..10]) |
161 | s3 = sdia ◇ x1 | 163 | s3 = sdia ◇ x1 |
@@ -164,6 +166,10 @@ instance Testable SMatrix | |||
164 | d3 = ddia ◇ x1 | 166 | d3 = ddia ◇ x1 |
165 | d4 = tr ddia ◇ x2 | 167 | d4 = tr ddia ◇ x2 |
166 | 168 | ||
169 | v = testb 40 | ||
170 | s5 = cgSolve False sm v | ||
171 | d5 = denseSolve dm v | ||
172 | |||
167 | info = do | 173 | info = do |
168 | print sm | 174 | print sm |
169 | disp (toDense sma) | 175 | disp (toDense sma) |
@@ -171,11 +177,14 @@ instance Testable SMatrix | |||
171 | print s2; print d2 | 177 | print s2; print d2 |
172 | print s3; print d3 | 178 | print s3; print d3 |
173 | print s4; print d4 | 179 | print s4; print d4 |
180 | print s5; print d5 | ||
181 | print $ relativeError Infinity s5 d5 | ||
174 | 182 | ||
175 | ok = s1==d1 | 183 | ok = s1==d1 |
176 | && s2==d2 | 184 | && s2==d2 |
177 | && s3==d3 | 185 | && s3==d3 |
178 | && s4==d4 | 186 | && s4==d4 |
187 | && relativeError Infinity s5 d5 < 1E-10 | ||
179 | 188 | ||
180 | disp = putStr . dispf 2 | 189 | disp = putStr . dispf 2 |
181 | 190 | ||
@@ -189,4 +198,8 @@ instance Testable SMatrix | |||
189 | where | 198 | where |
190 | m1 = convomat n k | 199 | m1 = convomat n k |
191 | m2 = map (((+n) *** id) *** id) m1 | 200 | m2 = map (((+n) *** id) *** id) m1 |
201 | |||
202 | testb n = vect $ take n $ cycle ([0..10]++[9,8..1]) | ||
203 | |||
204 | denseSolve a = flatten . linearSolveLS a . asColumn | ||
192 | 205 | ||