summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric/Sparse.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Numeric/Sparse.hs')
-rw-r--r--packages/base/src/Numeric/Sparse.hs127
1 files changed, 42 insertions, 85 deletions
diff --git a/packages/base/src/Numeric/Sparse.hs b/packages/base/src/Numeric/Sparse.hs
index 2df4578..4d05bdc 100644
--- a/packages/base/src/Numeric/Sparse.hs
+++ b/packages/base/src/Numeric/Sparse.hs
@@ -3,11 +3,11 @@
3{-# LANGUAGE FlexibleInstances #-} 3{-# LANGUAGE FlexibleInstances #-}
4 4
5module Numeric.Sparse( 5module Numeric.Sparse(
6 SMatrix(..), 6 GMatrix(..),
7 mkCSR, mkDiag, 7 mkSparse, mkDiagR, dense,
8 AssocMatrix, 8 AssocMatrix,
9 toDense, 9 toDense,
10 smXv 10 gmXv, (!#>)
11)where 11)where
12 12
13import Data.Packed.Numeric 13import Data.Packed.Numeric
@@ -17,8 +17,7 @@ import Control.Arrow((***))
17import Control.Monad(when) 17import Control.Monad(when)
18import Data.List(groupBy, sort) 18import Data.List(groupBy, sort)
19import Foreign.C.Types(CInt(..)) 19import Foreign.C.Types(CInt(..))
20import Numeric.LinearAlgebra.Util.CG(CGMat,cgSolve) 20
21import Numeric.LinearAlgebra.Algorithms(linearSolveLS, relativeError, NormType(..))
22import Data.Packed.Development 21import Data.Packed.Development
23import System.IO.Unsafe(unsafePerformIO) 22import System.IO.Unsafe(unsafePerformIO)
24import Foreign(Ptr) 23import Foreign(Ptr)
@@ -29,7 +28,7 @@ c ~!~ msg = when c (error msg)
29 28
30type AssocMatrix = [((Int,Int),Double)] 29type AssocMatrix = [((Int,Int),Double)]
31 30
32data SMatrix 31data GMatrix
33 = CSR 32 = CSR
34 { csrVals :: Vector Double 33 { csrVals :: Vector Double
35 , csrCols :: Vector CInt 34 , csrCols :: Vector CInt
@@ -46,14 +45,26 @@ data SMatrix
46 } 45 }
47 | Diag 46 | Diag
48 { diagVals :: Vector Double 47 { diagVals :: Vector Double
48 , nRows :: Int
49 , nCols :: Int
50 }
51 | Dense
52 { gmDense :: Matrix Double
49 , nRows :: Int 53 , nRows :: Int
50 , nCols :: Int 54 , nCols :: Int
51 } 55 }
52-- | Banded 56-- | Banded
53 deriving Show 57 deriving Show
54 58
55mkCSR :: AssocMatrix -> SMatrix 59dense :: Matrix Double -> GMatrix
56mkCSR sm' = CSR{..} 60dense m = Dense{..}
61 where
62 gmDense = m
63 nRows = rows m
64 nCols = cols m
65
66mkSparse :: AssocMatrix -> GMatrix
67mkSparse sm' = CSR{..}
57 where 68 where
58 sm = sort sm' 69 sm = sort sm'
59 rws = map ((fromList *** fromList) 70 rws = map ((fromList *** fromList)
@@ -78,37 +89,47 @@ mkDiagR r c v
78 nCols = c 89 nCols = c
79 diagVals = v 90 diagVals = v
80 91
81mkDiag v = mkDiagR (dim v) (dim v) v
82
83 92
84type IV t = CInt -> Ptr CInt -> t 93type IV t = CInt -> Ptr CInt -> t
85type V t = CInt -> Ptr Double -> t 94type V t = CInt -> Ptr Double -> t
86type SMxV = V (IV (IV (V (V (IO CInt))))) 95type SMxV = V (IV (IV (V (V (IO CInt)))))
87 96
88smXv :: SMatrix -> Vector Double -> Vector Double 97gmXv :: GMatrix -> Vector Double -> Vector Double
89smXv CSR{..} v = unsafePerformIO $ do 98gmXv CSR{..} v = unsafePerformIO $ do
90 dim v /= nCols ~!~ printf "smXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) 99 dim v /= nCols ~!~ printf "gmXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v)
91 r <- createVector nRows 100 r <- createVector nRows
92 app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv" 101 app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv"
93 return r 102 return r
94 103
95smXv CSC{..} v = unsafePerformIO $ do 104gmXv CSC{..} v = unsafePerformIO $ do
96 dim v /= nCols ~!~ printf "smXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) 105 dim v /= nCols ~!~ printf "gmXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v)
97 r <- createVector nRows 106 r <- createVector nRows
98 app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv" 107 app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv"
99 return r 108 return r
100 109
101smXv Diag{..} v 110gmXv Diag{..} v
102 | dim v == nCols 111 | dim v == nCols
103 = vjoin [ subVector 0 (dim diagVals) v `mul` diagVals 112 = vjoin [ subVector 0 (dim diagVals) v `mul` diagVals
104 , konst 0 (nRows - dim diagVals) ] 113 , konst 0 (nRows - dim diagVals) ]
105 | otherwise = error $ printf "smXv (Diag): incorrect sizes: (%d,%d) [%d] x %d" 114 | otherwise = error $ printf "gmXv (Diag): incorrect sizes: (%d,%d) [%d] x %d"
106 nRows nCols (dim diagVals) (dim v) 115 nRows nCols (dim diagVals) (dim v)
107 116
117gmXv Dense{..} v
118 | dim v == nCols
119 = mXv gmDense v
120 | otherwise = error $ printf "gmXv (Dense): incorrect sizes: (%d,%d) x %d"
121 nRows nCols (dim v)
122
108 123
109instance Contraction SMatrix (Vector Double) (Vector Double) 124-- | general matrix - vector product
125infixr 8 !#>
126(!#>) :: GMatrix -> Vector Double -> Vector Double
127(!#>) = gmXv
128
129
130instance Contraction GMatrix (Vector Double) (Vector Double)
110 where 131 where
111 contraction = smXv 132 contraction = gmXv
112 133
113-------------------------------------------------------------------------------- 134--------------------------------------------------------------------------------
114 135
@@ -127,75 +148,11 @@ toDense asm = assoc (r+1,c+1) 0 asm
127 148
128 149
129 150
130instance Transposable SMatrix 151instance Transposable GMatrix GMatrix
131 where 152 where
132 tr (CSR vs cs rs n m) = CSC vs cs rs m n 153 tr (CSR vs cs rs n m) = CSC vs cs rs m n
133 tr (CSC vs rs cs n m) = CSR vs rs cs m n 154 tr (CSC vs rs cs n m) = CSR vs rs cs m n
134 tr (Diag v n m) = Diag v m n 155 tr (Diag v n m) = Diag v m n
156 tr (Dense a n m) = Dense (tr a) m n
135 157
136 158
137instance CGMat SMatrix
138instance CGMat (Matrix Double)
139
140--------------------------------------------------------------------------------
141
142instance Testable SMatrix
143 where
144 checkT _ = (ok,info)
145 where
146 sma = convo2 20 3
147 x1 = vect [1..20]
148 x2 = vect [1..40]
149 sm = mkCSR sma
150 dm = toDense sma
151
152 s1 = sm ◇ x1
153 d1 = dm ◇ x1
154
155 s2 = tr sm ◇ x2
156 d2 = tr dm ◇ x2
157
158 sdia = mkDiagR 40 20 (vect [1..10])
159 s3 = sdia ◇ x1
160 s4 = tr sdia ◇ x2
161 ddia = diagRect 0 (vect [1..10]) 40 20
162 d3 = ddia ◇ x1
163 d4 = tr ddia ◇ x2
164
165 v = testb 40
166 s5 = cgSolve False sm v
167 d5 = denseSolve dm v
168
169 info = do
170 print sm
171 disp (toDense sma)
172 print s1; print d1
173 print s2; print d2
174 print s3; print d3
175 print s4; print d4
176 print s5; print d5
177 print $ relativeError Infinity s5 d5
178
179 ok = s1==d1
180 && s2==d2
181 && s3==d3
182 && s4==d4
183 && relativeError Infinity s5 d5 < 1E-10
184
185 disp = putStr . dispf 2
186
187 vect = fromList :: [Double] -> Vector Double
188
189 convomat :: Int -> Int -> AssocMatrix
190 convomat n k = [ ((i,j `mod` n),1) | i<-[0..n-1], j <- [i..i+k-1]]
191
192 convo2 :: Int -> Int -> AssocMatrix
193 convo2 n k = m1 ++ m2
194 where
195 m1 = convomat n k
196 m2 = map (((+n) *** id) *** id) m1
197
198 testb n = vect $ take n $ cycle ([0..10]++[9,8..1])
199
200 denseSolve a = flatten . linearSolveLS a . asColumn
201