summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--packages/base/src/Numeric/LinearAlgebra.hs9
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Algorithms.hs21
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Util/CG.hs62
-rw-r--r--packages/base/src/Numeric/Sparse.hs19
-rw-r--r--packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs12
5 files changed, 77 insertions, 46 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs
index 9e9151e..242122f 100644
--- a/packages/base/src/Numeric/LinearAlgebra.hs
+++ b/packages/base/src/Numeric/LinearAlgebra.hs
@@ -69,6 +69,7 @@ module Numeric.LinearAlgebra (
69 luSolve, 69 luSolve,
70 cholSolve, 70 cholSolve,
71 cgSolve, 71 cgSolve,
72 cgSolve',
72 73
73 -- * Inverse and pseudoinverse 74 -- * Inverse and pseudoinverse
74 inv, pinv, pinvTol, 75 inv, pinv, pinvTol,
@@ -136,8 +137,8 @@ module Numeric.LinearAlgebra (
136 RealOf, ComplexOf, SingleOf, DoubleOf, 137 RealOf, ComplexOf, SingleOf, DoubleOf,
137 IndexOf, 138 IndexOf,
138 Field, Normed, 139 Field, Normed,
139 CGMat, Transposable 140 CGMat, Transposable,
140 141 R,V
141) where 142) where
142 143
143import Numeric.LinearAlgebra.Data 144import Numeric.LinearAlgebra.Data
@@ -149,6 +150,6 @@ import Numeric.LinearAlgebra.Algorithms
149import Numeric.LinearAlgebra.Util 150import Numeric.LinearAlgebra.Util
150import Numeric.LinearAlgebra.Random 151import Numeric.LinearAlgebra.Random
151import Numeric.Sparse(smXv) 152import Numeric.Sparse(smXv)
152import Numeric.LinearAlgebra.Util.CG(cgSolve) 153import Numeric.LinearAlgebra.Util.CG
153import Numeric.LinearAlgebra.Util.CG(CGMat) 154
154 155
diff --git a/packages/base/src/Numeric/LinearAlgebra/Algorithms.hs b/packages/base/src/Numeric/LinearAlgebra/Algorithms.hs
index 063bfc9..c7e7043 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Algorithms.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Algorithms.hs
@@ -66,7 +66,7 @@ module Numeric.LinearAlgebra.Algorithms (
66 orth, 66 orth,
67-- * Norms 67-- * Norms
68 Normed(..), NormType(..), 68 Normed(..), NormType(..),
69 relativeError, 69 relativeError', relativeError,
70-- * Misc 70-- * Misc
71 eps, peps, i, 71 eps, peps, i,
72-- * Util 72-- * Util
@@ -719,11 +719,26 @@ instance Normed Matrix (Complex Float) where
719 pnorm Frobenius = pnorm PNorm2 . flatten 719 pnorm Frobenius = pnorm PNorm2 . flatten
720 720
721-- | Approximate number of common digits in the maximum element. 721-- | Approximate number of common digits in the maximum element.
722relativeError :: (Normed c t, Container c t) => c t -> c t -> Int 722relativeError' :: (Normed c t, Container c t) => c t -> c t -> Int
723relativeError x y = dig (norm (x `sub` y) / norm x) 723relativeError' x y = dig (norm (x `sub` y) / norm x)
724 where norm = pnorm Infinity 724 where norm = pnorm Infinity
725 dig r = round $ -logBase 10 (realToFrac r :: Double) 725 dig r = round $ -logBase 10 (realToFrac r :: Double)
726 726
727
728relativeError :: (Normed c t, Num (c t)) => NormType -> c t -> c t -> Double
729relativeError t a b = realToFrac r
730 where
731 norm = pnorm t
732 na = norm a
733 nb = norm b
734 nab = norm (a-b)
735 mx = max na nb
736 mn = min na nb
737 r = if mn < peps
738 then mx
739 else nab/mx
740
741
727---------------------------------------------------------------------- 742----------------------------------------------------------------------
728 743
729-- | Generalized symmetric positive definite eigensystem Av = lBv, 744-- | Generalized symmetric positive definite eigensystem Av = lBv,
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
index 2c782e8..d21602d 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
@@ -2,8 +2,8 @@
2{-# LANGUAGE RecordWildCards #-} 2{-# LANGUAGE RecordWildCards #-}
3 3
4module Numeric.LinearAlgebra.Util.CG( 4module Numeric.LinearAlgebra.Util.CG(
5 cgSolve, 5 cgSolve, cgSolve',
6 CGMat 6 CGMat, CGState(..), R, V
7) where 7) where
8 8
9import Numeric.Container 9import Numeric.Container
@@ -16,23 +16,23 @@ import Util.Misc(debug, debugMat)
16infix 0 // -- , /// 16infix 0 // -- , ///
17a // b = debug b id a 17a // b = debug b id a
18 18
19(///) :: DV -> String -> DV 19(///) :: V -> String -> V
20infix 0 /// 20infix 0 ///
21v /// b = debugMat b 2 asRow v 21v /// b = debugMat b 2 asRow v
22-} 22-}
23 23
24 24type R = Double
25type DV = Vector Double 25type V = Vector R
26 26
27data CGState = CGState 27data CGState = CGState
28 { cgp :: DV 28 { cgp :: V -- ^ conjugate gradient
29 , cgr :: DV 29 , cgr :: V -- ^ residual
30 , cgr2 :: Double 30 , cgr2 :: R -- ^ squared norm of residual
31 , cgx :: DV 31 , cgx :: V -- ^ current solution
32 , cgdx :: Double 32 , cgdx :: R -- ^ normalized size of correction
33 } 33 }
34 34
35cg :: Bool -> (DV -> DV) -> (DV -> DV) -> CGState -> CGState 35cg :: Bool -> (V -> V) -> (V -> V) -> CGState -> CGState
36cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx 36cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx
37 where 37 where
38 ap1 = a p 38 ap1 = a p
@@ -51,16 +51,16 @@ cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx
51 rdx = norm2 dx / max 1 (norm2 x) 51 rdx = norm2 dx / max 1 (norm2 x)
52 52
53conjugrad 53conjugrad
54 :: (Transposable m, Contraction m DV DV) 54 :: (Transposable m, Contraction m V V)
55 => Bool -> m -> DV -> DV -> Double -> Double -> [CGState] 55 => Bool -> m -> V -> V -> R -> R -> [CGState]
56conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b 56conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b
57 57
58solveG 58solveG
59 :: (DV -> DV) -> (DV -> DV) 59 :: (V -> V) -> (V -> V)
60 -> ((DV -> DV) -> (DV -> DV) -> CGState -> CGState) 60 -> ((V -> V) -> (V -> V) -> CGState -> CGState)
61 -> DV 61 -> V
62 -> DV 62 -> V
63 -> Double -> Double 63 -> R -> R
64 -> [CGState] 64 -> [CGState]
65solveG mat ma meth rawb x0' ϵb ϵx 65solveG mat ma meth rawb x0' ϵb ϵx
66 = takeUntil ok . iterate (meth mat ma) $ CGState p0 r0 r20 x0 1 66 = takeUntil ok . iterate (meth mat ma) $ CGState p0 r0 r20 x0 1
@@ -82,15 +82,27 @@ takeUntil q xs = a++ take 1 b
82 where 82 where
83 (a,b) = break q xs 83 (a,b) = break q xs
84 84
85class (Transposable m, Contraction m (Vector Double) (Vector Double)) => CGMat m 85class (Transposable m, Contraction m V V) => CGMat m
86 86
87cgSolve 87cgSolve
88 :: CGMat m 88 :: CGMat m
89 => Bool -- ^ symmetric 89 => Bool -- ^ is symmetric
90 -> Double -- ^ relative tolerance for the residual (e.g. 1E-4) 90 -> m -- ^ coefficient matrix
91 -> Double -- ^ relative tolerance for δx (e.g. 1E-3)
92 -> m -- ^ coefficient matrix
93 -> Vector Double -- ^ right-hand side 91 -> Vector Double -- ^ right-hand side
94 -> Vector Double -- ^ solution 92 -> Vector Double -- ^ solution
95cgSolve sym er es a b = cgx $ last $ conjugrad sym a b 0 er es 93cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0
94 where
95 n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double))
96
97cgSolve'
98 :: CGMat m
99 => Bool -- ^ symmetric
100 -> R -- ^ relative tolerance for the residual (e.g. 1E-4)
101 -> R -- ^ relative tolerance for δx (e.g. 1E-3)
102 -> Int -- ^ maximum number of iterations
103 -> m -- ^ coefficient matrix
104 -> V -- ^ initial solution
105 -> V -- ^ right-hand side
106 -> [CGState] -- ^ solution
107cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es
96 108
diff --git a/packages/base/src/Numeric/Sparse.hs b/packages/base/src/Numeric/Sparse.hs
index 3835590..1957d3a 100644
--- a/packages/base/src/Numeric/Sparse.hs
+++ b/packages/base/src/Numeric/Sparse.hs
@@ -17,7 +17,8 @@ import Control.Arrow((***))
17import Control.Monad(when) 17import Control.Monad(when)
18import Data.List(groupBy, sort) 18import Data.List(groupBy, sort)
19import Foreign.C.Types(CInt(..)) 19import Foreign.C.Types(CInt(..))
20import Numeric.LinearAlgebra.Util.CG(CGMat) 20import Numeric.LinearAlgebra.Util.CG(CGMat,cgSolve)
21import Numeric.LinearAlgebra.Algorithms(linearSolveLS, relativeError, NormType(..))
21import Data.Packed.Development 22import Data.Packed.Development
22import System.IO.Unsafe(unsafePerformIO) 23import System.IO.Unsafe(unsafePerformIO)
23import Foreign(Ptr) 24import Foreign(Ptr)
@@ -150,12 +151,13 @@ instance Testable SMatrix
150 x1 = vect [1..20] 151 x1 = vect [1..20]
151 x2 = vect [1..40] 152 x2 = vect [1..40]
152 sm = mkCSR sma 153 sm = mkCSR sma
154 dm = toDense sma
153 155
154 s1 = sm ◇ x1 156 s1 = sm ◇ x1
155 d1 = toDense sma ◇ x1 157 d1 = dm ◇ x1
156 158
157 s2 = tr sm ◇ x2 159 s2 = tr sm ◇ x2
158 d2 = tr (toDense sma) ◇ x2 160 d2 = tr dm ◇ x2
159 161
160 sdia = mkDiagR 40 20 (vect [1..10]) 162 sdia = mkDiagR 40 20 (vect [1..10])
161 s3 = sdia ◇ x1 163 s3 = sdia ◇ x1
@@ -164,6 +166,10 @@ instance Testable SMatrix
164 d3 = ddia ◇ x1 166 d3 = ddia ◇ x1
165 d4 = tr ddia ◇ x2 167 d4 = tr ddia ◇ x2
166 168
169 v = testb 40
170 s5 = cgSolve False sm v
171 d5 = denseSolve dm v
172
167 info = do 173 info = do
168 print sm 174 print sm
169 disp (toDense sma) 175 disp (toDense sma)
@@ -171,11 +177,14 @@ instance Testable SMatrix
171 print s2; print d2 177 print s2; print d2
172 print s3; print d3 178 print s3; print d3
173 print s4; print d4 179 print s4; print d4
180 print s5; print d5
181 print $ relativeError Infinity s5 d5
174 182
175 ok = s1==d1 183 ok = s1==d1
176 && s2==d2 184 && s2==d2
177 && s3==d3 185 && s3==d3
178 && s4==d4 186 && s4==d4
187 && relativeError Infinity s5 d5 < 1E-10
179 188
180 disp = putStr . dispf 2 189 disp = putStr . dispf 2
181 190
@@ -189,4 +198,8 @@ instance Testable SMatrix
189 where 198 where
190 m1 = convomat n k 199 m1 = convomat n k
191 m2 = map (((+n) *** id) *** id) m1 200 m2 = map (((+n) *** id) *** id) m1
201
202 testb n = vect $ take n $ cycle ([0..10]++[9,8..1])
203
204 denseSolve a = flatten . linearSolveLS a . asColumn
192 205
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs
index 657689a..423edaa 100644
--- a/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs
+++ b/packages/tests/src/Numeric/LinearAlgebra/Tests/Properties.hs
@@ -54,19 +54,9 @@ import Test.QuickCheck(Arbitrary,arbitrary,coarbitrary,choose,vector
54trivial :: Testable a => Bool -> a -> Property 54trivial :: Testable a => Bool -> a -> Property
55trivial = (`classify` "trivial") 55trivial = (`classify` "trivial")
56 56
57
58-- relative error 57-- relative error
59dist :: (Normed c t, Num (c t)) => c t -> c t -> Double 58dist :: (Normed c t, Num (c t)) => c t -> c t -> Double
60dist a b = realToFrac r 59dist = relativeError Infinity
61 where norm = pnorm Infinity
62 na = norm a
63 nb = norm b
64 nab = norm (a-b)
65 mx = max na nb
66 mn = min na nb
67 r = if mn < peps
68 then mx
69 else nab/mx
70 60
71infixl 4 |~| 61infixl 4 |~|
72a |~| b = a :~10~: b 62a |~| b = a :~10~: b