summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Sparse.hs
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2015-06-05 16:39:32 +0200
committerAlberto Ruiz <aruiz@um.es>2015-06-05 16:39:32 +0200
commitf20a94375c03bd6154f67fec1345e530acfc881d (patch)
treeadf818f40e8fd9b06f9445aaf086c98580a1dd4b /packages/base/src/Internal/Sparse.hs
parent8b093ecca2b4e200ff191b84cb0b56a12312867b (diff)
move sparse
Diffstat (limited to 'packages/base/src/Internal/Sparse.hs')
-rw-r--r--packages/base/src/Internal/Sparse.hs217
1 files changed, 217 insertions, 0 deletions
diff --git a/packages/base/src/Internal/Sparse.hs b/packages/base/src/Internal/Sparse.hs
new file mode 100644
index 0000000..930bc99
--- /dev/null
+++ b/packages/base/src/Internal/Sparse.hs
@@ -0,0 +1,217 @@
1{-# LANGUAGE RecordWildCards #-}
2{-# LANGUAGE MultiParamTypeClasses #-}
3{-# LANGUAGE FlexibleInstances #-}
4
5module Internal.Sparse(
6 GMatrix(..), CSR(..), mkCSR, fromCSR,
7 mkSparse, mkDiagR, mkDense,
8 AssocMatrix,
9 toDense,
10 gmXv, (!#>)
11)where
12
13import Internal.Vector
14import Internal.Matrix
15import Internal.Numeric
16import Internal.Container
17import Internal.Tools
18import qualified Data.Vector.Storable as V
19import Data.Vector.Storable(fromList)
20import Data.Function(on)
21import Control.Arrow((***))
22import Control.Monad(when)
23import Data.List(groupBy, sort)
24import Foreign.C.Types(CInt(..))
25
26import Internal.Devel
27import System.IO.Unsafe(unsafePerformIO)
28import Foreign(Ptr)
29import Text.Printf(printf)
30
31infixl 0 ~!~
32c ~!~ msg = when c (error msg)
33
34type AssocMatrix = [((Int,Int),Double)]
35
36data CSR = CSR
37 { csrVals :: Vector Double
38 , csrCols :: Vector CInt
39 , csrRows :: Vector CInt
40 , csrNRows :: Int
41 , csrNCols :: Int
42 } deriving Show
43
44data CSC = CSC
45 { cscVals :: Vector Double
46 , cscRows :: Vector CInt
47 , cscCols :: Vector CInt
48 , cscNRows :: Int
49 , cscNCols :: Int
50 } deriving Show
51
52
53mkCSR :: AssocMatrix -> CSR
54mkCSR sm' = CSR{..}
55 where
56 sm = sort sm'
57 rws = map ((fromList *** fromList)
58 . unzip
59 . map ((succ.fi.snd) *** id)
60 )
61 . groupBy ((==) `on` (fst.fst))
62 $ sm
63 rszs = map (fi . dim . fst) rws
64 csrRows = fromList (scanl (+) 1 rszs)
65 csrVals = vjoin (map snd rws)
66 csrCols = vjoin (map fst rws)
67 csrNRows = dim csrRows - 1
68 csrNCols = fromIntegral (V.maximum csrCols)
69
70{- | General matrix with specialized internal representations for
71 dense, sparse, diagonal, banded, and constant elements.
72
73>>> let m = mkSparse [((0,999),1.0),((1,1999),2.0)]
74>>> m
75SparseR {gmCSR = CSR {csrVals = fromList [1.0,2.0],
76 csrCols = fromList [1000,2000],
77 csrRows = fromList [1,2,3],
78 csrNRows = 2,
79 csrNCols = 2000},
80 nRows = 2,
81 nCols = 2000}
82
83>>> let m = mkDense (mat 2 [1..4])
84>>> m
85Dense {gmDense = (2><2)
86 [ 1.0, 2.0
87 , 3.0, 4.0 ], nRows = 2, nCols = 2}
88
89-}
90data GMatrix
91 = SparseR
92 { gmCSR :: CSR
93 , nRows :: Int
94 , nCols :: Int
95 }
96 | SparseC
97 { gmCSC :: CSC
98 , nRows :: Int
99 , nCols :: Int
100 }
101 | Diag
102 { diagVals :: Vector Double
103 , nRows :: Int
104 , nCols :: Int
105 }
106 | Dense
107 { gmDense :: Matrix Double
108 , nRows :: Int
109 , nCols :: Int
110 }
111-- | Banded
112 deriving Show
113
114
115mkDense :: Matrix Double -> GMatrix
116mkDense m = Dense{..}
117 where
118 gmDense = m
119 nRows = rows m
120 nCols = cols m
121
122mkSparse :: AssocMatrix -> GMatrix
123mkSparse = fromCSR . mkCSR
124
125fromCSR :: CSR -> GMatrix
126fromCSR csr = SparseR {..}
127 where
128 gmCSR @ CSR {..} = csr
129 nRows = csrNRows
130 nCols = csrNCols
131
132
133mkDiagR r c v
134 | dim v <= min r c = Diag{..}
135 | otherwise = error $ printf "mkDiagR: incorrect sizes (%d,%d) [%d]" r c (dim v)
136 where
137 nRows = r
138 nCols = c
139 diagVals = v
140
141
142type IV t = CInt -> Ptr CInt -> t
143type V t = CInt -> Ptr Double -> t
144type SMxV = V (IV (IV (V (V (IO CInt)))))
145
146gmXv :: GMatrix -> Vector Double -> Vector Double
147gmXv SparseR { gmCSR = CSR{..}, .. } v = unsafePerformIO $ do
148 dim v /= nCols ~!~ printf "gmXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v)
149 r <- createVector nRows
150 app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv"
151 return r
152
153gmXv SparseC { gmCSC = CSC{..}, .. } v = unsafePerformIO $ do
154 dim v /= nCols ~!~ printf "gmXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v)
155 r <- createVector nRows
156 app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv"
157 return r
158
159gmXv Diag{..} v
160 | dim v == nCols
161 = vjoin [ subVector 0 (dim diagVals) v `mul` diagVals
162 , konst 0 (nRows - dim diagVals) ]
163 | otherwise = error $ printf "gmXv (Diag): incorrect sizes: (%d,%d) [%d] x %d"
164 nRows nCols (dim diagVals) (dim v)
165
166gmXv Dense{..} v
167 | dim v == nCols
168 = mXv gmDense v
169 | otherwise = error $ printf "gmXv (Dense): incorrect sizes: (%d,%d) x %d"
170 nRows nCols (dim v)
171
172
173{- | general matrix - vector product
174
175>>> let m = mkSparse [((0,999),1.0),((1,1999),2.0)]
176>>> m !#> vector [1..2000]
177fromList [1000.0,4000.0]
178
179-}
180infixr 8 !#>
181(!#>) :: GMatrix -> Vector Double -> Vector Double
182(!#>) = gmXv
183
184--------------------------------------------------------------------------------
185
186foreign import ccall unsafe "smXv"
187 c_smXv :: SMxV
188
189foreign import ccall unsafe "smTXv"
190 c_smTXv :: SMxV
191
192--------------------------------------------------------------------------------
193
194toDense :: AssocMatrix -> Matrix Double
195toDense asm = assoc (r+1,c+1) 0 asm
196 where
197 (r,c) = (maximum *** maximum) . unzip . map fst $ asm
198
199
200instance Transposable CSR CSC
201 where
202 tr (CSR vs cs rs n m) = CSC vs cs rs m n
203 tr' = tr
204
205instance Transposable CSC CSR
206 where
207 tr (CSC vs rs cs n m) = CSR vs rs cs m n
208 tr' = tr
209
210instance Transposable GMatrix GMatrix
211 where
212 tr (SparseR s n m) = SparseC (tr s) m n
213 tr (SparseC s n m) = SparseR (tr s) m n
214 tr (Diag v n m) = Diag v m n
215 tr (Dense a n m) = Dense (tr a) m n
216 tr' = tr
217