summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2014-05-27 20:21:47 +0200
committerAlberto Ruiz <aruiz@um.es>2014-05-27 20:21:47 +0200
commit53559833d2166010eed754027484fb8d5525e710 (patch)
treed1cd40e45e6062ef6bece255b20424f90091a910 /packages/base/src/Numeric
parentcf3c788f0c44577ac1a5365e8154200b53a36409 (diff)
expose CSR
Diffstat (limited to 'packages/base/src/Numeric')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Data.hs6
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Real.hs10
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Util.hs7
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Util/CG.hs2
-rw-r--r--packages/base/src/Numeric/Sparse.hs92
5 files changed, 82 insertions, 35 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Data.hs b/packages/base/src/Numeric/LinearAlgebra/Data.hs
index 3417a5e..33a2c9a 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Data.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Data.hs
@@ -49,8 +49,8 @@ module Numeric.LinearAlgebra.Data(
49 find, maxIndex, minIndex, maxElement, minElement, atIndex, 49 find, maxIndex, minIndex, maxElement, minElement, atIndex,
50 50
51 -- * Sparse 51 -- * Sparse
52 GMatrix, AssocMatrix, mkSparse, toDense, 52 AssocMatrix, toDense,
53 mkDiagR, dense, 53 mkSparse, mkDiagR, mkDense,
54 54
55 -- * IO 55 -- * IO
56 disp, 56 disp,
@@ -68,7 +68,7 @@ module Numeric.LinearAlgebra.Data(
68 68
69 module Data.Complex, 69 module Data.Complex,
70 70
71 Vector, Matrix 71 Vector, Matrix, GMatrix, CSR(..), mkCSR
72 72
73) where 73) where
74 74
diff --git a/packages/base/src/Numeric/LinearAlgebra/Real.hs b/packages/base/src/Numeric/LinearAlgebra/Real.hs
index db15705..1e8b544 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Real.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Real.hs
@@ -141,6 +141,16 @@ instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t)))
141 negate = (lift1F . lift1F) negate 141 negate = (lift1F . lift1F) negate
142 fromInteger x = Dim (Dim (fromInteger x)) 142 fromInteger x = Dim (Dim (fromInteger x))
143 143
144instance Fractional (Dim n (Vector Double))
145 where
146 fromRational x = Dim (fromRational x)
147 (/) = lift2F (/)
148
149instance Fractional (Dim m (Dim n (Matrix Double)))
150 where
151 fromRational x = Dim (Dim (fromRational x))
152 (/) = (lift2F.lift2F) (/)
153
144-------------------------------------------------------------------------------- 154--------------------------------------------------------------------------------
145 155
146class Konst t 156class Konst t
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util.hs b/packages/base/src/Numeric/LinearAlgebra/Util.hs
index 47b1090..aee21b8 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Util.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Util.hs
@@ -37,6 +37,7 @@ module Numeric.LinearAlgebra.Util(
37 mnorm_1, mnorm_2, mnorm_0, mnorm_Inf, 37 mnorm_1, mnorm_2, mnorm_0, mnorm_Inf,
38 unitary, 38 unitary,
39 mt, 39 mt,
40 (~!~),
40 pairwiseD2, 41 pairwiseD2,
41 rowOuters, 42 rowOuters,
42 null1, 43 null1,
@@ -65,6 +66,7 @@ import Numeric.Matrix()
65import Numeric.Vector() 66import Numeric.Vector()
66import Numeric.LinearAlgebra.Random 67import Numeric.LinearAlgebra.Random
67import Numeric.LinearAlgebra.Util.Convolution 68import Numeric.LinearAlgebra.Util.Convolution
69import Control.Monad(when)
68 70
69type ℝ = Double 71type ℝ = Double
70type ℕ = Int 72type ℕ = Int
@@ -385,3 +387,8 @@ vtrans p m | r == 0 = fromBlocks . map (map asColumn . takesV (replicate q p)) .
385 where 387 where
386 (q,r) = divMod (rows m) p 388 (q,r) = divMod (rows m) p
387 389
390--------------------------------------------------------------------------------
391
392infixl 0 ~!~
393c ~!~ msg = when c (error msg)
394
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
index 50372f1..f821b57 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
@@ -115,7 +115,7 @@ instance Testable GMatrix
115 sma = convo2 20 3 115 sma = convo2 20 3
116 x1 = vect [1..20] 116 x1 = vect [1..20]
117 x2 = vect [1..40] 117 x2 = vect [1..40]
118 sm = mkSparse sma 118 sm = (mkSparse . mkCSR) sma
119 dm = toDense sma 119 dm = toDense sma
120 120
121 s1 = sm !#> x1 121 s1 = sm !#> x1
diff --git a/packages/base/src/Numeric/Sparse.hs b/packages/base/src/Numeric/Sparse.hs
index 4d05bdc..3c19c93 100644
--- a/packages/base/src/Numeric/Sparse.hs
+++ b/packages/base/src/Numeric/Sparse.hs
@@ -3,8 +3,8 @@
3{-# LANGUAGE FlexibleInstances #-} 3{-# LANGUAGE FlexibleInstances #-}
4 4
5module Numeric.Sparse( 5module Numeric.Sparse(
6 GMatrix(..), 6 GMatrix, CSR(..), mkCSR,
7 mkSparse, mkDiagR, dense, 7 mkSparse, mkDiagR, mkDense,
8 AssocMatrix, 8 AssocMatrix,
9 toDense, 9 toDense,
10 gmXv, (!#>) 10 gmXv, (!#>)
@@ -28,18 +28,49 @@ c ~!~ msg = when c (error msg)
28 28
29type AssocMatrix = [((Int,Int),Double)] 29type AssocMatrix = [((Int,Int),Double)]
30 30
31data CSR = CSR
32 { csrVals :: Vector Double
33 , csrCols :: Vector CInt
34 , csrRows :: Vector CInt
35 , csrNRows :: Int
36 , csrNCols :: Int
37 } deriving Show
38
39data CSC = CSC
40 { cscVals :: Vector Double
41 , cscRows :: Vector CInt
42 , cscCols :: Vector CInt
43 , cscNRows :: Int
44 , cscNCols :: Int
45 } deriving Show
46
47
48mkCSR :: AssocMatrix -> CSR
49mkCSR sm' = CSR{..}
50 where
51 sm = sort sm'
52 rws = map ((fromList *** fromList)
53 . unzip
54 . map ((succ.fi.snd) *** id)
55 )
56 . groupBy ((==) `on` (fst.fst))
57 $ sm
58 rszs = map (fi . dim . fst) rws
59 csrRows = fromList (scanl (+) 1 rszs)
60 csrVals = vjoin (map snd rws)
61 csrCols = vjoin (map fst rws)
62 csrNRows = dim csrRows - 1
63 csrNCols = fromIntegral (V.maximum csrCols)
64
65
31data GMatrix 66data GMatrix
32 = CSR 67 = SparseR
33 { csrVals :: Vector Double 68 { gmCSR :: CSR
34 , csrCols :: Vector CInt
35 , csrRows :: Vector CInt
36 , nRows :: Int 69 , nRows :: Int
37 , nCols :: Int 70 , nCols :: Int
38 } 71 }
39 | CSC 72 | SparseC
40 { cscVals :: Vector Double 73 { gmCSC :: CSC
41 , cscRows :: Vector CInt
42 , cscCols :: Vector CInt
43 , nRows :: Int 74 , nRows :: Int
44 , nCols :: Int 75 , nCols :: Int
45 } 76 }
@@ -56,29 +87,21 @@ data GMatrix
56-- | Banded 87-- | Banded
57 deriving Show 88 deriving Show
58 89
59dense :: Matrix Double -> GMatrix 90
60dense m = Dense{..} 91mkDense :: Matrix Double -> GMatrix
92mkDense m = Dense{..}
61 where 93 where
62 gmDense = m 94 gmDense = m
63 nRows = rows m 95 nRows = rows m
64 nCols = cols m 96 nCols = cols m
65 97
66mkSparse :: AssocMatrix -> GMatrix 98
67mkSparse sm' = CSR{..} 99mkSparse :: CSR -> GMatrix
100mkSparse csr = SparseR {..}
68 where 101 where
69 sm = sort sm' 102 gmCSR @ CSR {..} = csr
70 rws = map ((fromList *** fromList) 103 nRows = csrNRows
71 . unzip 104 nCols = csrNCols
72 . map ((succ.fi.snd) *** id)
73 )
74 . groupBy ((==) `on` (fst.fst))
75 $ sm
76 rszs = map (fi . dim . fst) rws
77 csrRows = fromList (scanl (+) 1 rszs)
78 csrVals = vjoin (map snd rws)
79 csrCols = vjoin (map fst rws)
80 nRows = dim csrRows - 1
81 nCols = fromIntegral (V.maximum csrCols)
82 105
83 106
84mkDiagR r c v 107mkDiagR r c v
@@ -95,13 +118,13 @@ type V t = CInt -> Ptr Double -> t
95type SMxV = V (IV (IV (V (V (IO CInt))))) 118type SMxV = V (IV (IV (V (V (IO CInt)))))
96 119
97gmXv :: GMatrix -> Vector Double -> Vector Double 120gmXv :: GMatrix -> Vector Double -> Vector Double
98gmXv CSR{..} v = unsafePerformIO $ do 121gmXv SparseR { gmCSR = CSR{..}, .. } v = unsafePerformIO $ do
99 dim v /= nCols ~!~ printf "gmXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) 122 dim v /= nCols ~!~ printf "gmXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v)
100 r <- createVector nRows 123 r <- createVector nRows
101 app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv" 124 app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv"
102 return r 125 return r
103 126
104gmXv CSC{..} v = unsafePerformIO $ do 127gmXv SparseC { gmCSC = CSC{..}, .. } v = unsafePerformIO $ do
105 dim v /= nCols ~!~ printf "gmXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) 128 dim v /= nCols ~!~ printf "gmXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v)
106 r <- createVector nRows 129 r <- createVector nRows
107 app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv" 130 app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv"
@@ -147,11 +170,18 @@ toDense asm = assoc (r+1,c+1) 0 asm
147 (r,c) = (maximum *** maximum) . unzip . map fst $ asm 170 (r,c) = (maximum *** maximum) . unzip . map fst $ asm
148 171
149 172
150 173instance Transposable CSR CSC
151instance Transposable GMatrix GMatrix
152 where 174 where
153 tr (CSR vs cs rs n m) = CSC vs cs rs m n 175 tr (CSR vs cs rs n m) = CSC vs cs rs m n
176
177instance Transposable CSC CSR
178 where
154 tr (CSC vs rs cs n m) = CSR vs rs cs m n 179 tr (CSC vs rs cs n m) = CSR vs rs cs m n
180
181instance Transposable GMatrix GMatrix
182 where
183 tr (SparseR s n m) = SparseC (tr s) m n
184 tr (SparseC s n m) = SparseR (tr s) m n
155 tr (Diag v n m) = Diag v m n 185 tr (Diag v n m) = Diag v m n
156 tr (Dense a n m) = Dense (tr a) m n 186 tr (Dense a n m) = Dense (tr a) m n
157 187