diff options
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra')
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Algorithms.hs | 21 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Util/CG.hs | 62 |
2 files changed, 55 insertions, 28 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Algorithms.hs b/packages/base/src/Numeric/LinearAlgebra/Algorithms.hs index 063bfc9..c7e7043 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Algorithms.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Algorithms.hs | |||
@@ -66,7 +66,7 @@ module Numeric.LinearAlgebra.Algorithms ( | |||
66 | orth, | 66 | orth, |
67 | -- * Norms | 67 | -- * Norms |
68 | Normed(..), NormType(..), | 68 | Normed(..), NormType(..), |
69 | relativeError, | 69 | relativeError', relativeError, |
70 | -- * Misc | 70 | -- * Misc |
71 | eps, peps, i, | 71 | eps, peps, i, |
72 | -- * Util | 72 | -- * Util |
@@ -719,11 +719,26 @@ instance Normed Matrix (Complex Float) where | |||
719 | pnorm Frobenius = pnorm PNorm2 . flatten | 719 | pnorm Frobenius = pnorm PNorm2 . flatten |
720 | 720 | ||
721 | -- | Approximate number of common digits in the maximum element. | 721 | -- | Approximate number of common digits in the maximum element. |
722 | relativeError :: (Normed c t, Container c t) => c t -> c t -> Int | 722 | relativeError' :: (Normed c t, Container c t) => c t -> c t -> Int |
723 | relativeError x y = dig (norm (x `sub` y) / norm x) | 723 | relativeError' x y = dig (norm (x `sub` y) / norm x) |
724 | where norm = pnorm Infinity | 724 | where norm = pnorm Infinity |
725 | dig r = round $ -logBase 10 (realToFrac r :: Double) | 725 | dig r = round $ -logBase 10 (realToFrac r :: Double) |
726 | 726 | ||
727 | |||
728 | relativeError :: (Normed c t, Num (c t)) => NormType -> c t -> c t -> Double | ||
729 | relativeError t a b = realToFrac r | ||
730 | where | ||
731 | norm = pnorm t | ||
732 | na = norm a | ||
733 | nb = norm b | ||
734 | nab = norm (a-b) | ||
735 | mx = max na nb | ||
736 | mn = min na nb | ||
737 | r = if mn < peps | ||
738 | then mx | ||
739 | else nab/mx | ||
740 | |||
741 | |||
727 | ---------------------------------------------------------------------- | 742 | ---------------------------------------------------------------------- |
728 | 743 | ||
729 | -- | Generalized symmetric positive definite eigensystem Av = lBv, | 744 | -- | Generalized symmetric positive definite eigensystem Av = lBv, |
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs index 2c782e8..d21602d 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs | |||
@@ -2,8 +2,8 @@ | |||
2 | {-# LANGUAGE RecordWildCards #-} | 2 | {-# LANGUAGE RecordWildCards #-} |
3 | 3 | ||
4 | module Numeric.LinearAlgebra.Util.CG( | 4 | module Numeric.LinearAlgebra.Util.CG( |
5 | cgSolve, | 5 | cgSolve, cgSolve', |
6 | CGMat | 6 | CGMat, CGState(..), R, V |
7 | ) where | 7 | ) where |
8 | 8 | ||
9 | import Numeric.Container | 9 | import Numeric.Container |
@@ -16,23 +16,23 @@ import Util.Misc(debug, debugMat) | |||
16 | infix 0 // -- , /// | 16 | infix 0 // -- , /// |
17 | a // b = debug b id a | 17 | a // b = debug b id a |
18 | 18 | ||
19 | (///) :: DV -> String -> DV | 19 | (///) :: V -> String -> V |
20 | infix 0 /// | 20 | infix 0 /// |
21 | v /// b = debugMat b 2 asRow v | 21 | v /// b = debugMat b 2 asRow v |
22 | -} | 22 | -} |
23 | 23 | ||
24 | 24 | type R = Double | |
25 | type DV = Vector Double | 25 | type V = Vector R |
26 | 26 | ||
27 | data CGState = CGState | 27 | data CGState = CGState |
28 | { cgp :: DV | 28 | { cgp :: V -- ^ conjugate gradient |
29 | , cgr :: DV | 29 | , cgr :: V -- ^ residual |
30 | , cgr2 :: Double | 30 | , cgr2 :: R -- ^ squared norm of residual |
31 | , cgx :: DV | 31 | , cgx :: V -- ^ current solution |
32 | , cgdx :: Double | 32 | , cgdx :: R -- ^ normalized size of correction |
33 | } | 33 | } |
34 | 34 | ||
35 | cg :: Bool -> (DV -> DV) -> (DV -> DV) -> CGState -> CGState | 35 | cg :: Bool -> (V -> V) -> (V -> V) -> CGState -> CGState |
36 | cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx | 36 | cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx |
37 | where | 37 | where |
38 | ap1 = a p | 38 | ap1 = a p |
@@ -51,16 +51,16 @@ cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx | |||
51 | rdx = norm2 dx / max 1 (norm2 x) | 51 | rdx = norm2 dx / max 1 (norm2 x) |
52 | 52 | ||
53 | conjugrad | 53 | conjugrad |
54 | :: (Transposable m, Contraction m DV DV) | 54 | :: (Transposable m, Contraction m V V) |
55 | => Bool -> m -> DV -> DV -> Double -> Double -> [CGState] | 55 | => Bool -> m -> V -> V -> R -> R -> [CGState] |
56 | conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b | 56 | conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b |
57 | 57 | ||
58 | solveG | 58 | solveG |
59 | :: (DV -> DV) -> (DV -> DV) | 59 | :: (V -> V) -> (V -> V) |
60 | -> ((DV -> DV) -> (DV -> DV) -> CGState -> CGState) | 60 | -> ((V -> V) -> (V -> V) -> CGState -> CGState) |
61 | -> DV | 61 | -> V |
62 | -> DV | 62 | -> V |
63 | -> Double -> Double | 63 | -> R -> R |
64 | -> [CGState] | 64 | -> [CGState] |
65 | solveG mat ma meth rawb x0' ϵb ϵx | 65 | solveG mat ma meth rawb x0' ϵb ϵx |
66 | = takeUntil ok . iterate (meth mat ma) $ CGState p0 r0 r20 x0 1 | 66 | = takeUntil ok . iterate (meth mat ma) $ CGState p0 r0 r20 x0 1 |
@@ -82,15 +82,27 @@ takeUntil q xs = a++ take 1 b | |||
82 | where | 82 | where |
83 | (a,b) = break q xs | 83 | (a,b) = break q xs |
84 | 84 | ||
85 | class (Transposable m, Contraction m (Vector Double) (Vector Double)) => CGMat m | 85 | class (Transposable m, Contraction m V V) => CGMat m |
86 | 86 | ||
87 | cgSolve | 87 | cgSolve |
88 | :: CGMat m | 88 | :: CGMat m |
89 | => Bool -- ^ symmetric | 89 | => Bool -- ^ is symmetric |
90 | -> Double -- ^ relative tolerance for the residual (e.g. 1E-4) | 90 | -> m -- ^ coefficient matrix |
91 | -> Double -- ^ relative tolerance for δx (e.g. 1E-3) | ||
92 | -> m -- ^ coefficient matrix | ||
93 | -> Vector Double -- ^ right-hand side | 91 | -> Vector Double -- ^ right-hand side |
94 | -> Vector Double -- ^ solution | 92 | -> Vector Double -- ^ solution |
95 | cgSolve sym er es a b = cgx $ last $ conjugrad sym a b 0 er es | 93 | cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0 |
94 | where | ||
95 | n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double)) | ||
96 | |||
97 | cgSolve' | ||
98 | :: CGMat m | ||
99 | => Bool -- ^ symmetric | ||
100 | -> R -- ^ relative tolerance for the residual (e.g. 1E-4) | ||
101 | -> R -- ^ relative tolerance for δx (e.g. 1E-3) | ||
102 | -> Int -- ^ maximum number of iterations | ||
103 | -> m -- ^ coefficient matrix | ||
104 | -> V -- ^ initial solution | ||
105 | -> V -- ^ right-hand side | ||
106 | -> [CGState] -- ^ solution | ||
107 | cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es | ||
96 | 108 | ||