summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2014-05-22 20:09:41 +0200
committerAlberto Ruiz <aruiz@um.es>2014-05-22 20:09:41 +0200
commit85af0a1d5ba2d1c03f05458f9689195e82f6ae7e (patch)
tree07fce2a4b912b85c321e8b1175b52efddc1c4fcb /packages/base/src/Numeric
parentb5125366953a6ae66ff014b736baf79c0feb47dd (diff)
cgSolve
Diffstat (limited to 'packages/base/src/Numeric')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Data.hs2
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Util/CG.hs96
-rw-r--r--packages/base/src/Numeric/Sparse.hs192
3 files changed, 289 insertions, 1 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Data.hs b/packages/base/src/Numeric/LinearAlgebra/Data.hs
index 49bc1c0..e3cbe31 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Data.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Data.hs
@@ -69,5 +69,5 @@ import Data.Packed.Matrix
69import Numeric.Container 69import Numeric.Container
70import Numeric.LinearAlgebra.Util 70import Numeric.LinearAlgebra.Util
71import Data.Complex 71import Data.Complex
72import Data.Packed.Internal.Sparse 72import Numeric.Sparse
73 73
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
new file mode 100644
index 0000000..2c782e8
--- /dev/null
+++ b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
@@ -0,0 +1,96 @@
1{-# LANGUAGE FlexibleContexts, FlexibleInstances #-}
2{-# LANGUAGE RecordWildCards #-}
3
4module Numeric.LinearAlgebra.Util.CG(
5 cgSolve,
6 CGMat
7) where
8
9import Numeric.Container
10import Numeric.Vector()
11
12{-
13import Util.Misc(debug, debugMat)
14
15(//) :: Show a => a -> String -> a
16infix 0 // -- , ///
17a // b = debug b id a
18
19(///) :: DV -> String -> DV
20infix 0 ///
21v /// b = debugMat b 2 asRow v
22-}
23
24
25type DV = Vector Double
26
27data CGState = CGState
28 { cgp :: DV
29 , cgr :: DV
30 , cgr2 :: Double
31 , cgx :: DV
32 , cgdx :: Double
33 }
34
35cg :: Bool -> (DV -> DV) -> (DV -> DV) -> CGState -> CGState
36cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx
37 where
38 ap1 = a p
39 ap | sym = ap1
40 | otherwise = at ap1
41 pap | sym = p ◇ ap1
42 | otherwise = norm2 ap1 ** 2
43 alpha = r2 / pap
44 dx = scale alpha p
45 x' = x + dx
46 r' = r - scale alpha ap
47 r'2 = r' ◇ r'
48 beta = r'2 / r2
49 p' = r' + scale beta p
50
51 rdx = norm2 dx / max 1 (norm2 x)
52
53conjugrad
54 :: (Transposable m, Contraction m DV DV)
55 => Bool -> m -> DV -> DV -> Double -> Double -> [CGState]
56conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b
57
58solveG
59 :: (DV -> DV) -> (DV -> DV)
60 -> ((DV -> DV) -> (DV -> DV) -> CGState -> CGState)
61 -> DV
62 -> DV
63 -> Double -> Double
64 -> [CGState]
65solveG mat ma meth rawb x0' ϵb ϵx
66 = takeUntil ok . iterate (meth mat ma) $ CGState p0 r0 r20 x0 1
67 where
68 a = mat . ma
69 b = mat rawb
70 x0 = if x0' == 0 then konst 0 (dim b) else x0'
71 r0 = b - a x0
72 r20 = r0 ◇ r0
73 p0 = r0
74 nb2 = b ◇ b
75 ok CGState {..}
76 = cgr2 <nb2*ϵb**2
77 || cgdx < ϵx
78
79
80takeUntil :: (a -> Bool) -> [a] -> [a]
81takeUntil q xs = a++ take 1 b
82 where
83 (a,b) = break q xs
84
85class (Transposable m, Contraction m (Vector Double) (Vector Double)) => CGMat m
86
87cgSolve
88 :: CGMat m
89 => Bool -- ^ symmetric
90 -> Double -- ^ relative tolerance for the residual (e.g. 1E-4)
91 -> Double -- ^ relative tolerance for δx (e.g. 1E-3)
92 -> m -- ^ coefficient matrix
93 -> Vector Double -- ^ right-hand side
94 -> Vector Double -- ^ solution
95cgSolve sym er es a b = cgx $ last $ conjugrad sym a b 0 er es
96
diff --git a/packages/base/src/Numeric/Sparse.hs b/packages/base/src/Numeric/Sparse.hs
new file mode 100644
index 0000000..3835590
--- /dev/null
+++ b/packages/base/src/Numeric/Sparse.hs
@@ -0,0 +1,192 @@
1{-# LANGUAGE RecordWildCards #-}
2{-# LANGUAGE MultiParamTypeClasses #-}
3{-# LANGUAGE FlexibleInstances #-}
4
5module Numeric.Sparse(
6 SMatrix(..),
7 mkCSR, mkDiag,
8 AssocMatrix,
9 toDense,
10 smXv
11)where
12
13import Numeric.Container
14import qualified Data.Vector.Storable as V
15import Data.Function(on)
16import Control.Arrow((***))
17import Control.Monad(when)
18import Data.List(groupBy, sort)
19import Foreign.C.Types(CInt(..))
20import Numeric.LinearAlgebra.Util.CG(CGMat)
21import Data.Packed.Development
22import System.IO.Unsafe(unsafePerformIO)
23import Foreign(Ptr)
24import Text.Printf(printf)
25
26infixl 0 ~!~
27c ~!~ msg = when c (error msg)
28
29type AssocMatrix = [((Int,Int),Double)]
30
31data SMatrix
32 = CSR
33 { csrVals :: Vector Double
34 , csrCols :: Vector CInt
35 , csrRows :: Vector CInt
36 , nRows :: Int
37 , nCols :: Int
38 }
39 | CSC
40 { cscVals :: Vector Double
41 , cscRows :: Vector CInt
42 , cscCols :: Vector CInt
43 , nRows :: Int
44 , nCols :: Int
45 }
46 | Diag
47 { diagVals :: Vector Double
48 , nRows :: Int
49 , nCols :: Int
50 }
51-- | Banded
52 deriving Show
53
54mkCSR :: AssocMatrix -> SMatrix
55mkCSR sm' = CSR{..}
56 where
57 sm = sort sm'
58 rws = map ((fromList *** fromList)
59 . unzip
60 . map ((succ.fi.snd) *** id)
61 )
62 . groupBy ((==) `on` (fst.fst))
63 $ sm
64 rszs = map (fi . dim . fst) rws
65 csrRows = fromList (scanl (+) 1 rszs)
66 csrVals = vjoin (map snd rws)
67 csrCols = vjoin (map fst rws)
68 nRows = dim csrRows - 1
69 nCols = fromIntegral (V.maximum csrCols)
70
71
72mkDiagR r c v
73 | dim v <= min r c = Diag{..}
74 | otherwise = error $ printf "mkDiagR: incorrect sizes (%d,%d) [%d]" r c (dim v)
75 where
76 nRows = r
77 nCols = c
78 diagVals = v
79
80mkDiag v = mkDiagR (dim v) (dim v) v
81
82
83type IV t = CInt -> Ptr CInt -> t
84type V t = CInt -> Ptr Double -> t
85type SMxV = V (IV (IV (V (V (IO CInt)))))
86
87smXv :: SMatrix -> Vector Double -> Vector Double
88smXv CSR{..} v = unsafePerformIO $ do
89 dim v /= nCols ~!~ printf "smXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v)
90 r <- createVector nRows
91 app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv"
92 return r
93
94smXv CSC{..} v = unsafePerformIO $ do
95 dim v /= nCols ~!~ printf "smXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v)
96 r <- createVector nRows
97 app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv"
98 return r
99
100smXv Diag{..} v
101 | dim v == nCols
102 = vjoin [ subVector 0 (dim diagVals) v `mul` diagVals
103 , konst 0 (nRows - dim diagVals) ]
104 | otherwise = error $ printf "smXv (Diag): incorrect sizes: (%d,%d) [%d] x %d"
105 nRows nCols (dim diagVals) (dim v)
106
107
108instance Contraction SMatrix (Vector Double) (Vector Double)
109 where
110 contraction = smXv
111
112--------------------------------------------------------------------------------
113
114foreign import ccall unsafe "smXv"
115 c_smXv :: SMxV
116
117foreign import ccall unsafe "smTXv"
118 c_smTXv :: SMxV
119
120--------------------------------------------------------------------------------
121
122toDense :: AssocMatrix -> Matrix Double
123toDense asm = assoc (r+1,c+1) 0 asm
124 where
125 (r,c) = (maximum *** maximum) . unzip . map fst $ asm
126
127
128
129instance Transposable SMatrix
130 where
131 tr (CSR vs cs rs n m) = CSC vs cs rs m n
132 tr (CSC vs rs cs n m) = CSR vs rs cs m n
133 tr (Diag v n m) = Diag v m n
134
135instance Transposable (Matrix Double)
136 where
137 tr = trans
138
139
140instance CGMat SMatrix
141instance CGMat (Matrix Double)
142
143--------------------------------------------------------------------------------
144
145instance Testable SMatrix
146 where
147 checkT _ = (ok,info)
148 where
149 sma = convo2 20 3
150 x1 = vect [1..20]
151 x2 = vect [1..40]
152 sm = mkCSR sma
153
154 s1 = sm ◇ x1
155 d1 = toDense sma ◇ x1
156
157 s2 = tr sm ◇ x2
158 d2 = tr (toDense sma) ◇ x2
159
160 sdia = mkDiagR 40 20 (vect [1..10])
161 s3 = sdia ◇ x1
162 s4 = tr sdia ◇ x2
163 ddia = diagRect 0 (vect [1..10]) 40 20
164 d3 = ddia ◇ x1
165 d4 = tr ddia ◇ x2
166
167 info = do
168 print sm
169 disp (toDense sma)
170 print s1; print d1
171 print s2; print d2
172 print s3; print d3
173 print s4; print d4
174
175 ok = s1==d1
176 && s2==d2
177 && s3==d3
178 && s4==d4
179
180 disp = putStr . dispf 2
181
182 vect = fromList :: [Double] -> Vector Double
183
184 convomat :: Int -> Int -> AssocMatrix
185 convomat n k = [ ((i,j `mod` n),1) | i<-[0..n-1], j <- [i..i+k-1]]
186
187 convo2 :: Int -> Int -> AssocMatrix
188 convo2 n k = m1 ++ m2
189 where
190 m1 = convomat n k
191 m2 = map (((+n) *** id) *** id) m1
192