diff options
Diffstat (limited to 'packages/base/src/Data')
-rw-r--r-- | packages/base/src/Data/Packed/Internal/Numeric.hs | 23 | ||||
-rw-r--r-- | packages/base/src/Data/Packed/Internal/Sparse.hs | 191 |
2 files changed, 214 insertions, 0 deletions
diff --git a/packages/base/src/Data/Packed/Internal/Numeric.hs b/packages/base/src/Data/Packed/Internal/Numeric.hs index 81a8083..3528e96 100644 --- a/packages/base/src/Data/Packed/Internal/Numeric.hs +++ b/packages/base/src/Data/Packed/Internal/Numeric.hs | |||
@@ -20,6 +20,7 @@ module Data.Packed.Internal.Numeric ( | |||
20 | ident, diag, ctrans, | 20 | ident, diag, ctrans, |
21 | -- * Generic operations | 21 | -- * Generic operations |
22 | Container(..), | 22 | Container(..), |
23 | Transposable(..), Linear(..), Testable(..), | ||
23 | -- * Matrix product and related functions | 24 | -- * Matrix product and related functions |
24 | Product(..), udot, | 25 | Product(..), udot, |
25 | mXm,mXv,vXm, | 26 | mXm,mXv,vXm, |
@@ -605,3 +606,25 @@ condV f a b l e t = f a' b' l' e' t' | |||
605 | where | 606 | where |
606 | [a', b', l', e', t'] = conformVs [a,b,l,e,t] | 607 | [a', b', l', e', t'] = conformVs [a,b,l,e,t] |
607 | 608 | ||
609 | -------------------------------------------------------------------------------- | ||
610 | |||
611 | class Transposable t | ||
612 | where | ||
613 | tr :: t -> t | ||
614 | |||
615 | |||
616 | class Linear t v | ||
617 | where | ||
618 | scalarL :: t -> v | ||
619 | addL :: v -> v -> v | ||
620 | scaleL :: t -> v -> v | ||
621 | |||
622 | |||
623 | class Testable t | ||
624 | where | ||
625 | checkT :: t -> (Bool, IO()) | ||
626 | ioCheckT :: t -> IO (Bool, IO()) | ||
627 | ioCheckT = return . checkT | ||
628 | |||
629 | -------------------------------------------------------------------------------- | ||
630 | |||
diff --git a/packages/base/src/Data/Packed/Internal/Sparse.hs b/packages/base/src/Data/Packed/Internal/Sparse.hs new file mode 100644 index 0000000..544c913 --- /dev/null +++ b/packages/base/src/Data/Packed/Internal/Sparse.hs | |||
@@ -0,0 +1,191 @@ | |||
1 | {-# LANGUAGE RecordWildCards #-} | ||
2 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
3 | {-# LANGUAGE FlexibleInstances #-} | ||
4 | |||
5 | |||
6 | |||
7 | module Data.Packed.Internal.Sparse( | ||
8 | SMatrix(..), | ||
9 | mkCSR, mkDiag, | ||
10 | AssocMatrix, | ||
11 | toDense, | ||
12 | smXv | ||
13 | )where | ||
14 | |||
15 | import Numeric.Container | ||
16 | import qualified Data.Vector.Storable as V | ||
17 | import Data.Function(on) | ||
18 | import Control.Arrow((***)) | ||
19 | import Control.Monad(when) | ||
20 | import Data.List(groupBy, sort) | ||
21 | import Foreign.C.Types(CInt(..)) | ||
22 | import Numeric.LinearAlgebra.Devel | ||
23 | import System.IO.Unsafe(unsafePerformIO) | ||
24 | import Foreign(Ptr) | ||
25 | import Text.Printf(printf) | ||
26 | |||
27 | infixl 0 ~!~ | ||
28 | c ~!~ msg = when c (error msg) | ||
29 | |||
30 | type AssocMatrix = [((Int,Int),Double)] | ||
31 | |||
32 | data SMatrix | ||
33 | = CSR | ||
34 | { csrVals :: Vector Double | ||
35 | , csrCols :: Vector CInt | ||
36 | , csrRows :: Vector CInt | ||
37 | , nRows :: Int | ||
38 | , nCols :: Int | ||
39 | } | ||
40 | | CSC | ||
41 | { cscVals :: Vector Double | ||
42 | , cscRows :: Vector CInt | ||
43 | , cscCols :: Vector CInt | ||
44 | , nRows :: Int | ||
45 | , nCols :: Int | ||
46 | } | ||
47 | | Diag | ||
48 | { diagVals :: Vector Double | ||
49 | , nRows :: Int | ||
50 | , nCols :: Int | ||
51 | } | ||
52 | -- | Banded | ||
53 | deriving Show | ||
54 | |||
55 | mkCSR :: AssocMatrix -> SMatrix | ||
56 | mkCSR sm' = CSR{..} | ||
57 | where | ||
58 | sm = sort sm' | ||
59 | rws = map ((fromList *** fromList) | ||
60 | . unzip | ||
61 | . map ((succ.fi.snd) *** id) | ||
62 | ) | ||
63 | . groupBy ((==) `on` (fst.fst)) | ||
64 | $ sm | ||
65 | rszs = map (fi . dim . fst) rws | ||
66 | csrRows = fromList (scanl (+) 1 rszs) | ||
67 | csrVals = vjoin (map snd rws) | ||
68 | csrCols = vjoin (map fst rws) | ||
69 | nRows = dim csrRows - 1 | ||
70 | nCols = fromIntegral (V.maximum csrCols) | ||
71 | |||
72 | |||
73 | mkDiagR r c v | ||
74 | | dim v <= min r c = Diag{..} | ||
75 | | otherwise = error $ printf "mkDiagR: incorrect sizes (%d,%d) [%d]" r c (dim v) | ||
76 | where | ||
77 | nRows = r | ||
78 | nCols = c | ||
79 | diagVals = v | ||
80 | |||
81 | mkDiag v = mkDiagR (dim v) (dim v) v | ||
82 | |||
83 | |||
84 | type IV t = CInt -> Ptr CInt -> t | ||
85 | type V t = CInt -> Ptr Double -> t | ||
86 | type SMxV = V (IV (IV (V (V (IO CInt))))) | ||
87 | |||
88 | smXv :: SMatrix -> Vector Double -> Vector Double | ||
89 | smXv CSR{..} v = unsafePerformIO $ do | ||
90 | dim v /= nCols ~!~ printf "smXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) | ||
91 | r <- createVector nRows | ||
92 | app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv" | ||
93 | return r | ||
94 | |||
95 | smXv CSC{..} v = unsafePerformIO $ do | ||
96 | dim v /= nCols ~!~ printf "smXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) | ||
97 | r <- createVector nRows | ||
98 | app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv" | ||
99 | return r | ||
100 | |||
101 | smXv Diag{..} v | ||
102 | | dim v == nCols | ||
103 | = vjoin [ subVector 0 (dim diagVals) v `mul` diagVals | ||
104 | , konst 0 (nRows - dim diagVals) ] | ||
105 | | otherwise = error $ printf "smXv (Diag): incorrect sizes: (%d,%d) [%d] x %d" | ||
106 | nRows nCols (dim diagVals) (dim v) | ||
107 | |||
108 | |||
109 | instance Contraction SMatrix (Vector Double) (Vector Double) | ||
110 | where | ||
111 | contraction = smXv | ||
112 | |||
113 | -------------------------------------------------------------------------------- | ||
114 | |||
115 | foreign import ccall unsafe "smXv" | ||
116 | c_smXv :: SMxV | ||
117 | |||
118 | foreign import ccall unsafe "smTXv" | ||
119 | c_smTXv :: SMxV | ||
120 | |||
121 | -------------------------------------------------------------------------------- | ||
122 | |||
123 | toDense :: AssocMatrix -> Matrix Double | ||
124 | toDense asm = assoc (r+1,c+1) 0 asm | ||
125 | where | ||
126 | (r,c) = (maximum *** maximum) . unzip . map fst $ asm | ||
127 | |||
128 | |||
129 | |||
130 | instance Transposable (SMatrix) | ||
131 | where | ||
132 | tr (CSR vs cs rs n m) = CSC vs cs rs m n | ||
133 | tr (CSC vs rs cs n m) = CSR vs rs cs m n | ||
134 | tr (Diag v n m) = Diag v m n | ||
135 | |||
136 | instance Transposable (Matrix Double) | ||
137 | where | ||
138 | tr = trans | ||
139 | |||
140 | |||
141 | |||
142 | -------------------------------------------------------------------------------- | ||
143 | |||
144 | instance Testable SMatrix | ||
145 | where | ||
146 | checkT _ = (ok,info) | ||
147 | where | ||
148 | sma = convo2 20 3 | ||
149 | x1 = vect [1..20] | ||
150 | x2 = vect [1..40] | ||
151 | sm = mkCSR sma | ||
152 | |||
153 | s1 = sm ◇ x1 | ||
154 | d1 = toDense sma ◇ x1 | ||
155 | |||
156 | s2 = tr sm ◇ x2 | ||
157 | d2 = tr (toDense sma) ◇ x2 | ||
158 | |||
159 | sdia = mkDiagR 40 20 (vect [1..10]) | ||
160 | s3 = sdia ◇ x1 | ||
161 | s4 = tr sdia ◇ x2 | ||
162 | ddia = diagRect 0 (vect [1..10]) 40 20 | ||
163 | d3 = ddia ◇ x1 | ||
164 | d4 = tr ddia ◇ x2 | ||
165 | |||
166 | info = do | ||
167 | print sm | ||
168 | disp (toDense sma) | ||
169 | print s1; print d1 | ||
170 | print s2; print d2 | ||
171 | print s3; print d3 | ||
172 | print s4; print d4 | ||
173 | |||
174 | ok = s1==d1 | ||
175 | && s2==d2 | ||
176 | && s3==d3 | ||
177 | && s4==d4 | ||
178 | |||
179 | disp = putStr . dispf 2 | ||
180 | |||
181 | vect = fromList :: [Double] -> Vector Double | ||
182 | |||
183 | convomat :: Int -> Int -> AssocMatrix | ||
184 | convomat n k = [ ((i,j `mod` n),1) | i<-[0..n-1], j <- [i..i+k-1]] | ||
185 | |||
186 | convo2 :: Int -> Int -> AssocMatrix | ||
187 | convo2 n k = m1 ++ m2 | ||
188 | where | ||
189 | m1 = convomat n k | ||
190 | m2 = map (((+n) *** id) *** id) m1 | ||
191 | |||