summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric/LinearAlgebra
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Algorithms.hs21
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Util/CG.hs62
2 files changed, 55 insertions, 28 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Algorithms.hs b/packages/base/src/Numeric/LinearAlgebra/Algorithms.hs
index 063bfc9..c7e7043 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Algorithms.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Algorithms.hs
@@ -66,7 +66,7 @@ module Numeric.LinearAlgebra.Algorithms (
66 orth, 66 orth,
67-- * Norms 67-- * Norms
68 Normed(..), NormType(..), 68 Normed(..), NormType(..),
69 relativeError, 69 relativeError', relativeError,
70-- * Misc 70-- * Misc
71 eps, peps, i, 71 eps, peps, i,
72-- * Util 72-- * Util
@@ -719,11 +719,26 @@ instance Normed Matrix (Complex Float) where
719 pnorm Frobenius = pnorm PNorm2 . flatten 719 pnorm Frobenius = pnorm PNorm2 . flatten
720 720
721-- | Approximate number of common digits in the maximum element. 721-- | Approximate number of common digits in the maximum element.
722relativeError :: (Normed c t, Container c t) => c t -> c t -> Int 722relativeError' :: (Normed c t, Container c t) => c t -> c t -> Int
723relativeError x y = dig (norm (x `sub` y) / norm x) 723relativeError' x y = dig (norm (x `sub` y) / norm x)
724 where norm = pnorm Infinity 724 where norm = pnorm Infinity
725 dig r = round $ -logBase 10 (realToFrac r :: Double) 725 dig r = round $ -logBase 10 (realToFrac r :: Double)
726 726
727
728relativeError :: (Normed c t, Num (c t)) => NormType -> c t -> c t -> Double
729relativeError t a b = realToFrac r
730 where
731 norm = pnorm t
732 na = norm a
733 nb = norm b
734 nab = norm (a-b)
735 mx = max na nb
736 mn = min na nb
737 r = if mn < peps
738 then mx
739 else nab/mx
740
741
727---------------------------------------------------------------------- 742----------------------------------------------------------------------
728 743
729-- | Generalized symmetric positive definite eigensystem Av = lBv, 744-- | Generalized symmetric positive definite eigensystem Av = lBv,
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
index 2c782e8..d21602d 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
@@ -2,8 +2,8 @@
2{-# LANGUAGE RecordWildCards #-} 2{-# LANGUAGE RecordWildCards #-}
3 3
4module Numeric.LinearAlgebra.Util.CG( 4module Numeric.LinearAlgebra.Util.CG(
5 cgSolve, 5 cgSolve, cgSolve',
6 CGMat 6 CGMat, CGState(..), R, V
7) where 7) where
8 8
9import Numeric.Container 9import Numeric.Container
@@ -16,23 +16,23 @@ import Util.Misc(debug, debugMat)
16infix 0 // -- , /// 16infix 0 // -- , ///
17a // b = debug b id a 17a // b = debug b id a
18 18
19(///) :: DV -> String -> DV 19(///) :: V -> String -> V
20infix 0 /// 20infix 0 ///
21v /// b = debugMat b 2 asRow v 21v /// b = debugMat b 2 asRow v
22-} 22-}
23 23
24 24type R = Double
25type DV = Vector Double 25type V = Vector R
26 26
27data CGState = CGState 27data CGState = CGState
28 { cgp :: DV 28 { cgp :: V -- ^ conjugate gradient
29 , cgr :: DV 29 , cgr :: V -- ^ residual
30 , cgr2 :: Double 30 , cgr2 :: R -- ^ squared norm of residual
31 , cgx :: DV 31 , cgx :: V -- ^ current solution
32 , cgdx :: Double 32 , cgdx :: R -- ^ normalized size of correction
33 } 33 }
34 34
35cg :: Bool -> (DV -> DV) -> (DV -> DV) -> CGState -> CGState 35cg :: Bool -> (V -> V) -> (V -> V) -> CGState -> CGState
36cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx 36cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx
37 where 37 where
38 ap1 = a p 38 ap1 = a p
@@ -51,16 +51,16 @@ cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx
51 rdx = norm2 dx / max 1 (norm2 x) 51 rdx = norm2 dx / max 1 (norm2 x)
52 52
53conjugrad 53conjugrad
54 :: (Transposable m, Contraction m DV DV) 54 :: (Transposable m, Contraction m V V)
55 => Bool -> m -> DV -> DV -> Double -> Double -> [CGState] 55 => Bool -> m -> V -> V -> R -> R -> [CGState]
56conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b 56conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b
57 57
58solveG 58solveG
59 :: (DV -> DV) -> (DV -> DV) 59 :: (V -> V) -> (V -> V)
60 -> ((DV -> DV) -> (DV -> DV) -> CGState -> CGState) 60 -> ((V -> V) -> (V -> V) -> CGState -> CGState)
61 -> DV 61 -> V
62 -> DV 62 -> V
63 -> Double -> Double 63 -> R -> R
64 -> [CGState] 64 -> [CGState]
65solveG mat ma meth rawb x0' ϵb ϵx 65solveG mat ma meth rawb x0' ϵb ϵx
66 = takeUntil ok . iterate (meth mat ma) $ CGState p0 r0 r20 x0 1 66 = takeUntil ok . iterate (meth mat ma) $ CGState p0 r0 r20 x0 1
@@ -82,15 +82,27 @@ takeUntil q xs = a++ take 1 b
82 where 82 where
83 (a,b) = break q xs 83 (a,b) = break q xs
84 84
85class (Transposable m, Contraction m (Vector Double) (Vector Double)) => CGMat m 85class (Transposable m, Contraction m V V) => CGMat m
86 86
87cgSolve 87cgSolve
88 :: CGMat m 88 :: CGMat m
89 => Bool -- ^ symmetric 89 => Bool -- ^ is symmetric
90 -> Double -- ^ relative tolerance for the residual (e.g. 1E-4) 90 -> m -- ^ coefficient matrix
91 -> Double -- ^ relative tolerance for δx (e.g. 1E-3)
92 -> m -- ^ coefficient matrix
93 -> Vector Double -- ^ right-hand side 91 -> Vector Double -- ^ right-hand side
94 -> Vector Double -- ^ solution 92 -> Vector Double -- ^ solution
95cgSolve sym er es a b = cgx $ last $ conjugrad sym a b 0 er es 93cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0
94 where
95 n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double))
96
97cgSolve'
98 :: CGMat m
99 => Bool -- ^ symmetric
100 -> R -- ^ relative tolerance for the residual (e.g. 1E-4)
101 -> R -- ^ relative tolerance for δx (e.g. 1E-3)
102 -> Int -- ^ maximum number of iterations
103 -> m -- ^ coefficient matrix
104 -> V -- ^ initial solution
105 -> V -- ^ right-hand side
106 -> [CGState] -- ^ solution
107cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es
96 108