diff options
author | Alberto Ruiz <aruiz@um.es> | 2014-05-22 20:09:41 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2014-05-22 20:09:41 +0200 |
commit | 85af0a1d5ba2d1c03f05458f9689195e82f6ae7e (patch) | |
tree | 07fce2a4b912b85c321e8b1175b52efddc1c4fcb /packages/base/src/Numeric/Sparse.hs | |
parent | b5125366953a6ae66ff014b736baf79c0feb47dd (diff) |
cgSolve
Diffstat (limited to 'packages/base/src/Numeric/Sparse.hs')
-rw-r--r-- | packages/base/src/Numeric/Sparse.hs | 192 |
1 files changed, 192 insertions, 0 deletions
diff --git a/packages/base/src/Numeric/Sparse.hs b/packages/base/src/Numeric/Sparse.hs new file mode 100644 index 0000000..3835590 --- /dev/null +++ b/packages/base/src/Numeric/Sparse.hs | |||
@@ -0,0 +1,192 @@ | |||
1 | {-# LANGUAGE RecordWildCards #-} | ||
2 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
3 | {-# LANGUAGE FlexibleInstances #-} | ||
4 | |||
5 | module Numeric.Sparse( | ||
6 | SMatrix(..), | ||
7 | mkCSR, mkDiag, | ||
8 | AssocMatrix, | ||
9 | toDense, | ||
10 | smXv | ||
11 | )where | ||
12 | |||
13 | import Numeric.Container | ||
14 | import qualified Data.Vector.Storable as V | ||
15 | import Data.Function(on) | ||
16 | import Control.Arrow((***)) | ||
17 | import Control.Monad(when) | ||
18 | import Data.List(groupBy, sort) | ||
19 | import Foreign.C.Types(CInt(..)) | ||
20 | import Numeric.LinearAlgebra.Util.CG(CGMat) | ||
21 | import Data.Packed.Development | ||
22 | import System.IO.Unsafe(unsafePerformIO) | ||
23 | import Foreign(Ptr) | ||
24 | import Text.Printf(printf) | ||
25 | |||
26 | infixl 0 ~!~ | ||
27 | c ~!~ msg = when c (error msg) | ||
28 | |||
29 | type AssocMatrix = [((Int,Int),Double)] | ||
30 | |||
31 | data SMatrix | ||
32 | = CSR | ||
33 | { csrVals :: Vector Double | ||
34 | , csrCols :: Vector CInt | ||
35 | , csrRows :: Vector CInt | ||
36 | , nRows :: Int | ||
37 | , nCols :: Int | ||
38 | } | ||
39 | | CSC | ||
40 | { cscVals :: Vector Double | ||
41 | , cscRows :: Vector CInt | ||
42 | , cscCols :: Vector CInt | ||
43 | , nRows :: Int | ||
44 | , nCols :: Int | ||
45 | } | ||
46 | | Diag | ||
47 | { diagVals :: Vector Double | ||
48 | , nRows :: Int | ||
49 | , nCols :: Int | ||
50 | } | ||
51 | -- | Banded | ||
52 | deriving Show | ||
53 | |||
54 | mkCSR :: AssocMatrix -> SMatrix | ||
55 | mkCSR sm' = CSR{..} | ||
56 | where | ||
57 | sm = sort sm' | ||
58 | rws = map ((fromList *** fromList) | ||
59 | . unzip | ||
60 | . map ((succ.fi.snd) *** id) | ||
61 | ) | ||
62 | . groupBy ((==) `on` (fst.fst)) | ||
63 | $ sm | ||
64 | rszs = map (fi . dim . fst) rws | ||
65 | csrRows = fromList (scanl (+) 1 rszs) | ||
66 | csrVals = vjoin (map snd rws) | ||
67 | csrCols = vjoin (map fst rws) | ||
68 | nRows = dim csrRows - 1 | ||
69 | nCols = fromIntegral (V.maximum csrCols) | ||
70 | |||
71 | |||
72 | mkDiagR r c v | ||
73 | | dim v <= min r c = Diag{..} | ||
74 | | otherwise = error $ printf "mkDiagR: incorrect sizes (%d,%d) [%d]" r c (dim v) | ||
75 | where | ||
76 | nRows = r | ||
77 | nCols = c | ||
78 | diagVals = v | ||
79 | |||
80 | mkDiag v = mkDiagR (dim v) (dim v) v | ||
81 | |||
82 | |||
83 | type IV t = CInt -> Ptr CInt -> t | ||
84 | type V t = CInt -> Ptr Double -> t | ||
85 | type SMxV = V (IV (IV (V (V (IO CInt))))) | ||
86 | |||
87 | smXv :: SMatrix -> Vector Double -> Vector Double | ||
88 | smXv CSR{..} v = unsafePerformIO $ do | ||
89 | dim v /= nCols ~!~ printf "smXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) | ||
90 | r <- createVector nRows | ||
91 | app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv" | ||
92 | return r | ||
93 | |||
94 | smXv CSC{..} v = unsafePerformIO $ do | ||
95 | dim v /= nCols ~!~ printf "smXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) | ||
96 | r <- createVector nRows | ||
97 | app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv" | ||
98 | return r | ||
99 | |||
100 | smXv Diag{..} v | ||
101 | | dim v == nCols | ||
102 | = vjoin [ subVector 0 (dim diagVals) v `mul` diagVals | ||
103 | , konst 0 (nRows - dim diagVals) ] | ||
104 | | otherwise = error $ printf "smXv (Diag): incorrect sizes: (%d,%d) [%d] x %d" | ||
105 | nRows nCols (dim diagVals) (dim v) | ||
106 | |||
107 | |||
108 | instance Contraction SMatrix (Vector Double) (Vector Double) | ||
109 | where | ||
110 | contraction = smXv | ||
111 | |||
112 | -------------------------------------------------------------------------------- | ||
113 | |||
114 | foreign import ccall unsafe "smXv" | ||
115 | c_smXv :: SMxV | ||
116 | |||
117 | foreign import ccall unsafe "smTXv" | ||
118 | c_smTXv :: SMxV | ||
119 | |||
120 | -------------------------------------------------------------------------------- | ||
121 | |||
122 | toDense :: AssocMatrix -> Matrix Double | ||
123 | toDense asm = assoc (r+1,c+1) 0 asm | ||
124 | where | ||
125 | (r,c) = (maximum *** maximum) . unzip . map fst $ asm | ||
126 | |||
127 | |||
128 | |||
129 | instance Transposable SMatrix | ||
130 | where | ||
131 | tr (CSR vs cs rs n m) = CSC vs cs rs m n | ||
132 | tr (CSC vs rs cs n m) = CSR vs rs cs m n | ||
133 | tr (Diag v n m) = Diag v m n | ||
134 | |||
135 | instance Transposable (Matrix Double) | ||
136 | where | ||
137 | tr = trans | ||
138 | |||
139 | |||
140 | instance CGMat SMatrix | ||
141 | instance CGMat (Matrix Double) | ||
142 | |||
143 | -------------------------------------------------------------------------------- | ||
144 | |||
145 | instance Testable SMatrix | ||
146 | where | ||
147 | checkT _ = (ok,info) | ||
148 | where | ||
149 | sma = convo2 20 3 | ||
150 | x1 = vect [1..20] | ||
151 | x2 = vect [1..40] | ||
152 | sm = mkCSR sma | ||
153 | |||
154 | s1 = sm ◇ x1 | ||
155 | d1 = toDense sma ◇ x1 | ||
156 | |||
157 | s2 = tr sm ◇ x2 | ||
158 | d2 = tr (toDense sma) ◇ x2 | ||
159 | |||
160 | sdia = mkDiagR 40 20 (vect [1..10]) | ||
161 | s3 = sdia ◇ x1 | ||
162 | s4 = tr sdia ◇ x2 | ||
163 | ddia = diagRect 0 (vect [1..10]) 40 20 | ||
164 | d3 = ddia ◇ x1 | ||
165 | d4 = tr ddia ◇ x2 | ||
166 | |||
167 | info = do | ||
168 | print sm | ||
169 | disp (toDense sma) | ||
170 | print s1; print d1 | ||
171 | print s2; print d2 | ||
172 | print s3; print d3 | ||
173 | print s4; print d4 | ||
174 | |||
175 | ok = s1==d1 | ||
176 | && s2==d2 | ||
177 | && s3==d3 | ||
178 | && s4==d4 | ||
179 | |||
180 | disp = putStr . dispf 2 | ||
181 | |||
182 | vect = fromList :: [Double] -> Vector Double | ||
183 | |||
184 | convomat :: Int -> Int -> AssocMatrix | ||
185 | convomat n k = [ ((i,j `mod` n),1) | i<-[0..n-1], j <- [i..i+k-1]] | ||
186 | |||
187 | convo2 :: Int -> Int -> AssocMatrix | ||
188 | convo2 n k = m1 ++ m2 | ||
189 | where | ||
190 | m1 = convomat n k | ||
191 | m2 = map (((+n) *** id) *** id) m1 | ||
192 | |||