diff options
Diffstat (limited to 'packages/base/src/Numeric/Sparse.hs')
-rw-r--r-- | packages/base/src/Numeric/Sparse.hs | 127 |
1 files changed, 42 insertions, 85 deletions
diff --git a/packages/base/src/Numeric/Sparse.hs b/packages/base/src/Numeric/Sparse.hs index 2df4578..4d05bdc 100644 --- a/packages/base/src/Numeric/Sparse.hs +++ b/packages/base/src/Numeric/Sparse.hs | |||
@@ -3,11 +3,11 @@ | |||
3 | {-# LANGUAGE FlexibleInstances #-} | 3 | {-# LANGUAGE FlexibleInstances #-} |
4 | 4 | ||
5 | module Numeric.Sparse( | 5 | module Numeric.Sparse( |
6 | SMatrix(..), | 6 | GMatrix(..), |
7 | mkCSR, mkDiag, | 7 | mkSparse, mkDiagR, dense, |
8 | AssocMatrix, | 8 | AssocMatrix, |
9 | toDense, | 9 | toDense, |
10 | smXv | 10 | gmXv, (!#>) |
11 | )where | 11 | )where |
12 | 12 | ||
13 | import Data.Packed.Numeric | 13 | import Data.Packed.Numeric |
@@ -17,8 +17,7 @@ import Control.Arrow((***)) | |||
17 | import Control.Monad(when) | 17 | import Control.Monad(when) |
18 | import Data.List(groupBy, sort) | 18 | import Data.List(groupBy, sort) |
19 | import Foreign.C.Types(CInt(..)) | 19 | import Foreign.C.Types(CInt(..)) |
20 | import Numeric.LinearAlgebra.Util.CG(CGMat,cgSolve) | 20 | |
21 | import Numeric.LinearAlgebra.Algorithms(linearSolveLS, relativeError, NormType(..)) | ||
22 | import Data.Packed.Development | 21 | import Data.Packed.Development |
23 | import System.IO.Unsafe(unsafePerformIO) | 22 | import System.IO.Unsafe(unsafePerformIO) |
24 | import Foreign(Ptr) | 23 | import Foreign(Ptr) |
@@ -29,7 +28,7 @@ c ~!~ msg = when c (error msg) | |||
29 | 28 | ||
30 | type AssocMatrix = [((Int,Int),Double)] | 29 | type AssocMatrix = [((Int,Int),Double)] |
31 | 30 | ||
32 | data SMatrix | 31 | data GMatrix |
33 | = CSR | 32 | = CSR |
34 | { csrVals :: Vector Double | 33 | { csrVals :: Vector Double |
35 | , csrCols :: Vector CInt | 34 | , csrCols :: Vector CInt |
@@ -46,14 +45,26 @@ data SMatrix | |||
46 | } | 45 | } |
47 | | Diag | 46 | | Diag |
48 | { diagVals :: Vector Double | 47 | { diagVals :: Vector Double |
48 | , nRows :: Int | ||
49 | , nCols :: Int | ||
50 | } | ||
51 | | Dense | ||
52 | { gmDense :: Matrix Double | ||
49 | , nRows :: Int | 53 | , nRows :: Int |
50 | , nCols :: Int | 54 | , nCols :: Int |
51 | } | 55 | } |
52 | -- | Banded | 56 | -- | Banded |
53 | deriving Show | 57 | deriving Show |
54 | 58 | ||
55 | mkCSR :: AssocMatrix -> SMatrix | 59 | dense :: Matrix Double -> GMatrix |
56 | mkCSR sm' = CSR{..} | 60 | dense m = Dense{..} |
61 | where | ||
62 | gmDense = m | ||
63 | nRows = rows m | ||
64 | nCols = cols m | ||
65 | |||
66 | mkSparse :: AssocMatrix -> GMatrix | ||
67 | mkSparse sm' = CSR{..} | ||
57 | where | 68 | where |
58 | sm = sort sm' | 69 | sm = sort sm' |
59 | rws = map ((fromList *** fromList) | 70 | rws = map ((fromList *** fromList) |
@@ -78,37 +89,47 @@ mkDiagR r c v | |||
78 | nCols = c | 89 | nCols = c |
79 | diagVals = v | 90 | diagVals = v |
80 | 91 | ||
81 | mkDiag v = mkDiagR (dim v) (dim v) v | ||
82 | |||
83 | 92 | ||
84 | type IV t = CInt -> Ptr CInt -> t | 93 | type IV t = CInt -> Ptr CInt -> t |
85 | type V t = CInt -> Ptr Double -> t | 94 | type V t = CInt -> Ptr Double -> t |
86 | type SMxV = V (IV (IV (V (V (IO CInt))))) | 95 | type SMxV = V (IV (IV (V (V (IO CInt))))) |
87 | 96 | ||
88 | smXv :: SMatrix -> Vector Double -> Vector Double | 97 | gmXv :: GMatrix -> Vector Double -> Vector Double |
89 | smXv CSR{..} v = unsafePerformIO $ do | 98 | gmXv CSR{..} v = unsafePerformIO $ do |
90 | dim v /= nCols ~!~ printf "smXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) | 99 | dim v /= nCols ~!~ printf "gmXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) |
91 | r <- createVector nRows | 100 | r <- createVector nRows |
92 | app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv" | 101 | app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv" |
93 | return r | 102 | return r |
94 | 103 | ||
95 | smXv CSC{..} v = unsafePerformIO $ do | 104 | gmXv CSC{..} v = unsafePerformIO $ do |
96 | dim v /= nCols ~!~ printf "smXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) | 105 | dim v /= nCols ~!~ printf "gmXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) |
97 | r <- createVector nRows | 106 | r <- createVector nRows |
98 | app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv" | 107 | app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv" |
99 | return r | 108 | return r |
100 | 109 | ||
101 | smXv Diag{..} v | 110 | gmXv Diag{..} v |
102 | | dim v == nCols | 111 | | dim v == nCols |
103 | = vjoin [ subVector 0 (dim diagVals) v `mul` diagVals | 112 | = vjoin [ subVector 0 (dim diagVals) v `mul` diagVals |
104 | , konst 0 (nRows - dim diagVals) ] | 113 | , konst 0 (nRows - dim diagVals) ] |
105 | | otherwise = error $ printf "smXv (Diag): incorrect sizes: (%d,%d) [%d] x %d" | 114 | | otherwise = error $ printf "gmXv (Diag): incorrect sizes: (%d,%d) [%d] x %d" |
106 | nRows nCols (dim diagVals) (dim v) | 115 | nRows nCols (dim diagVals) (dim v) |
107 | 116 | ||
117 | gmXv Dense{..} v | ||
118 | | dim v == nCols | ||
119 | = mXv gmDense v | ||
120 | | otherwise = error $ printf "gmXv (Dense): incorrect sizes: (%d,%d) x %d" | ||
121 | nRows nCols (dim v) | ||
122 | |||
108 | 123 | ||
109 | instance Contraction SMatrix (Vector Double) (Vector Double) | 124 | -- | general matrix - vector product |
125 | infixr 8 !#> | ||
126 | (!#>) :: GMatrix -> Vector Double -> Vector Double | ||
127 | (!#>) = gmXv | ||
128 | |||
129 | |||
130 | instance Contraction GMatrix (Vector Double) (Vector Double) | ||
110 | where | 131 | where |
111 | contraction = smXv | 132 | contraction = gmXv |
112 | 133 | ||
113 | -------------------------------------------------------------------------------- | 134 | -------------------------------------------------------------------------------- |
114 | 135 | ||
@@ -127,75 +148,11 @@ toDense asm = assoc (r+1,c+1) 0 asm | |||
127 | 148 | ||
128 | 149 | ||
129 | 150 | ||
130 | instance Transposable SMatrix | 151 | instance Transposable GMatrix GMatrix |
131 | where | 152 | where |
132 | tr (CSR vs cs rs n m) = CSC vs cs rs m n | 153 | tr (CSR vs cs rs n m) = CSC vs cs rs m n |
133 | tr (CSC vs rs cs n m) = CSR vs rs cs m n | 154 | tr (CSC vs rs cs n m) = CSR vs rs cs m n |
134 | tr (Diag v n m) = Diag v m n | 155 | tr (Diag v n m) = Diag v m n |
156 | tr (Dense a n m) = Dense (tr a) m n | ||
135 | 157 | ||
136 | 158 | ||
137 | instance CGMat SMatrix | ||
138 | instance CGMat (Matrix Double) | ||
139 | |||
140 | -------------------------------------------------------------------------------- | ||
141 | |||
142 | instance Testable SMatrix | ||
143 | where | ||
144 | checkT _ = (ok,info) | ||
145 | where | ||
146 | sma = convo2 20 3 | ||
147 | x1 = vect [1..20] | ||
148 | x2 = vect [1..40] | ||
149 | sm = mkCSR sma | ||
150 | dm = toDense sma | ||
151 | |||
152 | s1 = sm ◇ x1 | ||
153 | d1 = dm ◇ x1 | ||
154 | |||
155 | s2 = tr sm ◇ x2 | ||
156 | d2 = tr dm ◇ x2 | ||
157 | |||
158 | sdia = mkDiagR 40 20 (vect [1..10]) | ||
159 | s3 = sdia ◇ x1 | ||
160 | s4 = tr sdia ◇ x2 | ||
161 | ddia = diagRect 0 (vect [1..10]) 40 20 | ||
162 | d3 = ddia ◇ x1 | ||
163 | d4 = tr ddia ◇ x2 | ||
164 | |||
165 | v = testb 40 | ||
166 | s5 = cgSolve False sm v | ||
167 | d5 = denseSolve dm v | ||
168 | |||
169 | info = do | ||
170 | print sm | ||
171 | disp (toDense sma) | ||
172 | print s1; print d1 | ||
173 | print s2; print d2 | ||
174 | print s3; print d3 | ||
175 | print s4; print d4 | ||
176 | print s5; print d5 | ||
177 | print $ relativeError Infinity s5 d5 | ||
178 | |||
179 | ok = s1==d1 | ||
180 | && s2==d2 | ||
181 | && s3==d3 | ||
182 | && s4==d4 | ||
183 | && relativeError Infinity s5 d5 < 1E-10 | ||
184 | |||
185 | disp = putStr . dispf 2 | ||
186 | |||
187 | vect = fromList :: [Double] -> Vector Double | ||
188 | |||
189 | convomat :: Int -> Int -> AssocMatrix | ||
190 | convomat n k = [ ((i,j `mod` n),1) | i<-[0..n-1], j <- [i..i+k-1]] | ||
191 | |||
192 | convo2 :: Int -> Int -> AssocMatrix | ||
193 | convo2 n k = m1 ++ m2 | ||
194 | where | ||
195 | m1 = convomat n k | ||
196 | m2 = map (((+n) *** id) *** id) m1 | ||
197 | |||
198 | testb n = vect $ take n $ cycle ([0..10]++[9,8..1]) | ||
199 | |||
200 | denseSolve a = flatten . linearSolveLS a . asColumn | ||
201 | |||