summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric/Sparse.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Numeric/Sparse.hs')
-rw-r--r--packages/base/src/Numeric/Sparse.hs192
1 files changed, 192 insertions, 0 deletions
diff --git a/packages/base/src/Numeric/Sparse.hs b/packages/base/src/Numeric/Sparse.hs
new file mode 100644
index 0000000..3835590
--- /dev/null
+++ b/packages/base/src/Numeric/Sparse.hs
@@ -0,0 +1,192 @@
1{-# LANGUAGE RecordWildCards #-}
2{-# LANGUAGE MultiParamTypeClasses #-}
3{-# LANGUAGE FlexibleInstances #-}
4
5module Numeric.Sparse(
6 SMatrix(..),
7 mkCSR, mkDiag,
8 AssocMatrix,
9 toDense,
10 smXv
11)where
12
13import Numeric.Container
14import qualified Data.Vector.Storable as V
15import Data.Function(on)
16import Control.Arrow((***))
17import Control.Monad(when)
18import Data.List(groupBy, sort)
19import Foreign.C.Types(CInt(..))
20import Numeric.LinearAlgebra.Util.CG(CGMat)
21import Data.Packed.Development
22import System.IO.Unsafe(unsafePerformIO)
23import Foreign(Ptr)
24import Text.Printf(printf)
25
26infixl 0 ~!~
27c ~!~ msg = when c (error msg)
28
29type AssocMatrix = [((Int,Int),Double)]
30
31data SMatrix
32 = CSR
33 { csrVals :: Vector Double
34 , csrCols :: Vector CInt
35 , csrRows :: Vector CInt
36 , nRows :: Int
37 , nCols :: Int
38 }
39 | CSC
40 { cscVals :: Vector Double
41 , cscRows :: Vector CInt
42 , cscCols :: Vector CInt
43 , nRows :: Int
44 , nCols :: Int
45 }
46 | Diag
47 { diagVals :: Vector Double
48 , nRows :: Int
49 , nCols :: Int
50 }
51-- | Banded
52 deriving Show
53
54mkCSR :: AssocMatrix -> SMatrix
55mkCSR sm' = CSR{..}
56 where
57 sm = sort sm'
58 rws = map ((fromList *** fromList)
59 . unzip
60 . map ((succ.fi.snd) *** id)
61 )
62 . groupBy ((==) `on` (fst.fst))
63 $ sm
64 rszs = map (fi . dim . fst) rws
65 csrRows = fromList (scanl (+) 1 rszs)
66 csrVals = vjoin (map snd rws)
67 csrCols = vjoin (map fst rws)
68 nRows = dim csrRows - 1
69 nCols = fromIntegral (V.maximum csrCols)
70
71
72mkDiagR r c v
73 | dim v <= min r c = Diag{..}
74 | otherwise = error $ printf "mkDiagR: incorrect sizes (%d,%d) [%d]" r c (dim v)
75 where
76 nRows = r
77 nCols = c
78 diagVals = v
79
80mkDiag v = mkDiagR (dim v) (dim v) v
81
82
83type IV t = CInt -> Ptr CInt -> t
84type V t = CInt -> Ptr Double -> t
85type SMxV = V (IV (IV (V (V (IO CInt)))))
86
87smXv :: SMatrix -> Vector Double -> Vector Double
88smXv CSR{..} v = unsafePerformIO $ do
89 dim v /= nCols ~!~ printf "smXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v)
90 r <- createVector nRows
91 app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv"
92 return r
93
94smXv CSC{..} v = unsafePerformIO $ do
95 dim v /= nCols ~!~ printf "smXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v)
96 r <- createVector nRows
97 app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv"
98 return r
99
100smXv Diag{..} v
101 | dim v == nCols
102 = vjoin [ subVector 0 (dim diagVals) v `mul` diagVals
103 , konst 0 (nRows - dim diagVals) ]
104 | otherwise = error $ printf "smXv (Diag): incorrect sizes: (%d,%d) [%d] x %d"
105 nRows nCols (dim diagVals) (dim v)
106
107
108instance Contraction SMatrix (Vector Double) (Vector Double)
109 where
110 contraction = smXv
111
112--------------------------------------------------------------------------------
113
114foreign import ccall unsafe "smXv"
115 c_smXv :: SMxV
116
117foreign import ccall unsafe "smTXv"
118 c_smTXv :: SMxV
119
120--------------------------------------------------------------------------------
121
122toDense :: AssocMatrix -> Matrix Double
123toDense asm = assoc (r+1,c+1) 0 asm
124 where
125 (r,c) = (maximum *** maximum) . unzip . map fst $ asm
126
127
128
129instance Transposable SMatrix
130 where
131 tr (CSR vs cs rs n m) = CSC vs cs rs m n
132 tr (CSC vs rs cs n m) = CSR vs rs cs m n
133 tr (Diag v n m) = Diag v m n
134
135instance Transposable (Matrix Double)
136 where
137 tr = trans
138
139
140instance CGMat SMatrix
141instance CGMat (Matrix Double)
142
143--------------------------------------------------------------------------------
144
145instance Testable SMatrix
146 where
147 checkT _ = (ok,info)
148 where
149 sma = convo2 20 3
150 x1 = vect [1..20]
151 x2 = vect [1..40]
152 sm = mkCSR sma
153
154 s1 = sm ◇ x1
155 d1 = toDense sma ◇ x1
156
157 s2 = tr sm ◇ x2
158 d2 = tr (toDense sma) ◇ x2
159
160 sdia = mkDiagR 40 20 (vect [1..10])
161 s3 = sdia ◇ x1
162 s4 = tr sdia ◇ x2
163 ddia = diagRect 0 (vect [1..10]) 40 20
164 d3 = ddia ◇ x1
165 d4 = tr ddia ◇ x2
166
167 info = do
168 print sm
169 disp (toDense sma)
170 print s1; print d1
171 print s2; print d2
172 print s3; print d3
173 print s4; print d4
174
175 ok = s1==d1
176 && s2==d2
177 && s3==d3
178 && s4==d4
179
180 disp = putStr . dispf 2
181
182 vect = fromList :: [Double] -> Vector Double
183
184 convomat :: Int -> Int -> AssocMatrix
185 convomat n k = [ ((i,j `mod` n),1) | i<-[0..n-1], j <- [i..i+k-1]]
186
187 convo2 :: Int -> Int -> AssocMatrix
188 convo2 n k = m1 ++ m2
189 where
190 m1 = convomat n k
191 m2 = map (((+n) *** id) *** id) m1
192