summaryrefslogtreecommitdiff
path: root/lib/Numeric/Container.hs
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Numeric/Container.hs')
-rw-r--r--lib/Numeric/Container.hs99
1 files changed, 71 insertions, 28 deletions
diff --git a/lib/Numeric/Container.hs b/lib/Numeric/Container.hs
index d1ce588..a71fdfe 100644
--- a/lib/Numeric/Container.hs
+++ b/lib/Numeric/Container.hs
@@ -37,7 +37,8 @@ module Numeric.Container (
37 Container(..), 37 Container(..),
38 -- * Matrix product 38 -- * Matrix product
39 Product(..), 39 Product(..),
40 Contraction(..), 40 Mul(..),
41 Contraction(..), mmul,
41 optimiseMult, 42 optimiseMult,
42 mXm,mXv,vXm,LSDiv(..), cdot, (·), dot, (<.>), 43 mXm,mXv,vXm,LSDiv(..), cdot, (·), dot, (<.>),
43 outer, kronecker, 44 outer, kronecker,
@@ -102,12 +103,19 @@ cdot u v = udot (conj u) v
102 103
103-------------------------------------------------------- 104--------------------------------------------------------
104 105
105class Contraction a b c | a b -> c, a c -> b, b c -> a 106class Contraction a b c | a b -> c, c -> a b
106 where 107 where
107 infixl 7 <> 108 infixr 7 ×
108 {- | Matrix-matrix product, matrix-vector product, and unconjugated dot product 109 {- | Matrix-matrix product, matrix-vector product, and unconjugated dot product
109 110
110>>> let a = (3><4) [1..] :: Matrix Double 111(unicode 0x00d7, multiplication sign)
112
113Examples:
114
115>>> let a = (3><4) [1..] :: Matrix Double
116>>> let v = fromList [1,0,2,-1] :: Vector Double
117>>> let u = fromList [1,2,3] :: Vector Double
118
111>>> a 119>>> a
112(3><4) 120(3><4)
113 [ 1.0, 2.0, 3.0, 4.0 121 [ 1.0, 2.0, 3.0, 4.0
@@ -116,7 +124,7 @@ class Contraction a b c | a b -> c, a c -> b, b c -> a
116 124
117matrix × matrix: 125matrix × matrix:
118 126
119>>> disp 2 (a <> trans a) 127>>> disp 2 (a × trans a)
1203x3 1283x3
121 30 70 110 129 30 70 110
122 70 174 278 130 70 174 278
@@ -124,52 +132,79 @@ matrix × matrix:
124 132
125matrix × vector: 133matrix × vector:
126 134
127>>> a <> fromList [1,0,2,-1::Double] 135>>> a × v
128fromList [3.0,11.0,19.0] 136fromList [3.0,11.0,19.0]
129 137
130vector × matrix:
131
132>>> fromList [1,2,3::Double] <> a
133fromList [38.0,44.0,50.0,56.0]
134
135unconjugated dot product: 138unconjugated dot product:
136 139
137>>> fromList [1,i] <> fromList[2*i+1,3] 140>>> fromList [1,i] × fromList[2*i+1,3]
1381.0 :+ 5.0 1411.0 :+ 5.0
139 142
140-} 143(×) is right associative, so we can write:
141 (<>) :: a -> b -> c
142 144
143instance Product t => Contraction (Vector t) (Vector t) t where 145>>> u × a × v
144 (<>) = udot 14682.0 :: Double
145 147
146instance Product t => Contraction (Matrix t) (Vector t) (Vector t) where 148-}
147 (<>) = mXv 149 (×) :: a -> b -> c
148 150
149instance Product t => Contraction (Vector t) (Matrix t) (Vector t) where 151instance Product t => Contraction (Matrix t) (Vector t) (Vector t) where
150 (<>) = vXm 152 (×) = mXv
151 153
152instance Product t => Contraction (Matrix t) (Matrix t) (Matrix t) where 154instance Product t => Contraction (Matrix t) (Matrix t) (Matrix t) where
155 (×) = mXm
156
157instance Contraction (Vector Double) (Vector Double) Double where
158 (×) = udot
159
160instance Contraction (Vector Float) (Vector Float) Float where
161 (×) = udot
162
163instance Contraction (Vector (Complex Double)) (Vector (Complex Double)) (Complex Double) where
164 (×) = udot
165
166instance Contraction (Vector (Complex Float)) (Vector (Complex Float)) (Complex Float) where
167 (×) = udot
168
169
170-- | alternative function for the matrix product (×)
171mmul :: Contraction a b c => a -> b -> c
172mmul = (×)
173
174--------------------------------------------------------------------------------
175
176class Mul a b c | a b -> c where
177 infixl 7 <>
178 -- | Matrix-matrix, matrix-vector, and vector-matrix products.
179 (<>) :: Product t => a t -> b t -> c t
180
181instance Mul Matrix Matrix Matrix where
153 (<>) = mXm 182 (<>) = mXm
154 183
155-------------------------------------------------------- 184instance Mul Matrix Vector Vector where
185 (<>) m v = flatten $ m <> asColumn v
156 186
157class LSDiv b c | b -> c, c->b where 187instance Mul Vector Matrix Vector where
188 (<>) v m = flatten $ asRow v <> m
189
190--------------------------------------------------------------------------------
191
192class LSDiv c where
158 infixl 7 <\> 193 infixl 7 <\>
159 -- | least squares solution of a linear system, similar to the \\ operator of Matlab\/Octave (based on linearSolveSVD) 194 -- | least squares solution of a linear system, similar to the \\ operator of Matlab\/Octave (based on linearSolveSVD)
160 (<\>) :: Field t => Matrix t -> b t -> c t 195 (<\>) :: Field t => Matrix t -> c t -> c t
161 196
162instance LSDiv Vector Vector where 197instance LSDiv Vector where
163 m <\> v = flatten (linearSolveSVD m (reshape 1 v)) 198 m <\> v = flatten (linearSolveSVD m (reshape 1 v))
164 199
165instance LSDiv Matrix Matrix where 200instance LSDiv Matrix where
166 (<\>) = linearSolveSVD 201 (<\>) = linearSolveSVD
167 202
168-------------------------------------------------------- 203--------------------------------------------------------
169 204
170{- | Dot product : @u · v = 'cdot' u v@ 205{- | Dot product : @u · v = 'cdot' u v@
171 206
172 (unicode 0x00b7, Alt-Gr .) 207 (unicode 0x00b7, middle dot, Alt-Gr .)
173 208
174>>> fromList [1,i] · fromList[2*i+1,3] 209>>> fromList [1,i] · fromList[2*i+1,3]
1751.0 :+ (-1.0) 2101.0 :+ (-1.0)
@@ -233,7 +268,15 @@ instance Container Matrix e => Build (Int,Int) (e -> e -> e) Matrix e
233 268
234-------------------------------------------------------------------------------- 269--------------------------------------------------------------------------------
235 270
236-- | Compute mean vector and covariance matrix of the rows of a matrix. 271{- | Compute mean vector and covariance matrix of the rows of a matrix.
272
273>>> meanCov $ gaussianSample 666 1000 (fromList[4,5]) (diagl[2,3])
274(fromList [4.010341078059521,5.0197204699640405],
275(2><2)
276 [ 1.9862461923890056, -1.0127225830525157e-2
277 , -1.0127225830525157e-2, 3.0373954915729318 ])
278
279-}
237meanCov :: Matrix Double -> (Vector Double, Matrix Double) 280meanCov :: Matrix Double -> (Vector Double, Matrix Double)
238meanCov x = (med,cov) where 281meanCov x = (med,cov) where
239 r = rows x 282 r = rows x
@@ -249,7 +292,7 @@ meanCov x = (med,cov) where
249dot :: Product e => Vector e -> Vector e -> e 292dot :: Product e => Vector e -> Vector e -> e
250dot = udot 293dot = udot
251 294
252{-# DEPRECATED (<.>) "use udot or (<>)" #-} 295{-# DEPRECATED (<.>) "use udot or (×)" #-}
253infixl 7 <.> 296infixl 7 <.>
254(<.>) :: Product e => Vector e -> Vector e -> e 297(<.>) :: Product e => Vector e -> Vector e -> e
255(<.>) = udot 298(<.>) = udot