summaryrefslogtreecommitdiff
path: root/packages/base/src/Data/Packed/Internal/Sparse.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Data/Packed/Internal/Sparse.hs')
-rw-r--r--packages/base/src/Data/Packed/Internal/Sparse.hs191
1 files changed, 191 insertions, 0 deletions
diff --git a/packages/base/src/Data/Packed/Internal/Sparse.hs b/packages/base/src/Data/Packed/Internal/Sparse.hs
new file mode 100644
index 0000000..544c913
--- /dev/null
+++ b/packages/base/src/Data/Packed/Internal/Sparse.hs
@@ -0,0 +1,191 @@
1{-# LANGUAGE RecordWildCards #-}
2{-# LANGUAGE MultiParamTypeClasses #-}
3{-# LANGUAGE FlexibleInstances #-}
4
5
6
7module Data.Packed.Internal.Sparse(
8 SMatrix(..),
9 mkCSR, mkDiag,
10 AssocMatrix,
11 toDense,
12 smXv
13)where
14
15import Numeric.Container
16import qualified Data.Vector.Storable as V
17import Data.Function(on)
18import Control.Arrow((***))
19import Control.Monad(when)
20import Data.List(groupBy, sort)
21import Foreign.C.Types(CInt(..))
22import Numeric.LinearAlgebra.Devel
23import System.IO.Unsafe(unsafePerformIO)
24import Foreign(Ptr)
25import Text.Printf(printf)
26
27infixl 0 ~!~
28c ~!~ msg = when c (error msg)
29
30type AssocMatrix = [((Int,Int),Double)]
31
32data SMatrix
33 = CSR
34 { csrVals :: Vector Double
35 , csrCols :: Vector CInt
36 , csrRows :: Vector CInt
37 , nRows :: Int
38 , nCols :: Int
39 }
40 | CSC
41 { cscVals :: Vector Double
42 , cscRows :: Vector CInt
43 , cscCols :: Vector CInt
44 , nRows :: Int
45 , nCols :: Int
46 }
47 | Diag
48 { diagVals :: Vector Double
49 , nRows :: Int
50 , nCols :: Int
51 }
52-- | Banded
53 deriving Show
54
55mkCSR :: AssocMatrix -> SMatrix
56mkCSR sm' = CSR{..}
57 where
58 sm = sort sm'
59 rws = map ((fromList *** fromList)
60 . unzip
61 . map ((succ.fi.snd) *** id)
62 )
63 . groupBy ((==) `on` (fst.fst))
64 $ sm
65 rszs = map (fi . dim . fst) rws
66 csrRows = fromList (scanl (+) 1 rszs)
67 csrVals = vjoin (map snd rws)
68 csrCols = vjoin (map fst rws)
69 nRows = dim csrRows - 1
70 nCols = fromIntegral (V.maximum csrCols)
71
72
73mkDiagR r c v
74 | dim v <= min r c = Diag{..}
75 | otherwise = error $ printf "mkDiagR: incorrect sizes (%d,%d) [%d]" r c (dim v)
76 where
77 nRows = r
78 nCols = c
79 diagVals = v
80
81mkDiag v = mkDiagR (dim v) (dim v) v
82
83
84type IV t = CInt -> Ptr CInt -> t
85type V t = CInt -> Ptr Double -> t
86type SMxV = V (IV (IV (V (V (IO CInt)))))
87
88smXv :: SMatrix -> Vector Double -> Vector Double
89smXv CSR{..} v = unsafePerformIO $ do
90 dim v /= nCols ~!~ printf "smXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v)
91 r <- createVector nRows
92 app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv"
93 return r
94
95smXv CSC{..} v = unsafePerformIO $ do
96 dim v /= nCols ~!~ printf "smXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v)
97 r <- createVector nRows
98 app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv"
99 return r
100
101smXv Diag{..} v
102 | dim v == nCols
103 = vjoin [ subVector 0 (dim diagVals) v `mul` diagVals
104 , konst 0 (nRows - dim diagVals) ]
105 | otherwise = error $ printf "smXv (Diag): incorrect sizes: (%d,%d) [%d] x %d"
106 nRows nCols (dim diagVals) (dim v)
107
108
109instance Contraction SMatrix (Vector Double) (Vector Double)
110 where
111 contraction = smXv
112
113--------------------------------------------------------------------------------
114
115foreign import ccall unsafe "smXv"
116 c_smXv :: SMxV
117
118foreign import ccall unsafe "smTXv"
119 c_smTXv :: SMxV
120
121--------------------------------------------------------------------------------
122
123toDense :: AssocMatrix -> Matrix Double
124toDense asm = assoc (r+1,c+1) 0 asm
125 where
126 (r,c) = (maximum *** maximum) . unzip . map fst $ asm
127
128
129
130instance Transposable (SMatrix)
131 where
132 tr (CSR vs cs rs n m) = CSC vs cs rs m n
133 tr (CSC vs rs cs n m) = CSR vs rs cs m n
134 tr (Diag v n m) = Diag v m n
135
136instance Transposable (Matrix Double)
137 where
138 tr = trans
139
140
141
142--------------------------------------------------------------------------------
143
144instance Testable SMatrix
145 where
146 checkT _ = (ok,info)
147 where
148 sma = convo2 20 3
149 x1 = vect [1..20]
150 x2 = vect [1..40]
151 sm = mkCSR sma
152
153 s1 = sm ◇ x1
154 d1 = toDense sma ◇ x1
155
156 s2 = tr sm ◇ x2
157 d2 = tr (toDense sma) ◇ x2
158
159 sdia = mkDiagR 40 20 (vect [1..10])
160 s3 = sdia ◇ x1
161 s4 = tr sdia ◇ x2
162 ddia = diagRect 0 (vect [1..10]) 40 20
163 d3 = ddia ◇ x1
164 d4 = tr ddia ◇ x2
165
166 info = do
167 print sm
168 disp (toDense sma)
169 print s1; print d1
170 print s2; print d2
171 print s3; print d3
172 print s4; print d4
173
174 ok = s1==d1
175 && s2==d2
176 && s3==d3
177 && s4==d4
178
179 disp = putStr . dispf 2
180
181 vect = fromList :: [Double] -> Vector Double
182
183 convomat :: Int -> Int -> AssocMatrix
184 convomat n k = [ ((i,j `mod` n),1) | i<-[0..n-1], j <- [i..i+k-1]]
185
186 convo2 :: Int -> Int -> AssocMatrix
187 convo2 n k = m1 ++ m2
188 where
189 m1 = convomat n k
190 m2 = map (((+n) *** id) *** id) m1
191