summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric/LinearAlgebra/Util
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2014-05-27 10:41:40 +0200
committerAlberto Ruiz <aruiz@um.es>2014-05-27 10:41:40 +0200
commitcf3c788f0c44577ac1a5365e8154200b53a36409 (patch)
treed667ea10609e74b69b11309bb59b7e000b240a92 /packages/base/src/Numeric/LinearAlgebra/Util
parent365e2435e71de10ebe849acac5a107b6f43817c4 (diff)
static dimensions, cont.
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra/Util')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Util/CG.hs86
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Util/Static.hs70
2 files changed, 75 insertions, 81 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
index 5e2ea84..50372f1 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
@@ -3,11 +3,14 @@
3 3
4module Numeric.LinearAlgebra.Util.CG( 4module Numeric.LinearAlgebra.Util.CG(
5 cgSolve, cgSolve', 5 cgSolve, cgSolve',
6 CGMat, CGState(..), R, V 6 CGState(..), R, V
7) where 7) where
8 8
9import Data.Packed.Numeric 9import Data.Packed.Numeric
10import Numeric.Sparse
10import Numeric.Vector() 11import Numeric.Vector()
12import Numeric.LinearAlgebra.Algorithms(linearSolveLS, relativeError, NormType(..))
13import Control.Arrow((***))
11 14
12{- 15{-
13import Util.Misc(debug, debugMat) 16import Util.Misc(debug, debugMat)
@@ -51,7 +54,7 @@ cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx
51 rdx = norm2 dx / max 1 (norm2 x) 54 rdx = norm2 dx / max 1 (norm2 x)
52 55
53conjugrad 56conjugrad
54 :: (Transposable m, Contraction m V V) 57 :: (Transposable m mt, Contraction m V V, Contraction mt V V)
55 => Bool -> m -> V -> V -> R -> R -> [CGState] 58 => Bool -> m -> V -> V -> R -> R -> [CGState]
56conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b 59conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b
57 60
@@ -82,27 +85,88 @@ takeUntil q xs = a++ take 1 b
82 where 85 where
83 (a,b) = break q xs 86 (a,b) = break q xs
84 87
85class (Transposable m, Contraction m V V) => CGMat m
86
87cgSolve 88cgSolve
88 :: CGMat m 89 :: Bool -- ^ is symmetric
89 => Bool -- ^ is symmetric 90 -> GMatrix -- ^ coefficient matrix
90 -> m -- ^ coefficient matrix
91 -> Vector Double -- ^ right-hand side 91 -> Vector Double -- ^ right-hand side
92 -> Vector Double -- ^ solution 92 -> Vector Double -- ^ solution
93cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0 93cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0
94 where 94 where
95 n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double)) 95 n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double))
96 96
97cgSolve' 97cgSolve'
98 :: CGMat m 98 :: Bool -- ^ symmetric
99 => Bool -- ^ symmetric
100 -> R -- ^ relative tolerance for the residual (e.g. 1E-4) 99 -> R -- ^ relative tolerance for the residual (e.g. 1E-4)
101 -> R -- ^ relative tolerance for δx (e.g. 1E-3) 100 -> R -- ^ relative tolerance for δx (e.g. 1E-3)
102 -> Int -- ^ maximum number of iterations 101 -> Int -- ^ maximum number of iterations
103 -> m -- ^ coefficient matrix 102 -> GMatrix -- ^ coefficient matrix
104 -> V -- ^ initial solution 103 -> V -- ^ initial solution
105 -> V -- ^ right-hand side 104 -> V -- ^ right-hand side
106 -> [CGState] -- ^ solution 105 -> [CGState] -- ^ solution
107cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es 106cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es
108 107
108
109--------------------------------------------------------------------------------
110
111instance Testable GMatrix
112 where
113 checkT _ = (ok,info)
114 where
115 sma = convo2 20 3
116 x1 = vect [1..20]
117 x2 = vect [1..40]
118 sm = mkSparse sma
119 dm = toDense sma
120
121 s1 = sm !#> x1
122 d1 = dm #> x1
123
124 s2 = tr sm !#> x2
125 d2 = tr dm #> x2
126
127 sdia = mkDiagR 40 20 (vect [1..10])
128 s3 = sdia !#> x1
129 s4 = tr sdia !#> x2
130 ddia = diagRect 0 (vect [1..10]) 40 20
131 d3 = ddia #> x1
132 d4 = tr ddia #> x2
133
134 v = testb 40
135 s5 = cgSolve False sm v
136 d5 = denseSolve dm v
137
138 info = do
139 print sm
140 disp (toDense sma)
141 print s1; print d1
142 print s2; print d2
143 print s3; print d3
144 print s4; print d4
145 print s5; print d5
146 print $ relativeError Infinity s5 d5
147
148 ok = s1==d1
149 && s2==d2
150 && s3==d3
151 && s4==d4
152 && relativeError Infinity s5 d5 < 1E-10
153
154 disp = putStr . dispf 2
155
156 vect = fromList :: [Double] -> Vector Double
157
158 convomat :: Int -> Int -> AssocMatrix
159 convomat n k = [ ((i,j `mod` n),1) | i<-[0..n-1], j <- [i..i+k-1]]
160
161 convo2 :: Int -> Int -> AssocMatrix
162 convo2 n k = m1 ++ m2
163 where
164 m1 = convomat n k
165 m2 = map (((+n) *** id) *** id) m1
166
167 testb n = vect $ take n $ cycle ([0..10]++[9,8..1])
168
169 denseSolve a = flatten . linearSolveLS a . asColumn
170
171 -- mkDiag v = mkDiagR (dim v) (dim v) v
172
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Util/Static.hs
deleted file mode 100644
index a3f8eb0..0000000
--- a/packages/base/src/Numeric/LinearAlgebra/Util/Static.hs
+++ /dev/null
@@ -1,70 +0,0 @@
1{-# LANGUAGE DataKinds #-}
2{-# LANGUAGE KindSignatures #-}
3{-# LANGUAGE GeneralizedNewtypeDeriving #-}
4{-# LANGUAGE MultiParamTypeClasses #-}
5{-# LANGUAGE FlexibleContexts #-}
6{-# LANGUAGE ScopedTypeVariables #-}
7{-# LANGUAGE EmptyDataDecls #-}
8{-# LANGUAGE Rank2Types #-}
9{-# LANGUAGE FlexibleInstances #-}
10{-# LANGUAGE TypeOperators #-}
11
12module Numeric.LinearAlgebra.Util.Static(
13 Static (ddata),
14 R,
15 vect0, sScalar, vect2, vect3, (&)
16) where
17
18
19import GHC.TypeLits
20import Data.Packed.Numeric
21import Numeric.Vector()
22import Numeric.LinearAlgebra.Util(Numeric,ℝ)
23
24lift1F :: (Vector t -> Vector t) -> Static n (Vector t) -> Static n (Vector t)
25lift1F f (Static v) = Static (f v)
26
27lift2F :: (Vector t -> Vector t -> Vector t) -> Static n (Vector t) -> Static n (Vector t) -> Static n (Vector t)
28lift2F f (Static u) (Static v) = Static (f u v)
29
30newtype Static (n :: Nat) t = Static { ddata :: t } deriving Show
31
32type R n = Static n (Vector ℝ)
33
34
35infixl 4 &
36(&) :: R n -> ℝ -> R (n+1)
37Static v & x = Static (vjoin [v, scalar x])
38
39vect0 :: R 0
40vect0 = Static (fromList[])
41
42sScalar :: ℝ -> R 1
43sScalar = Static . scalar
44
45
46vect2 :: ℝ -> ℝ -> R 2
47vect2 x1 x2 = Static (fromList [x1,x2])
48
49vect3 :: ℝ -> ℝ -> ℝ -> R 3
50vect3 x1 x2 x3 = Static (fromList [x1,x2,x3])
51
52
53
54
55
56
57instance forall n t . (KnownNat n, Num (Vector t), Numeric t )=> Num (Static n (Vector t))
58 where
59 (+) = lift2F add
60 (*) = lift2F mul
61 (-) = lift2F sub
62 abs = lift1F abs
63 signum = lift1F signum
64 negate = lift1F (scale (-1))
65 fromInteger x = Static (konst (fromInteger x) d)
66 where
67 d = fromIntegral . natVal $ (undefined :: Proxy n)
68
69data Proxy :: Nat -> *
70