summaryrefslogtreecommitdiff
path: root/packages/base/src
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2014-05-22 20:09:41 +0200
committerAlberto Ruiz <aruiz@um.es>2014-05-22 20:09:41 +0200
commit85af0a1d5ba2d1c03f05458f9689195e82f6ae7e (patch)
tree07fce2a4b912b85c321e8b1175b52efddc1c4fcb /packages/base/src
parentb5125366953a6ae66ff014b736baf79c0feb47dd (diff)
cgSolve
Diffstat (limited to 'packages/base/src')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Data.hs2
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Util/CG.hs96
-rw-r--r--packages/base/src/Numeric/Sparse.hs (renamed from packages/base/src/Data/Packed/Internal/Sparse.hs)11
3 files changed, 103 insertions, 6 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Data.hs b/packages/base/src/Numeric/LinearAlgebra/Data.hs
index 49bc1c0..e3cbe31 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Data.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Data.hs
@@ -69,5 +69,5 @@ import Data.Packed.Matrix
69import Numeric.Container 69import Numeric.Container
70import Numeric.LinearAlgebra.Util 70import Numeric.LinearAlgebra.Util
71import Data.Complex 71import Data.Complex
72import Data.Packed.Internal.Sparse 72import Numeric.Sparse
73 73
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
new file mode 100644
index 0000000..2c782e8
--- /dev/null
+++ b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
@@ -0,0 +1,96 @@
1{-# LANGUAGE FlexibleContexts, FlexibleInstances #-}
2{-# LANGUAGE RecordWildCards #-}
3
4module Numeric.LinearAlgebra.Util.CG(
5 cgSolve,
6 CGMat
7) where
8
9import Numeric.Container
10import Numeric.Vector()
11
12{-
13import Util.Misc(debug, debugMat)
14
15(//) :: Show a => a -> String -> a
16infix 0 // -- , ///
17a // b = debug b id a
18
19(///) :: DV -> String -> DV
20infix 0 ///
21v /// b = debugMat b 2 asRow v
22-}
23
24
25type DV = Vector Double
26
27data CGState = CGState
28 { cgp :: DV
29 , cgr :: DV
30 , cgr2 :: Double
31 , cgx :: DV
32 , cgdx :: Double
33 }
34
35cg :: Bool -> (DV -> DV) -> (DV -> DV) -> CGState -> CGState
36cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx
37 where
38 ap1 = a p
39 ap | sym = ap1
40 | otherwise = at ap1
41 pap | sym = p ◇ ap1
42 | otherwise = norm2 ap1 ** 2
43 alpha = r2 / pap
44 dx = scale alpha p
45 x' = x + dx
46 r' = r - scale alpha ap
47 r'2 = r' ◇ r'
48 beta = r'2 / r2
49 p' = r' + scale beta p
50
51 rdx = norm2 dx / max 1 (norm2 x)
52
53conjugrad
54 :: (Transposable m, Contraction m DV DV)
55 => Bool -> m -> DV -> DV -> Double -> Double -> [CGState]
56conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b
57
58solveG
59 :: (DV -> DV) -> (DV -> DV)
60 -> ((DV -> DV) -> (DV -> DV) -> CGState -> CGState)
61 -> DV
62 -> DV
63 -> Double -> Double
64 -> [CGState]
65solveG mat ma meth rawb x0' ϵb ϵx
66 = takeUntil ok . iterate (meth mat ma) $ CGState p0 r0 r20 x0 1
67 where
68 a = mat . ma
69 b = mat rawb
70 x0 = if x0' == 0 then konst 0 (dim b) else x0'
71 r0 = b - a x0
72 r20 = r0 ◇ r0
73 p0 = r0
74 nb2 = b ◇ b
75 ok CGState {..}
76 = cgr2 <nb2*ϵb**2
77 || cgdx < ϵx
78
79
80takeUntil :: (a -> Bool) -> [a] -> [a]
81takeUntil q xs = a++ take 1 b
82 where
83 (a,b) = break q xs
84
85class (Transposable m, Contraction m (Vector Double) (Vector Double)) => CGMat m
86
87cgSolve
88 :: CGMat m
89 => Bool -- ^ symmetric
90 -> Double -- ^ relative tolerance for the residual (e.g. 1E-4)
91 -> Double -- ^ relative tolerance for δx (e.g. 1E-3)
92 -> m -- ^ coefficient matrix
93 -> Vector Double -- ^ right-hand side
94 -> Vector Double -- ^ solution
95cgSolve sym er es a b = cgx $ last $ conjugrad sym a b 0 er es
96
diff --git a/packages/base/src/Data/Packed/Internal/Sparse.hs b/packages/base/src/Numeric/Sparse.hs
index 544c913..3835590 100644
--- a/packages/base/src/Data/Packed/Internal/Sparse.hs
+++ b/packages/base/src/Numeric/Sparse.hs
@@ -2,9 +2,7 @@
2{-# LANGUAGE MultiParamTypeClasses #-} 2{-# LANGUAGE MultiParamTypeClasses #-}
3{-# LANGUAGE FlexibleInstances #-} 3{-# LANGUAGE FlexibleInstances #-}
4 4
5 5module Numeric.Sparse(
6
7module Data.Packed.Internal.Sparse(
8 SMatrix(..), 6 SMatrix(..),
9 mkCSR, mkDiag, 7 mkCSR, mkDiag,
10 AssocMatrix, 8 AssocMatrix,
@@ -19,7 +17,8 @@ import Control.Arrow((***))
19import Control.Monad(when) 17import Control.Monad(when)
20import Data.List(groupBy, sort) 18import Data.List(groupBy, sort)
21import Foreign.C.Types(CInt(..)) 19import Foreign.C.Types(CInt(..))
22import Numeric.LinearAlgebra.Devel 20import Numeric.LinearAlgebra.Util.CG(CGMat)
21import Data.Packed.Development
23import System.IO.Unsafe(unsafePerformIO) 22import System.IO.Unsafe(unsafePerformIO)
24import Foreign(Ptr) 23import Foreign(Ptr)
25import Text.Printf(printf) 24import Text.Printf(printf)
@@ -127,7 +126,7 @@ toDense asm = assoc (r+1,c+1) 0 asm
127 126
128 127
129 128
130instance Transposable (SMatrix) 129instance Transposable SMatrix
131 where 130 where
132 tr (CSR vs cs rs n m) = CSC vs cs rs m n 131 tr (CSR vs cs rs n m) = CSC vs cs rs m n
133 tr (CSC vs rs cs n m) = CSR vs rs cs m n 132 tr (CSC vs rs cs n m) = CSR vs rs cs m n
@@ -138,6 +137,8 @@ instance Transposable (Matrix Double)
138 tr = trans 137 tr = trans
139 138
140 139
140instance CGMat SMatrix
141instance CGMat (Matrix Double)
141 142
142-------------------------------------------------------------------------------- 143--------------------------------------------------------------------------------
143 144