summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Numeric')
-rw-r--r--packages/base/src/Numeric/Sparse.hs212
1 files changed, 0 insertions, 212 deletions
diff --git a/packages/base/src/Numeric/Sparse.hs b/packages/base/src/Numeric/Sparse.hs
deleted file mode 100644
index d856287..0000000
--- a/packages/base/src/Numeric/Sparse.hs
+++ /dev/null
@@ -1,212 +0,0 @@
1{-# LANGUAGE RecordWildCards #-}
2{-# LANGUAGE MultiParamTypeClasses #-}
3{-# LANGUAGE FlexibleInstances #-}
4
5module Numeric.Sparse(
6 GMatrix(..), CSR(..), mkCSR, fromCSR,
7 mkSparse, mkDiagR, mkDense,
8 AssocMatrix,
9 toDense,
10 gmXv, (!#>)
11)where
12
13import Data.Packed.Numeric
14import qualified Data.Vector.Storable as V
15import Data.Function(on)
16import Control.Arrow((***))
17import Control.Monad(when)
18import Data.List(groupBy, sort)
19import Foreign.C.Types(CInt(..))
20
21import Data.Packed.Development
22import System.IO.Unsafe(unsafePerformIO)
23import Foreign(Ptr)
24import Text.Printf(printf)
25
26infixl 0 ~!~
27c ~!~ msg = when c (error msg)
28
29type AssocMatrix = [((Int,Int),Double)]
30
31data CSR = CSR
32 { csrVals :: Vector Double
33 , csrCols :: Vector CInt
34 , csrRows :: Vector CInt
35 , csrNRows :: Int
36 , csrNCols :: Int
37 } deriving Show
38
39data CSC = CSC
40 { cscVals :: Vector Double
41 , cscRows :: Vector CInt
42 , cscCols :: Vector CInt
43 , cscNRows :: Int
44 , cscNCols :: Int
45 } deriving Show
46
47
48mkCSR :: AssocMatrix -> CSR
49mkCSR sm' = CSR{..}
50 where
51 sm = sort sm'
52 rws = map ((fromList *** fromList)
53 . unzip
54 . map ((succ.fi.snd) *** id)
55 )
56 . groupBy ((==) `on` (fst.fst))
57 $ sm
58 rszs = map (fi . dim . fst) rws
59 csrRows = fromList (scanl (+) 1 rszs)
60 csrVals = vjoin (map snd rws)
61 csrCols = vjoin (map fst rws)
62 csrNRows = dim csrRows - 1
63 csrNCols = fromIntegral (V.maximum csrCols)
64
65{- | General matrix with specialized internal representations for
66 dense, sparse, diagonal, banded, and constant elements.
67
68>>> let m = mkSparse [((0,999),1.0),((1,1999),2.0)]
69>>> m
70SparseR {gmCSR = CSR {csrVals = fromList [1.0,2.0],
71 csrCols = fromList [1000,2000],
72 csrRows = fromList [1,2,3],
73 csrNRows = 2,
74 csrNCols = 2000},
75 nRows = 2,
76 nCols = 2000}
77
78>>> let m = mkDense (mat 2 [1..4])
79>>> m
80Dense {gmDense = (2><2)
81 [ 1.0, 2.0
82 , 3.0, 4.0 ], nRows = 2, nCols = 2}
83
84-}
85data GMatrix
86 = SparseR
87 { gmCSR :: CSR
88 , nRows :: Int
89 , nCols :: Int
90 }
91 | SparseC
92 { gmCSC :: CSC
93 , nRows :: Int
94 , nCols :: Int
95 }
96 | Diag
97 { diagVals :: Vector Double
98 , nRows :: Int
99 , nCols :: Int
100 }
101 | Dense
102 { gmDense :: Matrix Double
103 , nRows :: Int
104 , nCols :: Int
105 }
106-- | Banded
107 deriving Show
108
109
110mkDense :: Matrix Double -> GMatrix
111mkDense m = Dense{..}
112 where
113 gmDense = m
114 nRows = rows m
115 nCols = cols m
116
117mkSparse :: AssocMatrix -> GMatrix
118mkSparse = fromCSR . mkCSR
119
120fromCSR :: CSR -> GMatrix
121fromCSR csr = SparseR {..}
122 where
123 gmCSR @ CSR {..} = csr
124 nRows = csrNRows
125 nCols = csrNCols
126
127
128mkDiagR r c v
129 | dim v <= min r c = Diag{..}
130 | otherwise = error $ printf "mkDiagR: incorrect sizes (%d,%d) [%d]" r c (dim v)
131 where
132 nRows = r
133 nCols = c
134 diagVals = v
135
136
137type IV t = CInt -> Ptr CInt -> t
138type V t = CInt -> Ptr Double -> t
139type SMxV = V (IV (IV (V (V (IO CInt)))))
140
141gmXv :: GMatrix -> Vector Double -> Vector Double
142gmXv SparseR { gmCSR = CSR{..}, .. } v = unsafePerformIO $ do
143 dim v /= nCols ~!~ printf "gmXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v)
144 r <- createVector nRows
145 app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv"
146 return r
147
148gmXv SparseC { gmCSC = CSC{..}, .. } v = unsafePerformIO $ do
149 dim v /= nCols ~!~ printf "gmXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v)
150 r <- createVector nRows
151 app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv"
152 return r
153
154gmXv Diag{..} v
155 | dim v == nCols
156 = vjoin [ subVector 0 (dim diagVals) v `mul` diagVals
157 , konst 0 (nRows - dim diagVals) ]
158 | otherwise = error $ printf "gmXv (Diag): incorrect sizes: (%d,%d) [%d] x %d"
159 nRows nCols (dim diagVals) (dim v)
160
161gmXv Dense{..} v
162 | dim v == nCols
163 = mXv gmDense v
164 | otherwise = error $ printf "gmXv (Dense): incorrect sizes: (%d,%d) x %d"
165 nRows nCols (dim v)
166
167
168{- | general matrix - vector product
169
170>>> let m = mkSparse [((0,999),1.0),((1,1999),2.0)]
171>>> m !#> vector [1..2000]
172fromList [1000.0,4000.0]
173
174-}
175infixr 8 !#>
176(!#>) :: GMatrix -> Vector Double -> Vector Double
177(!#>) = gmXv
178
179--------------------------------------------------------------------------------
180
181foreign import ccall unsafe "smXv"
182 c_smXv :: SMxV
183
184foreign import ccall unsafe "smTXv"
185 c_smTXv :: SMxV
186
187--------------------------------------------------------------------------------
188
189toDense :: AssocMatrix -> Matrix Double
190toDense asm = assoc (r+1,c+1) 0 asm
191 where
192 (r,c) = (maximum *** maximum) . unzip . map fst $ asm
193
194
195instance Transposable CSR CSC
196 where
197 tr (CSR vs cs rs n m) = CSC vs cs rs m n
198 tr' = tr
199
200instance Transposable CSC CSR
201 where
202 tr (CSC vs rs cs n m) = CSR vs rs cs m n
203 tr' = tr
204
205instance Transposable GMatrix GMatrix
206 where
207 tr (SparseR s n m) = SparseC (tr s) m n
208 tr (SparseC s n m) = SparseR (tr s) m n
209 tr (Diag v n m) = Diag v m n
210 tr (Dense a n m) = Dense (tr a) m n
211 tr' = tr
212