summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric/Sparse.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Numeric/Sparse.hs')
-rw-r--r--packages/base/src/Numeric/Sparse.hs92
1 files changed, 61 insertions, 31 deletions
diff --git a/packages/base/src/Numeric/Sparse.hs b/packages/base/src/Numeric/Sparse.hs
index 4d05bdc..3c19c93 100644
--- a/packages/base/src/Numeric/Sparse.hs
+++ b/packages/base/src/Numeric/Sparse.hs
@@ -3,8 +3,8 @@
3{-# LANGUAGE FlexibleInstances #-} 3{-# LANGUAGE FlexibleInstances #-}
4 4
5module Numeric.Sparse( 5module Numeric.Sparse(
6 GMatrix(..), 6 GMatrix, CSR(..), mkCSR,
7 mkSparse, mkDiagR, dense, 7 mkSparse, mkDiagR, mkDense,
8 AssocMatrix, 8 AssocMatrix,
9 toDense, 9 toDense,
10 gmXv, (!#>) 10 gmXv, (!#>)
@@ -28,18 +28,49 @@ c ~!~ msg = when c (error msg)
28 28
29type AssocMatrix = [((Int,Int),Double)] 29type AssocMatrix = [((Int,Int),Double)]
30 30
31data CSR = CSR
32 { csrVals :: Vector Double
33 , csrCols :: Vector CInt
34 , csrRows :: Vector CInt
35 , csrNRows :: Int
36 , csrNCols :: Int
37 } deriving Show
38
39data CSC = CSC
40 { cscVals :: Vector Double
41 , cscRows :: Vector CInt
42 , cscCols :: Vector CInt
43 , cscNRows :: Int
44 , cscNCols :: Int
45 } deriving Show
46
47
48mkCSR :: AssocMatrix -> CSR
49mkCSR sm' = CSR{..}
50 where
51 sm = sort sm'
52 rws = map ((fromList *** fromList)
53 . unzip
54 . map ((succ.fi.snd) *** id)
55 )
56 . groupBy ((==) `on` (fst.fst))
57 $ sm
58 rszs = map (fi . dim . fst) rws
59 csrRows = fromList (scanl (+) 1 rszs)
60 csrVals = vjoin (map snd rws)
61 csrCols = vjoin (map fst rws)
62 csrNRows = dim csrRows - 1
63 csrNCols = fromIntegral (V.maximum csrCols)
64
65
31data GMatrix 66data GMatrix
32 = CSR 67 = SparseR
33 { csrVals :: Vector Double 68 { gmCSR :: CSR
34 , csrCols :: Vector CInt
35 , csrRows :: Vector CInt
36 , nRows :: Int 69 , nRows :: Int
37 , nCols :: Int 70 , nCols :: Int
38 } 71 }
39 | CSC 72 | SparseC
40 { cscVals :: Vector Double 73 { gmCSC :: CSC
41 , cscRows :: Vector CInt
42 , cscCols :: Vector CInt
43 , nRows :: Int 74 , nRows :: Int
44 , nCols :: Int 75 , nCols :: Int
45 } 76 }
@@ -56,29 +87,21 @@ data GMatrix
56-- | Banded 87-- | Banded
57 deriving Show 88 deriving Show
58 89
59dense :: Matrix Double -> GMatrix 90
60dense m = Dense{..} 91mkDense :: Matrix Double -> GMatrix
92mkDense m = Dense{..}
61 where 93 where
62 gmDense = m 94 gmDense = m
63 nRows = rows m 95 nRows = rows m
64 nCols = cols m 96 nCols = cols m
65 97
66mkSparse :: AssocMatrix -> GMatrix 98
67mkSparse sm' = CSR{..} 99mkSparse :: CSR -> GMatrix
100mkSparse csr = SparseR {..}
68 where 101 where
69 sm = sort sm' 102 gmCSR @ CSR {..} = csr
70 rws = map ((fromList *** fromList) 103 nRows = csrNRows
71 . unzip 104 nCols = csrNCols
72 . map ((succ.fi.snd) *** id)
73 )
74 . groupBy ((==) `on` (fst.fst))
75 $ sm
76 rszs = map (fi . dim . fst) rws
77 csrRows = fromList (scanl (+) 1 rszs)
78 csrVals = vjoin (map snd rws)
79 csrCols = vjoin (map fst rws)
80 nRows = dim csrRows - 1
81 nCols = fromIntegral (V.maximum csrCols)
82 105
83 106
84mkDiagR r c v 107mkDiagR r c v
@@ -95,13 +118,13 @@ type V t = CInt -> Ptr Double -> t
95type SMxV = V (IV (IV (V (V (IO CInt))))) 118type SMxV = V (IV (IV (V (V (IO CInt)))))
96 119
97gmXv :: GMatrix -> Vector Double -> Vector Double 120gmXv :: GMatrix -> Vector Double -> Vector Double
98gmXv CSR{..} v = unsafePerformIO $ do 121gmXv SparseR { gmCSR = CSR{..}, .. } v = unsafePerformIO $ do
99 dim v /= nCols ~!~ printf "gmXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) 122 dim v /= nCols ~!~ printf "gmXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v)
100 r <- createVector nRows 123 r <- createVector nRows
101 app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv" 124 app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv"
102 return r 125 return r
103 126
104gmXv CSC{..} v = unsafePerformIO $ do 127gmXv SparseC { gmCSC = CSC{..}, .. } v = unsafePerformIO $ do
105 dim v /= nCols ~!~ printf "gmXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) 128 dim v /= nCols ~!~ printf "gmXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v)
106 r <- createVector nRows 129 r <- createVector nRows
107 app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv" 130 app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv"
@@ -147,11 +170,18 @@ toDense asm = assoc (r+1,c+1) 0 asm
147 (r,c) = (maximum *** maximum) . unzip . map fst $ asm 170 (r,c) = (maximum *** maximum) . unzip . map fst $ asm
148 171
149 172
150 173instance Transposable CSR CSC
151instance Transposable GMatrix GMatrix
152 where 174 where
153 tr (CSR vs cs rs n m) = CSC vs cs rs m n 175 tr (CSR vs cs rs n m) = CSC vs cs rs m n
176
177instance Transposable CSC CSR
178 where
154 tr (CSC vs rs cs n m) = CSR vs rs cs m n 179 tr (CSC vs rs cs n m) = CSR vs rs cs m n
180
181instance Transposable GMatrix GMatrix
182 where
183 tr (SparseR s n m) = SparseC (tr s) m n
184 tr (SparseC s n m) = SparseR (tr s) m n
155 tr (Diag v n m) = Diag v m n 185 tr (Diag v n m) = Diag v m n
156 tr (Dense a n m) = Dense (tr a) m n 186 tr (Dense a n m) = Dense (tr a) m n
157 187