diff options
author | Alberto Ruiz <aruiz@um.es> | 2015-06-05 16:39:32 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2015-06-05 16:39:32 +0200 |
commit | f20a94375c03bd6154f67fec1345e530acfc881d (patch) | |
tree | adf818f40e8fd9b06f9445aaf086c98580a1dd4b /packages/base/src/Internal | |
parent | 8b093ecca2b4e200ff191b84cb0b56a12312867b (diff) |
move sparse
Diffstat (limited to 'packages/base/src/Internal')
-rw-r--r-- | packages/base/src/Internal/Sparse.hs | 217 |
1 files changed, 217 insertions, 0 deletions
diff --git a/packages/base/src/Internal/Sparse.hs b/packages/base/src/Internal/Sparse.hs new file mode 100644 index 0000000..930bc99 --- /dev/null +++ b/packages/base/src/Internal/Sparse.hs | |||
@@ -0,0 +1,217 @@ | |||
1 | {-# LANGUAGE RecordWildCards #-} | ||
2 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
3 | {-# LANGUAGE FlexibleInstances #-} | ||
4 | |||
5 | module Internal.Sparse( | ||
6 | GMatrix(..), CSR(..), mkCSR, fromCSR, | ||
7 | mkSparse, mkDiagR, mkDense, | ||
8 | AssocMatrix, | ||
9 | toDense, | ||
10 | gmXv, (!#>) | ||
11 | )where | ||
12 | |||
13 | import Internal.Vector | ||
14 | import Internal.Matrix | ||
15 | import Internal.Numeric | ||
16 | import Internal.Container | ||
17 | import Internal.Tools | ||
18 | import qualified Data.Vector.Storable as V | ||
19 | import Data.Vector.Storable(fromList) | ||
20 | import Data.Function(on) | ||
21 | import Control.Arrow((***)) | ||
22 | import Control.Monad(when) | ||
23 | import Data.List(groupBy, sort) | ||
24 | import Foreign.C.Types(CInt(..)) | ||
25 | |||
26 | import Internal.Devel | ||
27 | import System.IO.Unsafe(unsafePerformIO) | ||
28 | import Foreign(Ptr) | ||
29 | import Text.Printf(printf) | ||
30 | |||
31 | infixl 0 ~!~ | ||
32 | c ~!~ msg = when c (error msg) | ||
33 | |||
34 | type AssocMatrix = [((Int,Int),Double)] | ||
35 | |||
36 | data CSR = CSR | ||
37 | { csrVals :: Vector Double | ||
38 | , csrCols :: Vector CInt | ||
39 | , csrRows :: Vector CInt | ||
40 | , csrNRows :: Int | ||
41 | , csrNCols :: Int | ||
42 | } deriving Show | ||
43 | |||
44 | data CSC = CSC | ||
45 | { cscVals :: Vector Double | ||
46 | , cscRows :: Vector CInt | ||
47 | , cscCols :: Vector CInt | ||
48 | , cscNRows :: Int | ||
49 | , cscNCols :: Int | ||
50 | } deriving Show | ||
51 | |||
52 | |||
53 | mkCSR :: AssocMatrix -> CSR | ||
54 | mkCSR sm' = CSR{..} | ||
55 | where | ||
56 | sm = sort sm' | ||
57 | rws = map ((fromList *** fromList) | ||
58 | . unzip | ||
59 | . map ((succ.fi.snd) *** id) | ||
60 | ) | ||
61 | . groupBy ((==) `on` (fst.fst)) | ||
62 | $ sm | ||
63 | rszs = map (fi . dim . fst) rws | ||
64 | csrRows = fromList (scanl (+) 1 rszs) | ||
65 | csrVals = vjoin (map snd rws) | ||
66 | csrCols = vjoin (map fst rws) | ||
67 | csrNRows = dim csrRows - 1 | ||
68 | csrNCols = fromIntegral (V.maximum csrCols) | ||
69 | |||
70 | {- | General matrix with specialized internal representations for | ||
71 | dense, sparse, diagonal, banded, and constant elements. | ||
72 | |||
73 | >>> let m = mkSparse [((0,999),1.0),((1,1999),2.0)] | ||
74 | >>> m | ||
75 | SparseR {gmCSR = CSR {csrVals = fromList [1.0,2.0], | ||
76 | csrCols = fromList [1000,2000], | ||
77 | csrRows = fromList [1,2,3], | ||
78 | csrNRows = 2, | ||
79 | csrNCols = 2000}, | ||
80 | nRows = 2, | ||
81 | nCols = 2000} | ||
82 | |||
83 | >>> let m = mkDense (mat 2 [1..4]) | ||
84 | >>> m | ||
85 | Dense {gmDense = (2><2) | ||
86 | [ 1.0, 2.0 | ||
87 | , 3.0, 4.0 ], nRows = 2, nCols = 2} | ||
88 | |||
89 | -} | ||
90 | data GMatrix | ||
91 | = SparseR | ||
92 | { gmCSR :: CSR | ||
93 | , nRows :: Int | ||
94 | , nCols :: Int | ||
95 | } | ||
96 | | SparseC | ||
97 | { gmCSC :: CSC | ||
98 | , nRows :: Int | ||
99 | , nCols :: Int | ||
100 | } | ||
101 | | Diag | ||
102 | { diagVals :: Vector Double | ||
103 | , nRows :: Int | ||
104 | , nCols :: Int | ||
105 | } | ||
106 | | Dense | ||
107 | { gmDense :: Matrix Double | ||
108 | , nRows :: Int | ||
109 | , nCols :: Int | ||
110 | } | ||
111 | -- | Banded | ||
112 | deriving Show | ||
113 | |||
114 | |||
115 | mkDense :: Matrix Double -> GMatrix | ||
116 | mkDense m = Dense{..} | ||
117 | where | ||
118 | gmDense = m | ||
119 | nRows = rows m | ||
120 | nCols = cols m | ||
121 | |||
122 | mkSparse :: AssocMatrix -> GMatrix | ||
123 | mkSparse = fromCSR . mkCSR | ||
124 | |||
125 | fromCSR :: CSR -> GMatrix | ||
126 | fromCSR csr = SparseR {..} | ||
127 | where | ||
128 | gmCSR @ CSR {..} = csr | ||
129 | nRows = csrNRows | ||
130 | nCols = csrNCols | ||
131 | |||
132 | |||
133 | mkDiagR r c v | ||
134 | | dim v <= min r c = Diag{..} | ||
135 | | otherwise = error $ printf "mkDiagR: incorrect sizes (%d,%d) [%d]" r c (dim v) | ||
136 | where | ||
137 | nRows = r | ||
138 | nCols = c | ||
139 | diagVals = v | ||
140 | |||
141 | |||
142 | type IV t = CInt -> Ptr CInt -> t | ||
143 | type V t = CInt -> Ptr Double -> t | ||
144 | type SMxV = V (IV (IV (V (V (IO CInt))))) | ||
145 | |||
146 | gmXv :: GMatrix -> Vector Double -> Vector Double | ||
147 | gmXv SparseR { gmCSR = CSR{..}, .. } v = unsafePerformIO $ do | ||
148 | dim v /= nCols ~!~ printf "gmXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) | ||
149 | r <- createVector nRows | ||
150 | app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv" | ||
151 | return r | ||
152 | |||
153 | gmXv SparseC { gmCSC = CSC{..}, .. } v = unsafePerformIO $ do | ||
154 | dim v /= nCols ~!~ printf "gmXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) | ||
155 | r <- createVector nRows | ||
156 | app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv" | ||
157 | return r | ||
158 | |||
159 | gmXv Diag{..} v | ||
160 | | dim v == nCols | ||
161 | = vjoin [ subVector 0 (dim diagVals) v `mul` diagVals | ||
162 | , konst 0 (nRows - dim diagVals) ] | ||
163 | | otherwise = error $ printf "gmXv (Diag): incorrect sizes: (%d,%d) [%d] x %d" | ||
164 | nRows nCols (dim diagVals) (dim v) | ||
165 | |||
166 | gmXv Dense{..} v | ||
167 | | dim v == nCols | ||
168 | = mXv gmDense v | ||
169 | | otherwise = error $ printf "gmXv (Dense): incorrect sizes: (%d,%d) x %d" | ||
170 | nRows nCols (dim v) | ||
171 | |||
172 | |||
173 | {- | general matrix - vector product | ||
174 | |||
175 | >>> let m = mkSparse [((0,999),1.0),((1,1999),2.0)] | ||
176 | >>> m !#> vector [1..2000] | ||
177 | fromList [1000.0,4000.0] | ||
178 | |||
179 | -} | ||
180 | infixr 8 !#> | ||
181 | (!#>) :: GMatrix -> Vector Double -> Vector Double | ||
182 | (!#>) = gmXv | ||
183 | |||
184 | -------------------------------------------------------------------------------- | ||
185 | |||
186 | foreign import ccall unsafe "smXv" | ||
187 | c_smXv :: SMxV | ||
188 | |||
189 | foreign import ccall unsafe "smTXv" | ||
190 | c_smTXv :: SMxV | ||
191 | |||
192 | -------------------------------------------------------------------------------- | ||
193 | |||
194 | toDense :: AssocMatrix -> Matrix Double | ||
195 | toDense asm = assoc (r+1,c+1) 0 asm | ||
196 | where | ||
197 | (r,c) = (maximum *** maximum) . unzip . map fst $ asm | ||
198 | |||
199 | |||
200 | instance Transposable CSR CSC | ||
201 | where | ||
202 | tr (CSR vs cs rs n m) = CSC vs cs rs m n | ||
203 | tr' = tr | ||
204 | |||
205 | instance Transposable CSC CSR | ||
206 | where | ||
207 | tr (CSC vs rs cs n m) = CSR vs rs cs m n | ||
208 | tr' = tr | ||
209 | |||
210 | instance Transposable GMatrix GMatrix | ||
211 | where | ||
212 | tr (SparseR s n m) = SparseC (tr s) m n | ||
213 | tr (SparseC s n m) = SparseR (tr s) m n | ||
214 | tr (Diag v n m) = Diag v m n | ||
215 | tr (Dense a n m) = Dense (tr a) m n | ||
216 | tr' = tr | ||
217 | |||