diff options
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra/Util.hs')
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Util.hs | 149 |
1 files changed, 127 insertions, 22 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util.hs b/packages/base/src/Numeric/LinearAlgebra/Util.hs index 2f91e18..a7d6946 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Util.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Util.hs | |||
@@ -1,4 +1,9 @@ | |||
1 | {-# LANGUAGE FlexibleContexts #-} | 1 | {-# LANGUAGE FlexibleContexts #-} |
2 | {-# LANGUAGE FlexibleInstances #-} | ||
3 | {-# LANGUAGE TypeFamilies #-} | ||
4 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
5 | {-# LANGUAGE FunctionalDependencies #-} | ||
6 | |||
2 | ----------------------------------------------------------------------------- | 7 | ----------------------------------------------------------------------------- |
3 | {- | | 8 | {- | |
4 | Module : Numeric.LinearAlgebra.Util | 9 | Module : Numeric.LinearAlgebra.Util |
@@ -14,19 +19,24 @@ Stability : provisional | |||
14 | module Numeric.LinearAlgebra.Util( | 19 | module Numeric.LinearAlgebra.Util( |
15 | 20 | ||
16 | -- * Convenience functions | 21 | -- * Convenience functions |
17 | size, disp, | 22 | vect, mat, |
23 | disp, | ||
18 | zeros, ones, | 24 | zeros, ones, |
19 | diagl, | 25 | diagl, |
20 | row, | 26 | row, |
21 | col, | 27 | col, |
22 | (&), (¦), (——), (#), | 28 | (&), (¦), (——), (#), |
23 | (?), (¿), | 29 | (?), (¿), |
30 | Indexable(..), size, | ||
31 | rand, randn, | ||
24 | cross, | 32 | cross, |
25 | norm, | 33 | norm, |
34 | ℕ,ℤ,ℝ,ℂ,ℝn,ℂn,𝑖,i_C, --ℍ | ||
35 | norm_1, norm_2, norm_0, norm_Inf, norm_Frob, norm_nuclear, | ||
36 | mnorm_1, mnorm_2, mnorm_0, mnorm_Inf, | ||
26 | unitary, | 37 | unitary, |
27 | mt, | 38 | mt, |
28 | pairwiseD2, | 39 | pairwiseD2, |
29 | meanCov, | ||
30 | rowOuters, | 40 | rowOuters, |
31 | null1, | 41 | null1, |
32 | null1sym, | 42 | null1sym, |
@@ -48,13 +58,49 @@ module Numeric.LinearAlgebra.Util( | |||
48 | vtrans | 58 | vtrans |
49 | ) where | 59 | ) where |
50 | 60 | ||
51 | import Numeric.Container | 61 | import Data.Packed.Numeric |
52 | import Numeric.LinearAlgebra.Algorithms hiding (i) | 62 | import Numeric.LinearAlgebra.Algorithms hiding (i) |
53 | import Numeric.Matrix() | 63 | import Numeric.Matrix() |
54 | import Numeric.Vector() | 64 | import Numeric.Vector() |
55 | 65 | import Numeric.LinearAlgebra.Random | |
56 | import Numeric.LinearAlgebra.Util.Convolution | 66 | import Numeric.LinearAlgebra.Util.Convolution |
57 | 67 | ||
68 | type ℝ = Double | ||
69 | type ℕ = Int | ||
70 | type ℤ = Int | ||
71 | type ℂ = Complex Double | ||
72 | type ℝn = Vector ℝ | ||
73 | type ℂn = Vector ℂ | ||
74 | --newtype ℍ m = H m | ||
75 | |||
76 | i_C, 𝑖 :: ℂ | ||
77 | 𝑖 = 0:+1 | ||
78 | i_C = 𝑖 | ||
79 | |||
80 | {- | create a real vector | ||
81 | |||
82 | >>> vect [1..5] | ||
83 | fromList [1.0,2.0,3.0,4.0,5.0] | ||
84 | |||
85 | -} | ||
86 | vect :: [ℝ] -> ℝn | ||
87 | vect = fromList | ||
88 | |||
89 | {- | create a real matrix | ||
90 | |||
91 | >>> mat 5 [1..15] | ||
92 | (3><5) | ||
93 | [ 1.0, 2.0, 3.0, 4.0, 5.0 | ||
94 | , 6.0, 7.0, 8.0, 9.0, 10.0 | ||
95 | , 11.0, 12.0, 13.0, 14.0, 15.0 ] | ||
96 | |||
97 | -} | ||
98 | mat | ||
99 | :: Int -- ^ columns | ||
100 | -> [ℝ] -- ^ elements | ||
101 | -> Matrix ℝ | ||
102 | mat c = reshape c . fromList | ||
103 | |||
58 | {- | print a real matrix with given number of digits after the decimal point | 104 | {- | print a real matrix with given number of digits after the decimal point |
59 | 105 | ||
60 | >>> disp 5 $ ident 2 / 3 | 106 | >>> disp 5 $ ident 2 / 3 |
@@ -175,38 +221,97 @@ norm :: Vector Double -> Double | |||
175 | -- ^ 2-norm of real vector | 221 | -- ^ 2-norm of real vector |
176 | norm = pnorm PNorm2 | 222 | norm = pnorm PNorm2 |
177 | 223 | ||
224 | norm_2 :: Normed Vector t => Vector t -> RealOf t | ||
225 | norm_2 = pnorm PNorm2 | ||
226 | |||
227 | norm_1 :: Normed Vector t => Vector t -> RealOf t | ||
228 | norm_1 = pnorm PNorm1 | ||
229 | |||
230 | norm_Inf :: Normed Vector t => Vector t -> RealOf t | ||
231 | norm_Inf = pnorm Infinity | ||
232 | |||
233 | norm_0 :: Vector ℝ -> ℝ | ||
234 | norm_0 v = sumElements (step (abs v - scalar (eps*mx))) | ||
235 | where | ||
236 | mx = norm_Inf v | ||
237 | |||
238 | norm_Frob :: Normed Matrix t => Matrix t -> RealOf t | ||
239 | norm_Frob = pnorm Frobenius | ||
240 | |||
241 | norm_nuclear :: Field t => Matrix t -> ℝ | ||
242 | norm_nuclear = sumElements . singularValues | ||
243 | |||
244 | mnorm_2 :: Normed Matrix t => Matrix t -> RealOf t | ||
245 | mnorm_2 = pnorm PNorm2 | ||
246 | |||
247 | mnorm_1 :: Normed Matrix t => Matrix t -> RealOf t | ||
248 | mnorm_1 = pnorm PNorm1 | ||
249 | |||
250 | mnorm_Inf :: Normed Matrix t => Matrix t -> RealOf t | ||
251 | mnorm_Inf = pnorm Infinity | ||
252 | |||
253 | mnorm_0 :: Matrix ℝ -> ℝ | ||
254 | mnorm_0 = norm_0 . flatten | ||
178 | 255 | ||
179 | -- | Obtains a vector in the same direction with 2-norm=1 | 256 | -- | Obtains a vector in the same direction with 2-norm=1 |
180 | unitary :: Vector Double -> Vector Double | 257 | unitary :: Vector Double -> Vector Double |
181 | unitary v = v / scalar (norm v) | 258 | unitary v = v / scalar (norm v) |
182 | 259 | ||
183 | -- | ('rows' &&& 'cols') | ||
184 | size :: Matrix t -> (Int, Int) | ||
185 | size m = (rows m, cols m) | ||
186 | 260 | ||
187 | -- | trans . inv | 261 | -- | trans . inv |
188 | mt :: Matrix Double -> Matrix Double | 262 | mt :: Matrix Double -> Matrix Double |
189 | mt = trans . inv | 263 | mt = trans . inv |
190 | 264 | ||
191 | -------------------------------------------------------------------------------- | 265 | -------------------------------------------------------------------------------- |
266 | {- | | ||
267 | |||
268 | >>> size $ fromList[1..10::Double] | ||
269 | 10 | ||
270 | >>> size $ (2><5)[1..10::Double] | ||
271 | (2,5) | ||
272 | |||
273 | -} | ||
274 | size :: Container c t => c t -> IndexOf c | ||
275 | size = size' | ||
192 | 276 | ||
193 | {- | Compute mean vector and covariance matrix of the rows of a matrix. | 277 | {- | |
278 | |||
279 | >>> vect [1..10] ! 3 | ||
280 | 4.0 | ||
281 | |||
282 | >>> mat 5 [1..15] ! 1 | ||
283 | fromList [6.0,7.0,8.0,9.0,10.0] | ||
194 | 284 | ||
195 | >>> meanCov $ gaussianSample 666 1000 (fromList[4,5]) (diagl[2,3]) | 285 | >>> mat 5 [1..15] ! 1 ! 3 |
196 | (fromList [4.010341078059521,5.0197204699640405], | 286 | 9.0 |
197 | (2><2) | ||
198 | [ 1.9862461923890056, -1.0127225830525157e-2 | ||
199 | , -1.0127225830525157e-2, 3.0373954915729318 ]) | ||
200 | 287 | ||
201 | -} | 288 | -} |
202 | meanCov :: Matrix Double -> (Vector Double, Matrix Double) | 289 | class Indexable c t | c -> t , t -> c |
203 | meanCov x = (med,cov) where | 290 | where |
204 | r = rows x | 291 | infixl 9 ! |
205 | k = 1 / fromIntegral r | 292 | (!) :: c -> Int -> t |
206 | med = konst k r `vXm` x | 293 | |
207 | meds = konst 1 r `outer` med | 294 | instance Indexable (Vector Double) Double |
208 | xc = x `sub` meds | 295 | where |
209 | cov = scale (recip (fromIntegral (r-1))) (trans xc `mXm` xc) | 296 | (!) = (@>) |
297 | |||
298 | instance Indexable (Vector Float) Float | ||
299 | where | ||
300 | (!) = (@>) | ||
301 | |||
302 | instance Indexable (Vector (Complex Double)) (Complex Double) | ||
303 | where | ||
304 | (!) = (@>) | ||
305 | |||
306 | instance Indexable (Vector (Complex Float)) (Complex Float) | ||
307 | where | ||
308 | (!) = (@>) | ||
309 | |||
310 | instance Element t => Indexable (Matrix t) (Vector t) | ||
311 | where | ||
312 | m!j = subVector (j*c) c (flatten m) | ||
313 | where | ||
314 | c = cols m | ||
210 | 315 | ||
211 | -------------------------------------------------------------------------------- | 316 | -------------------------------------------------------------------------------- |
212 | 317 | ||
@@ -220,7 +325,7 @@ pairwiseD2 x y | ok = x2 `outer` oy + ox `outer` y2 - 2* x <> trans y | |||
220 | ox = one (rows x) | 325 | ox = one (rows x) |
221 | oy = one (rows y) | 326 | oy = one (rows y) |
222 | oc = one (cols x) | 327 | oc = one (cols x) |
223 | one k = constant 1 k | 328 | one k = konst 1 k |
224 | x2 = x * x <> oc | 329 | x2 = x * x <> oc |
225 | y2 = y * y <> oc | 330 | y2 = y * y <> oc |
226 | ok = cols x == cols y | 331 | ok = cols x == cols y |