diff options
Diffstat (limited to 'packages/base/src/Numeric')
-rw-r--r-- | packages/base/src/Numeric/Sparse.hs | 212 |
1 files changed, 0 insertions, 212 deletions
diff --git a/packages/base/src/Numeric/Sparse.hs b/packages/base/src/Numeric/Sparse.hs deleted file mode 100644 index d856287..0000000 --- a/packages/base/src/Numeric/Sparse.hs +++ /dev/null | |||
@@ -1,212 +0,0 @@ | |||
1 | {-# LANGUAGE RecordWildCards #-} | ||
2 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
3 | {-# LANGUAGE FlexibleInstances #-} | ||
4 | |||
5 | module Numeric.Sparse( | ||
6 | GMatrix(..), CSR(..), mkCSR, fromCSR, | ||
7 | mkSparse, mkDiagR, mkDense, | ||
8 | AssocMatrix, | ||
9 | toDense, | ||
10 | gmXv, (!#>) | ||
11 | )where | ||
12 | |||
13 | import Data.Packed.Numeric | ||
14 | import qualified Data.Vector.Storable as V | ||
15 | import Data.Function(on) | ||
16 | import Control.Arrow((***)) | ||
17 | import Control.Monad(when) | ||
18 | import Data.List(groupBy, sort) | ||
19 | import Foreign.C.Types(CInt(..)) | ||
20 | |||
21 | import Data.Packed.Development | ||
22 | import System.IO.Unsafe(unsafePerformIO) | ||
23 | import Foreign(Ptr) | ||
24 | import Text.Printf(printf) | ||
25 | |||
26 | infixl 0 ~!~ | ||
27 | c ~!~ msg = when c (error msg) | ||
28 | |||
29 | type AssocMatrix = [((Int,Int),Double)] | ||
30 | |||
31 | data CSR = CSR | ||
32 | { csrVals :: Vector Double | ||
33 | , csrCols :: Vector CInt | ||
34 | , csrRows :: Vector CInt | ||
35 | , csrNRows :: Int | ||
36 | , csrNCols :: Int | ||
37 | } deriving Show | ||
38 | |||
39 | data CSC = CSC | ||
40 | { cscVals :: Vector Double | ||
41 | , cscRows :: Vector CInt | ||
42 | , cscCols :: Vector CInt | ||
43 | , cscNRows :: Int | ||
44 | , cscNCols :: Int | ||
45 | } deriving Show | ||
46 | |||
47 | |||
48 | mkCSR :: AssocMatrix -> CSR | ||
49 | mkCSR sm' = CSR{..} | ||
50 | where | ||
51 | sm = sort sm' | ||
52 | rws = map ((fromList *** fromList) | ||
53 | . unzip | ||
54 | . map ((succ.fi.snd) *** id) | ||
55 | ) | ||
56 | . groupBy ((==) `on` (fst.fst)) | ||
57 | $ sm | ||
58 | rszs = map (fi . dim . fst) rws | ||
59 | csrRows = fromList (scanl (+) 1 rszs) | ||
60 | csrVals = vjoin (map snd rws) | ||
61 | csrCols = vjoin (map fst rws) | ||
62 | csrNRows = dim csrRows - 1 | ||
63 | csrNCols = fromIntegral (V.maximum csrCols) | ||
64 | |||
65 | {- | General matrix with specialized internal representations for | ||
66 | dense, sparse, diagonal, banded, and constant elements. | ||
67 | |||
68 | >>> let m = mkSparse [((0,999),1.0),((1,1999),2.0)] | ||
69 | >>> m | ||
70 | SparseR {gmCSR = CSR {csrVals = fromList [1.0,2.0], | ||
71 | csrCols = fromList [1000,2000], | ||
72 | csrRows = fromList [1,2,3], | ||
73 | csrNRows = 2, | ||
74 | csrNCols = 2000}, | ||
75 | nRows = 2, | ||
76 | nCols = 2000} | ||
77 | |||
78 | >>> let m = mkDense (mat 2 [1..4]) | ||
79 | >>> m | ||
80 | Dense {gmDense = (2><2) | ||
81 | [ 1.0, 2.0 | ||
82 | , 3.0, 4.0 ], nRows = 2, nCols = 2} | ||
83 | |||
84 | -} | ||
85 | data GMatrix | ||
86 | = SparseR | ||
87 | { gmCSR :: CSR | ||
88 | , nRows :: Int | ||
89 | , nCols :: Int | ||
90 | } | ||
91 | | SparseC | ||
92 | { gmCSC :: CSC | ||
93 | , nRows :: Int | ||
94 | , nCols :: Int | ||
95 | } | ||
96 | | Diag | ||
97 | { diagVals :: Vector Double | ||
98 | , nRows :: Int | ||
99 | , nCols :: Int | ||
100 | } | ||
101 | | Dense | ||
102 | { gmDense :: Matrix Double | ||
103 | , nRows :: Int | ||
104 | , nCols :: Int | ||
105 | } | ||
106 | -- | Banded | ||
107 | deriving Show | ||
108 | |||
109 | |||
110 | mkDense :: Matrix Double -> GMatrix | ||
111 | mkDense m = Dense{..} | ||
112 | where | ||
113 | gmDense = m | ||
114 | nRows = rows m | ||
115 | nCols = cols m | ||
116 | |||
117 | mkSparse :: AssocMatrix -> GMatrix | ||
118 | mkSparse = fromCSR . mkCSR | ||
119 | |||
120 | fromCSR :: CSR -> GMatrix | ||
121 | fromCSR csr = SparseR {..} | ||
122 | where | ||
123 | gmCSR @ CSR {..} = csr | ||
124 | nRows = csrNRows | ||
125 | nCols = csrNCols | ||
126 | |||
127 | |||
128 | mkDiagR r c v | ||
129 | | dim v <= min r c = Diag{..} | ||
130 | | otherwise = error $ printf "mkDiagR: incorrect sizes (%d,%d) [%d]" r c (dim v) | ||
131 | where | ||
132 | nRows = r | ||
133 | nCols = c | ||
134 | diagVals = v | ||
135 | |||
136 | |||
137 | type IV t = CInt -> Ptr CInt -> t | ||
138 | type V t = CInt -> Ptr Double -> t | ||
139 | type SMxV = V (IV (IV (V (V (IO CInt))))) | ||
140 | |||
141 | gmXv :: GMatrix -> Vector Double -> Vector Double | ||
142 | gmXv SparseR { gmCSR = CSR{..}, .. } v = unsafePerformIO $ do | ||
143 | dim v /= nCols ~!~ printf "gmXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) | ||
144 | r <- createVector nRows | ||
145 | app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv" | ||
146 | return r | ||
147 | |||
148 | gmXv SparseC { gmCSC = CSC{..}, .. } v = unsafePerformIO $ do | ||
149 | dim v /= nCols ~!~ printf "gmXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) | ||
150 | r <- createVector nRows | ||
151 | app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv" | ||
152 | return r | ||
153 | |||
154 | gmXv Diag{..} v | ||
155 | | dim v == nCols | ||
156 | = vjoin [ subVector 0 (dim diagVals) v `mul` diagVals | ||
157 | , konst 0 (nRows - dim diagVals) ] | ||
158 | | otherwise = error $ printf "gmXv (Diag): incorrect sizes: (%d,%d) [%d] x %d" | ||
159 | nRows nCols (dim diagVals) (dim v) | ||
160 | |||
161 | gmXv Dense{..} v | ||
162 | | dim v == nCols | ||
163 | = mXv gmDense v | ||
164 | | otherwise = error $ printf "gmXv (Dense): incorrect sizes: (%d,%d) x %d" | ||
165 | nRows nCols (dim v) | ||
166 | |||
167 | |||
168 | {- | general matrix - vector product | ||
169 | |||
170 | >>> let m = mkSparse [((0,999),1.0),((1,1999),2.0)] | ||
171 | >>> m !#> vector [1..2000] | ||
172 | fromList [1000.0,4000.0] | ||
173 | |||
174 | -} | ||
175 | infixr 8 !#> | ||
176 | (!#>) :: GMatrix -> Vector Double -> Vector Double | ||
177 | (!#>) = gmXv | ||
178 | |||
179 | -------------------------------------------------------------------------------- | ||
180 | |||
181 | foreign import ccall unsafe "smXv" | ||
182 | c_smXv :: SMxV | ||
183 | |||
184 | foreign import ccall unsafe "smTXv" | ||
185 | c_smTXv :: SMxV | ||
186 | |||
187 | -------------------------------------------------------------------------------- | ||
188 | |||
189 | toDense :: AssocMatrix -> Matrix Double | ||
190 | toDense asm = assoc (r+1,c+1) 0 asm | ||
191 | where | ||
192 | (r,c) = (maximum *** maximum) . unzip . map fst $ asm | ||
193 | |||
194 | |||
195 | instance Transposable CSR CSC | ||
196 | where | ||
197 | tr (CSR vs cs rs n m) = CSC vs cs rs m n | ||
198 | tr' = tr | ||
199 | |||
200 | instance Transposable CSC CSR | ||
201 | where | ||
202 | tr (CSC vs rs cs n m) = CSR vs rs cs m n | ||
203 | tr' = tr | ||
204 | |||
205 | instance Transposable GMatrix GMatrix | ||
206 | where | ||
207 | tr (SparseR s n m) = SparseC (tr s) m n | ||
208 | tr (SparseC s n m) = SparseR (tr s) m n | ||
209 | tr (Diag v n m) = Diag v m n | ||
210 | tr (Dense a n m) = Dense (tr a) m n | ||
211 | tr' = tr | ||
212 | |||