diff options
Diffstat (limited to 'packages')
-rw-r--r-- | packages/base/hmatrix.cabal | 4 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Data.hs | 2 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Util/CG.hs | 96 | ||||
-rw-r--r-- | packages/base/src/Numeric/Sparse.hs (renamed from packages/base/src/Data/Packed/Internal/Sparse.hs) | 11 |
4 files changed, 106 insertions, 7 deletions
diff --git a/packages/base/hmatrix.cabal b/packages/base/hmatrix.cabal index 06c1a3c..8f6b04b 100644 --- a/packages/base/hmatrix.cabal +++ b/packages/base/hmatrix.cabal | |||
@@ -48,7 +48,7 @@ library | |||
48 | Numeric.LinearAlgebra.Data | 48 | Numeric.LinearAlgebra.Data |
49 | 49 | ||
50 | Numeric.LinearAlgebra.Compat | 50 | Numeric.LinearAlgebra.Compat |
51 | Data.Packed.Internal.Sparse | 51 | |
52 | 52 | ||
53 | other-modules: Data.Packed.Internal, | 53 | other-modules: Data.Packed.Internal, |
54 | Data.Packed.Internal.Common, | 54 | Data.Packed.Internal.Common, |
@@ -62,8 +62,10 @@ library | |||
62 | Numeric.Matrix | 62 | Numeric.Matrix |
63 | Data.Packed.Internal.Numeric | 63 | Data.Packed.Internal.Numeric |
64 | Numeric.LinearAlgebra.Util.Convolution | 64 | Numeric.LinearAlgebra.Util.Convolution |
65 | Numeric.LinearAlgebra.Util.CG | ||
65 | Numeric.LinearAlgebra.Random | 66 | Numeric.LinearAlgebra.Random |
66 | Numeric.Conversion | 67 | Numeric.Conversion |
68 | Numeric.Sparse | ||
67 | 69 | ||
68 | C-sources: src/C/lapack-aux.c | 70 | C-sources: src/C/lapack-aux.c |
69 | src/C/vector-aux.c | 71 | src/C/vector-aux.c |
diff --git a/packages/base/src/Numeric/LinearAlgebra/Data.hs b/packages/base/src/Numeric/LinearAlgebra/Data.hs index 49bc1c0..e3cbe31 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Data.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Data.hs | |||
@@ -69,5 +69,5 @@ import Data.Packed.Matrix | |||
69 | import Numeric.Container | 69 | import Numeric.Container |
70 | import Numeric.LinearAlgebra.Util | 70 | import Numeric.LinearAlgebra.Util |
71 | import Data.Complex | 71 | import Data.Complex |
72 | import Data.Packed.Internal.Sparse | 72 | import Numeric.Sparse |
73 | 73 | ||
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs new file mode 100644 index 0000000..2c782e8 --- /dev/null +++ b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs | |||
@@ -0,0 +1,96 @@ | |||
1 | {-# LANGUAGE FlexibleContexts, FlexibleInstances #-} | ||
2 | {-# LANGUAGE RecordWildCards #-} | ||
3 | |||
4 | module Numeric.LinearAlgebra.Util.CG( | ||
5 | cgSolve, | ||
6 | CGMat | ||
7 | ) where | ||
8 | |||
9 | import Numeric.Container | ||
10 | import Numeric.Vector() | ||
11 | |||
12 | {- | ||
13 | import Util.Misc(debug, debugMat) | ||
14 | |||
15 | (//) :: Show a => a -> String -> a | ||
16 | infix 0 // -- , /// | ||
17 | a // b = debug b id a | ||
18 | |||
19 | (///) :: DV -> String -> DV | ||
20 | infix 0 /// | ||
21 | v /// b = debugMat b 2 asRow v | ||
22 | -} | ||
23 | |||
24 | |||
25 | type DV = Vector Double | ||
26 | |||
27 | data CGState = CGState | ||
28 | { cgp :: DV | ||
29 | , cgr :: DV | ||
30 | , cgr2 :: Double | ||
31 | , cgx :: DV | ||
32 | , cgdx :: Double | ||
33 | } | ||
34 | |||
35 | cg :: Bool -> (DV -> DV) -> (DV -> DV) -> CGState -> CGState | ||
36 | cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx | ||
37 | where | ||
38 | ap1 = a p | ||
39 | ap | sym = ap1 | ||
40 | | otherwise = at ap1 | ||
41 | pap | sym = p ◇ ap1 | ||
42 | | otherwise = norm2 ap1 ** 2 | ||
43 | alpha = r2 / pap | ||
44 | dx = scale alpha p | ||
45 | x' = x + dx | ||
46 | r' = r - scale alpha ap | ||
47 | r'2 = r' ◇ r' | ||
48 | beta = r'2 / r2 | ||
49 | p' = r' + scale beta p | ||
50 | |||
51 | rdx = norm2 dx / max 1 (norm2 x) | ||
52 | |||
53 | conjugrad | ||
54 | :: (Transposable m, Contraction m DV DV) | ||
55 | => Bool -> m -> DV -> DV -> Double -> Double -> [CGState] | ||
56 | conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b | ||
57 | |||
58 | solveG | ||
59 | :: (DV -> DV) -> (DV -> DV) | ||
60 | -> ((DV -> DV) -> (DV -> DV) -> CGState -> CGState) | ||
61 | -> DV | ||
62 | -> DV | ||
63 | -> Double -> Double | ||
64 | -> [CGState] | ||
65 | solveG mat ma meth rawb x0' ϵb ϵx | ||
66 | = takeUntil ok . iterate (meth mat ma) $ CGState p0 r0 r20 x0 1 | ||
67 | where | ||
68 | a = mat . ma | ||
69 | b = mat rawb | ||
70 | x0 = if x0' == 0 then konst 0 (dim b) else x0' | ||
71 | r0 = b - a x0 | ||
72 | r20 = r0 ◇ r0 | ||
73 | p0 = r0 | ||
74 | nb2 = b ◇ b | ||
75 | ok CGState {..} | ||
76 | = cgr2 <nb2*ϵb**2 | ||
77 | || cgdx < ϵx | ||
78 | |||
79 | |||
80 | takeUntil :: (a -> Bool) -> [a] -> [a] | ||
81 | takeUntil q xs = a++ take 1 b | ||
82 | where | ||
83 | (a,b) = break q xs | ||
84 | |||
85 | class (Transposable m, Contraction m (Vector Double) (Vector Double)) => CGMat m | ||
86 | |||
87 | cgSolve | ||
88 | :: CGMat m | ||
89 | => Bool -- ^ symmetric | ||
90 | -> Double -- ^ relative tolerance for the residual (e.g. 1E-4) | ||
91 | -> Double -- ^ relative tolerance for δx (e.g. 1E-3) | ||
92 | -> m -- ^ coefficient matrix | ||
93 | -> Vector Double -- ^ right-hand side | ||
94 | -> Vector Double -- ^ solution | ||
95 | cgSolve sym er es a b = cgx $ last $ conjugrad sym a b 0 er es | ||
96 | |||
diff --git a/packages/base/src/Data/Packed/Internal/Sparse.hs b/packages/base/src/Numeric/Sparse.hs index 544c913..3835590 100644 --- a/packages/base/src/Data/Packed/Internal/Sparse.hs +++ b/packages/base/src/Numeric/Sparse.hs | |||
@@ -2,9 +2,7 @@ | |||
2 | {-# LANGUAGE MultiParamTypeClasses #-} | 2 | {-# LANGUAGE MultiParamTypeClasses #-} |
3 | {-# LANGUAGE FlexibleInstances #-} | 3 | {-# LANGUAGE FlexibleInstances #-} |
4 | 4 | ||
5 | 5 | module Numeric.Sparse( | |
6 | |||
7 | module Data.Packed.Internal.Sparse( | ||
8 | SMatrix(..), | 6 | SMatrix(..), |
9 | mkCSR, mkDiag, | 7 | mkCSR, mkDiag, |
10 | AssocMatrix, | 8 | AssocMatrix, |
@@ -19,7 +17,8 @@ import Control.Arrow((***)) | |||
19 | import Control.Monad(when) | 17 | import Control.Monad(when) |
20 | import Data.List(groupBy, sort) | 18 | import Data.List(groupBy, sort) |
21 | import Foreign.C.Types(CInt(..)) | 19 | import Foreign.C.Types(CInt(..)) |
22 | import Numeric.LinearAlgebra.Devel | 20 | import Numeric.LinearAlgebra.Util.CG(CGMat) |
21 | import Data.Packed.Development | ||
23 | import System.IO.Unsafe(unsafePerformIO) | 22 | import System.IO.Unsafe(unsafePerformIO) |
24 | import Foreign(Ptr) | 23 | import Foreign(Ptr) |
25 | import Text.Printf(printf) | 24 | import Text.Printf(printf) |
@@ -127,7 +126,7 @@ toDense asm = assoc (r+1,c+1) 0 asm | |||
127 | 126 | ||
128 | 127 | ||
129 | 128 | ||
130 | instance Transposable (SMatrix) | 129 | instance Transposable SMatrix |
131 | where | 130 | where |
132 | tr (CSR vs cs rs n m) = CSC vs cs rs m n | 131 | tr (CSR vs cs rs n m) = CSC vs cs rs m n |
133 | tr (CSC vs rs cs n m) = CSR vs rs cs m n | 132 | tr (CSC vs rs cs n m) = CSR vs rs cs m n |
@@ -138,6 +137,8 @@ instance Transposable (Matrix Double) | |||
138 | tr = trans | 137 | tr = trans |
139 | 138 | ||
140 | 139 | ||
140 | instance CGMat SMatrix | ||
141 | instance CGMat (Matrix Double) | ||
141 | 142 | ||
142 | -------------------------------------------------------------------------------- | 143 | -------------------------------------------------------------------------------- |
143 | 144 | ||