diff options
author | Alberto Ruiz <aruiz@um.es> | 2014-05-27 10:41:40 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2014-05-27 10:41:40 +0200 |
commit | cf3c788f0c44577ac1a5365e8154200b53a36409 (patch) | |
tree | d667ea10609e74b69b11309bb59b7e000b240a92 /packages/base/src/Numeric/LinearAlgebra/Util | |
parent | 365e2435e71de10ebe849acac5a107b6f43817c4 (diff) |
static dimensions, cont.
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra/Util')
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Util/CG.hs | 86 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Util/Static.hs | 70 |
2 files changed, 75 insertions, 81 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs index 5e2ea84..50372f1 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs | |||
@@ -3,11 +3,14 @@ | |||
3 | 3 | ||
4 | module Numeric.LinearAlgebra.Util.CG( | 4 | module Numeric.LinearAlgebra.Util.CG( |
5 | cgSolve, cgSolve', | 5 | cgSolve, cgSolve', |
6 | CGMat, CGState(..), R, V | 6 | CGState(..), R, V |
7 | ) where | 7 | ) where |
8 | 8 | ||
9 | import Data.Packed.Numeric | 9 | import Data.Packed.Numeric |
10 | import Numeric.Sparse | ||
10 | import Numeric.Vector() | 11 | import Numeric.Vector() |
12 | import Numeric.LinearAlgebra.Algorithms(linearSolveLS, relativeError, NormType(..)) | ||
13 | import Control.Arrow((***)) | ||
11 | 14 | ||
12 | {- | 15 | {- |
13 | import Util.Misc(debug, debugMat) | 16 | import Util.Misc(debug, debugMat) |
@@ -51,7 +54,7 @@ cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx | |||
51 | rdx = norm2 dx / max 1 (norm2 x) | 54 | rdx = norm2 dx / max 1 (norm2 x) |
52 | 55 | ||
53 | conjugrad | 56 | conjugrad |
54 | :: (Transposable m, Contraction m V V) | 57 | :: (Transposable m mt, Contraction m V V, Contraction mt V V) |
55 | => Bool -> m -> V -> V -> R -> R -> [CGState] | 58 | => Bool -> m -> V -> V -> R -> R -> [CGState] |
56 | conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b | 59 | conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b |
57 | 60 | ||
@@ -82,27 +85,88 @@ takeUntil q xs = a++ take 1 b | |||
82 | where | 85 | where |
83 | (a,b) = break q xs | 86 | (a,b) = break q xs |
84 | 87 | ||
85 | class (Transposable m, Contraction m V V) => CGMat m | ||
86 | |||
87 | cgSolve | 88 | cgSolve |
88 | :: CGMat m | 89 | :: Bool -- ^ is symmetric |
89 | => Bool -- ^ is symmetric | 90 | -> GMatrix -- ^ coefficient matrix |
90 | -> m -- ^ coefficient matrix | ||
91 | -> Vector Double -- ^ right-hand side | 91 | -> Vector Double -- ^ right-hand side |
92 | -> Vector Double -- ^ solution | 92 | -> Vector Double -- ^ solution |
93 | cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0 | 93 | cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0 |
94 | where | 94 | where |
95 | n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double)) | 95 | n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double)) |
96 | 96 | ||
97 | cgSolve' | 97 | cgSolve' |
98 | :: CGMat m | 98 | :: Bool -- ^ symmetric |
99 | => Bool -- ^ symmetric | ||
100 | -> R -- ^ relative tolerance for the residual (e.g. 1E-4) | 99 | -> R -- ^ relative tolerance for the residual (e.g. 1E-4) |
101 | -> R -- ^ relative tolerance for δx (e.g. 1E-3) | 100 | -> R -- ^ relative tolerance for δx (e.g. 1E-3) |
102 | -> Int -- ^ maximum number of iterations | 101 | -> Int -- ^ maximum number of iterations |
103 | -> m -- ^ coefficient matrix | 102 | -> GMatrix -- ^ coefficient matrix |
104 | -> V -- ^ initial solution | 103 | -> V -- ^ initial solution |
105 | -> V -- ^ right-hand side | 104 | -> V -- ^ right-hand side |
106 | -> [CGState] -- ^ solution | 105 | -> [CGState] -- ^ solution |
107 | cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es | 106 | cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es |
108 | 107 | ||
108 | |||
109 | -------------------------------------------------------------------------------- | ||
110 | |||
111 | instance Testable GMatrix | ||
112 | where | ||
113 | checkT _ = (ok,info) | ||
114 | where | ||
115 | sma = convo2 20 3 | ||
116 | x1 = vect [1..20] | ||
117 | x2 = vect [1..40] | ||
118 | sm = mkSparse sma | ||
119 | dm = toDense sma | ||
120 | |||
121 | s1 = sm !#> x1 | ||
122 | d1 = dm #> x1 | ||
123 | |||
124 | s2 = tr sm !#> x2 | ||
125 | d2 = tr dm #> x2 | ||
126 | |||
127 | sdia = mkDiagR 40 20 (vect [1..10]) | ||
128 | s3 = sdia !#> x1 | ||
129 | s4 = tr sdia !#> x2 | ||
130 | ddia = diagRect 0 (vect [1..10]) 40 20 | ||
131 | d3 = ddia #> x1 | ||
132 | d4 = tr ddia #> x2 | ||
133 | |||
134 | v = testb 40 | ||
135 | s5 = cgSolve False sm v | ||
136 | d5 = denseSolve dm v | ||
137 | |||
138 | info = do | ||
139 | print sm | ||
140 | disp (toDense sma) | ||
141 | print s1; print d1 | ||
142 | print s2; print d2 | ||
143 | print s3; print d3 | ||
144 | print s4; print d4 | ||
145 | print s5; print d5 | ||
146 | print $ relativeError Infinity s5 d5 | ||
147 | |||
148 | ok = s1==d1 | ||
149 | && s2==d2 | ||
150 | && s3==d3 | ||
151 | && s4==d4 | ||
152 | && relativeError Infinity s5 d5 < 1E-10 | ||
153 | |||
154 | disp = putStr . dispf 2 | ||
155 | |||
156 | vect = fromList :: [Double] -> Vector Double | ||
157 | |||
158 | convomat :: Int -> Int -> AssocMatrix | ||
159 | convomat n k = [ ((i,j `mod` n),1) | i<-[0..n-1], j <- [i..i+k-1]] | ||
160 | |||
161 | convo2 :: Int -> Int -> AssocMatrix | ||
162 | convo2 n k = m1 ++ m2 | ||
163 | where | ||
164 | m1 = convomat n k | ||
165 | m2 = map (((+n) *** id) *** id) m1 | ||
166 | |||
167 | testb n = vect $ take n $ cycle ([0..10]++[9,8..1]) | ||
168 | |||
169 | denseSolve a = flatten . linearSolveLS a . asColumn | ||
170 | |||
171 | -- mkDiag v = mkDiagR (dim v) (dim v) v | ||
172 | |||
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Util/Static.hs deleted file mode 100644 index a3f8eb0..0000000 --- a/packages/base/src/Numeric/LinearAlgebra/Util/Static.hs +++ /dev/null | |||
@@ -1,70 +0,0 @@ | |||
1 | {-# LANGUAGE DataKinds #-} | ||
2 | {-# LANGUAGE KindSignatures #-} | ||
3 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} | ||
4 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
5 | {-# LANGUAGE FlexibleContexts #-} | ||
6 | {-# LANGUAGE ScopedTypeVariables #-} | ||
7 | {-# LANGUAGE EmptyDataDecls #-} | ||
8 | {-# LANGUAGE Rank2Types #-} | ||
9 | {-# LANGUAGE FlexibleInstances #-} | ||
10 | {-# LANGUAGE TypeOperators #-} | ||
11 | |||
12 | module Numeric.LinearAlgebra.Util.Static( | ||
13 | Static (ddata), | ||
14 | R, | ||
15 | vect0, sScalar, vect2, vect3, (&) | ||
16 | ) where | ||
17 | |||
18 | |||
19 | import GHC.TypeLits | ||
20 | import Data.Packed.Numeric | ||
21 | import Numeric.Vector() | ||
22 | import Numeric.LinearAlgebra.Util(Numeric,ℝ) | ||
23 | |||
24 | lift1F :: (Vector t -> Vector t) -> Static n (Vector t) -> Static n (Vector t) | ||
25 | lift1F f (Static v) = Static (f v) | ||
26 | |||
27 | lift2F :: (Vector t -> Vector t -> Vector t) -> Static n (Vector t) -> Static n (Vector t) -> Static n (Vector t) | ||
28 | lift2F f (Static u) (Static v) = Static (f u v) | ||
29 | |||
30 | newtype Static (n :: Nat) t = Static { ddata :: t } deriving Show | ||
31 | |||
32 | type R n = Static n (Vector ℝ) | ||
33 | |||
34 | |||
35 | infixl 4 & | ||
36 | (&) :: R n -> ℝ -> R (n+1) | ||
37 | Static v & x = Static (vjoin [v, scalar x]) | ||
38 | |||
39 | vect0 :: R 0 | ||
40 | vect0 = Static (fromList[]) | ||
41 | |||
42 | sScalar :: ℝ -> R 1 | ||
43 | sScalar = Static . scalar | ||
44 | |||
45 | |||
46 | vect2 :: ℝ -> ℝ -> R 2 | ||
47 | vect2 x1 x2 = Static (fromList [x1,x2]) | ||
48 | |||
49 | vect3 :: ℝ -> ℝ -> ℝ -> R 3 | ||
50 | vect3 x1 x2 x3 = Static (fromList [x1,x2,x3]) | ||
51 | |||
52 | |||
53 | |||
54 | |||
55 | |||
56 | |||
57 | instance forall n t . (KnownNat n, Num (Vector t), Numeric t )=> Num (Static n (Vector t)) | ||
58 | where | ||
59 | (+) = lift2F add | ||
60 | (*) = lift2F mul | ||
61 | (-) = lift2F sub | ||
62 | abs = lift1F abs | ||
63 | signum = lift1F signum | ||
64 | negate = lift1F (scale (-1)) | ||
65 | fromInteger x = Static (konst (fromInteger x) d) | ||
66 | where | ||
67 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
68 | |||
69 | data Proxy :: Nat -> * | ||
70 | |||