diff options
Diffstat (limited to 'lib/Numeric/Container.hs')
-rw-r--r-- | lib/Numeric/Container.hs | 99 |
1 files changed, 71 insertions, 28 deletions
diff --git a/lib/Numeric/Container.hs b/lib/Numeric/Container.hs index d1ce588..a71fdfe 100644 --- a/lib/Numeric/Container.hs +++ b/lib/Numeric/Container.hs | |||
@@ -37,7 +37,8 @@ module Numeric.Container ( | |||
37 | Container(..), | 37 | Container(..), |
38 | -- * Matrix product | 38 | -- * Matrix product |
39 | Product(..), | 39 | Product(..), |
40 | Contraction(..), | 40 | Mul(..), |
41 | Contraction(..), mmul, | ||
41 | optimiseMult, | 42 | optimiseMult, |
42 | mXm,mXv,vXm,LSDiv(..), cdot, (·), dot, (<.>), | 43 | mXm,mXv,vXm,LSDiv(..), cdot, (·), dot, (<.>), |
43 | outer, kronecker, | 44 | outer, kronecker, |
@@ -102,12 +103,19 @@ cdot u v = udot (conj u) v | |||
102 | 103 | ||
103 | -------------------------------------------------------- | 104 | -------------------------------------------------------- |
104 | 105 | ||
105 | class Contraction a b c | a b -> c, a c -> b, b c -> a | 106 | class Contraction a b c | a b -> c, c -> a b |
106 | where | 107 | where |
107 | infixl 7 <> | 108 | infixr 7 × |
108 | {- | Matrix-matrix product, matrix-vector product, and unconjugated dot product | 109 | {- | Matrix-matrix product, matrix-vector product, and unconjugated dot product |
109 | 110 | ||
110 | >>> let a = (3><4) [1..] :: Matrix Double | 111 | (unicode 0x00d7, multiplication sign) |
112 | |||
113 | Examples: | ||
114 | |||
115 | >>> let a = (3><4) [1..] :: Matrix Double | ||
116 | >>> let v = fromList [1,0,2,-1] :: Vector Double | ||
117 | >>> let u = fromList [1,2,3] :: Vector Double | ||
118 | |||
111 | >>> a | 119 | >>> a |
112 | (3><4) | 120 | (3><4) |
113 | [ 1.0, 2.0, 3.0, 4.0 | 121 | [ 1.0, 2.0, 3.0, 4.0 |
@@ -116,7 +124,7 @@ class Contraction a b c | a b -> c, a c -> b, b c -> a | |||
116 | 124 | ||
117 | matrix × matrix: | 125 | matrix × matrix: |
118 | 126 | ||
119 | >>> disp 2 (a <> trans a) | 127 | >>> disp 2 (a × trans a) |
120 | 3x3 | 128 | 3x3 |
121 | 30 70 110 | 129 | 30 70 110 |
122 | 70 174 278 | 130 | 70 174 278 |
@@ -124,52 +132,79 @@ matrix × matrix: | |||
124 | 132 | ||
125 | matrix × vector: | 133 | matrix × vector: |
126 | 134 | ||
127 | >>> a <> fromList [1,0,2,-1::Double] | 135 | >>> a × v |
128 | fromList [3.0,11.0,19.0] | 136 | fromList [3.0,11.0,19.0] |
129 | 137 | ||
130 | vector × matrix: | ||
131 | |||
132 | >>> fromList [1,2,3::Double] <> a | ||
133 | fromList [38.0,44.0,50.0,56.0] | ||
134 | |||
135 | unconjugated dot product: | 138 | unconjugated dot product: |
136 | 139 | ||
137 | >>> fromList [1,i] <> fromList[2*i+1,3] | 140 | >>> fromList [1,i] × fromList[2*i+1,3] |
138 | 1.0 :+ 5.0 | 141 | 1.0 :+ 5.0 |
139 | 142 | ||
140 | -} | 143 | (×) is right associative, so we can write: |
141 | (<>) :: a -> b -> c | ||
142 | 144 | ||
143 | instance Product t => Contraction (Vector t) (Vector t) t where | 145 | >>> u × a × v |
144 | (<>) = udot | 146 | 82.0 :: Double |
145 | 147 | ||
146 | instance Product t => Contraction (Matrix t) (Vector t) (Vector t) where | 148 | -} |
147 | (<>) = mXv | 149 | (×) :: a -> b -> c |
148 | 150 | ||
149 | instance Product t => Contraction (Vector t) (Matrix t) (Vector t) where | 151 | instance Product t => Contraction (Matrix t) (Vector t) (Vector t) where |
150 | (<>) = vXm | 152 | (×) = mXv |
151 | 153 | ||
152 | instance Product t => Contraction (Matrix t) (Matrix t) (Matrix t) where | 154 | instance Product t => Contraction (Matrix t) (Matrix t) (Matrix t) where |
155 | (×) = mXm | ||
156 | |||
157 | instance Contraction (Vector Double) (Vector Double) Double where | ||
158 | (×) = udot | ||
159 | |||
160 | instance Contraction (Vector Float) (Vector Float) Float where | ||
161 | (×) = udot | ||
162 | |||
163 | instance Contraction (Vector (Complex Double)) (Vector (Complex Double)) (Complex Double) where | ||
164 | (×) = udot | ||
165 | |||
166 | instance Contraction (Vector (Complex Float)) (Vector (Complex Float)) (Complex Float) where | ||
167 | (×) = udot | ||
168 | |||
169 | |||
170 | -- | alternative function for the matrix product (×) | ||
171 | mmul :: Contraction a b c => a -> b -> c | ||
172 | mmul = (×) | ||
173 | |||
174 | -------------------------------------------------------------------------------- | ||
175 | |||
176 | class Mul a b c | a b -> c where | ||
177 | infixl 7 <> | ||
178 | -- | Matrix-matrix, matrix-vector, and vector-matrix products. | ||
179 | (<>) :: Product t => a t -> b t -> c t | ||
180 | |||
181 | instance Mul Matrix Matrix Matrix where | ||
153 | (<>) = mXm | 182 | (<>) = mXm |
154 | 183 | ||
155 | -------------------------------------------------------- | 184 | instance Mul Matrix Vector Vector where |
185 | (<>) m v = flatten $ m <> asColumn v | ||
156 | 186 | ||
157 | class LSDiv b c | b -> c, c->b where | 187 | instance Mul Vector Matrix Vector where |
188 | (<>) v m = flatten $ asRow v <> m | ||
189 | |||
190 | -------------------------------------------------------------------------------- | ||
191 | |||
192 | class LSDiv c where | ||
158 | infixl 7 <\> | 193 | infixl 7 <\> |
159 | -- | least squares solution of a linear system, similar to the \\ operator of Matlab\/Octave (based on linearSolveSVD) | 194 | -- | least squares solution of a linear system, similar to the \\ operator of Matlab\/Octave (based on linearSolveSVD) |
160 | (<\>) :: Field t => Matrix t -> b t -> c t | 195 | (<\>) :: Field t => Matrix t -> c t -> c t |
161 | 196 | ||
162 | instance LSDiv Vector Vector where | 197 | instance LSDiv Vector where |
163 | m <\> v = flatten (linearSolveSVD m (reshape 1 v)) | 198 | m <\> v = flatten (linearSolveSVD m (reshape 1 v)) |
164 | 199 | ||
165 | instance LSDiv Matrix Matrix where | 200 | instance LSDiv Matrix where |
166 | (<\>) = linearSolveSVD | 201 | (<\>) = linearSolveSVD |
167 | 202 | ||
168 | -------------------------------------------------------- | 203 | -------------------------------------------------------- |
169 | 204 | ||
170 | {- | Dot product : @u · v = 'cdot' u v@ | 205 | {- | Dot product : @u · v = 'cdot' u v@ |
171 | 206 | ||
172 | (unicode 0x00b7, Alt-Gr .) | 207 | (unicode 0x00b7, middle dot, Alt-Gr .) |
173 | 208 | ||
174 | >>> fromList [1,i] · fromList[2*i+1,3] | 209 | >>> fromList [1,i] · fromList[2*i+1,3] |
175 | 1.0 :+ (-1.0) | 210 | 1.0 :+ (-1.0) |
@@ -233,7 +268,15 @@ instance Container Matrix e => Build (Int,Int) (e -> e -> e) Matrix e | |||
233 | 268 | ||
234 | -------------------------------------------------------------------------------- | 269 | -------------------------------------------------------------------------------- |
235 | 270 | ||
236 | -- | Compute mean vector and covariance matrix of the rows of a matrix. | 271 | {- | Compute mean vector and covariance matrix of the rows of a matrix. |
272 | |||
273 | >>> meanCov $ gaussianSample 666 1000 (fromList[4,5]) (diagl[2,3]) | ||
274 | (fromList [4.010341078059521,5.0197204699640405], | ||
275 | (2><2) | ||
276 | [ 1.9862461923890056, -1.0127225830525157e-2 | ||
277 | , -1.0127225830525157e-2, 3.0373954915729318 ]) | ||
278 | |||
279 | -} | ||
237 | meanCov :: Matrix Double -> (Vector Double, Matrix Double) | 280 | meanCov :: Matrix Double -> (Vector Double, Matrix Double) |
238 | meanCov x = (med,cov) where | 281 | meanCov x = (med,cov) where |
239 | r = rows x | 282 | r = rows x |
@@ -249,7 +292,7 @@ meanCov x = (med,cov) where | |||
249 | dot :: Product e => Vector e -> Vector e -> e | 292 | dot :: Product e => Vector e -> Vector e -> e |
250 | dot = udot | 293 | dot = udot |
251 | 294 | ||
252 | {-# DEPRECATED (<.>) "use udot or (<>)" #-} | 295 | {-# DEPRECATED (<.>) "use udot or (×)" #-} |
253 | infixl 7 <.> | 296 | infixl 7 <.> |
254 | (<.>) :: Product e => Vector e -> Vector e -> e | 297 | (<.>) :: Product e => Vector e -> Vector e -> e |
255 | (<.>) = udot | 298 | (<.>) = udot |