summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Numeric')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Util.hs560
1 files changed, 0 insertions, 560 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util.hs b/packages/base/src/Numeric/LinearAlgebra/Util.hs
deleted file mode 100644
index 37cdc0f..0000000
--- a/packages/base/src/Numeric/LinearAlgebra/Util.hs
+++ /dev/null
@@ -1,560 +0,0 @@
1{-# LANGUAGE FlexibleContexts #-}
2{-# LANGUAGE FlexibleInstances #-}
3{-# LANGUAGE TypeFamilies #-}
4{-# LANGUAGE MultiParamTypeClasses #-}
5{-# LANGUAGE FunctionalDependencies #-}
6{-# LANGUAGE ViewPatterns #-}
7
8
9-----------------------------------------------------------------------------
10{- |
11Module : Numeric.LinearAlgebra.Util
12Copyright : (c) Alberto Ruiz 2013
13License : BSD3
14Maintainer : Alberto Ruiz
15Stability : provisional
16
17-}
18-----------------------------------------------------------------------------
19
20module Numeric.LinearAlgebra.Util(
21
22 -- * Convenience functions
23 vector, matrix,
24 disp,
25 formatSparse,
26 approxInt,
27 dispDots,
28 dispBlanks,
29 formatShort,
30 dispShort,
31 zeros, ones,
32 diagl,
33 row,
34 col,
35 (&), (¦), (|||), (——), (===), (#),
36 (?), (¿),
37 Indexable(..), size,
38 Numeric,
39 rand, randn,
40 cross,
41 norm,
42 ℕ,ℤ,ℝ,ℂ,iC,
43 Normed(..), norm_Frob, norm_nuclear,
44 unitary,
45 mt,
46 (~!~),
47 pairwiseD2,
48 rowOuters,
49 null1,
50 null1sym,
51 -- * Convolution
52 -- ** 1D
53 corr, conv, corrMin,
54 -- ** 2D
55 corr2, conv2, separable,
56 gaussElim
57) where
58
59import Data.Packed.Numeric
60import Numeric.LinearAlgebra.Algorithms hiding (i,Normed)
61--import qualified Numeric.LinearAlgebra.Algorithms as A
62import Numeric.Matrix()
63import Numeric.Vector()
64import Numeric.LinearAlgebra.Random
65import Numeric.LinearAlgebra.Util.Convolution
66import Control.Monad(when)
67import Text.Printf
68import Data.List.Split(splitOn)
69import Data.List(intercalate,sortBy)
70import Data.Function(on)
71import Control.Arrow((&&&))
72
73type ℝ = Double
74type ℕ = Int
75type ℤ = Int
76type ℂ = Complex Double
77
78-- | imaginary unit
79iC :: ℂ
80iC = 0:+1
81
82{- | Create a real vector.
83
84>>> vector [1..5]
85fromList [1.0,2.0,3.0,4.0,5.0]
86
87-}
88vector :: [ℝ] -> Vector ℝ
89vector = fromList
90
91{- | Create a real matrix.
92
93>>> matrix 5 [1..15]
94(3><5)
95 [ 1.0, 2.0, 3.0, 4.0, 5.0
96 , 6.0, 7.0, 8.0, 9.0, 10.0
97 , 11.0, 12.0, 13.0, 14.0, 15.0 ]
98
99-}
100matrix
101 :: Int -- ^ number of columns
102 -> [ℝ] -- ^ elements in row order
103 -> Matrix ℝ
104matrix c = reshape c . fromList
105
106
107{- | print a real matrix with given number of digits after the decimal point
108
109>>> disp 5 $ ident 2 / 3
1102x2
1110.33333 0.00000
1120.00000 0.33333
113
114-}
115disp :: Int -> Matrix Double -> IO ()
116
117disp n = putStr . dispf n
118
119
120{- | create a real diagonal matrix from a list
121
122>>> diagl [1,2,3]
123(3><3)
124 [ 1.0, 0.0, 0.0
125 , 0.0, 2.0, 0.0
126 , 0.0, 0.0, 3.0 ]
127
128-}
129diagl :: [Double] -> Matrix Double
130diagl = diag . fromList
131
132-- | a real matrix of zeros
133zeros :: Int -- ^ rows
134 -> Int -- ^ columns
135 -> Matrix Double
136zeros r c = konst 0 (r,c)
137
138-- | a real matrix of ones
139ones :: Int -- ^ rows
140 -> Int -- ^ columns
141 -> Matrix Double
142ones r c = konst 1 (r,c)
143
144-- | concatenation of real vectors
145infixl 3 &
146(&) :: Vector Double -> Vector Double -> Vector Double
147a & b = vjoin [a,b]
148
149{- | horizontal concatenation of real matrices
150
151>>> ident 3 ||| konst 7 (3,4)
152(3><7)
153 [ 1.0, 0.0, 0.0, 7.0, 7.0, 7.0, 7.0
154 , 0.0, 1.0, 0.0, 7.0, 7.0, 7.0, 7.0
155 , 0.0, 0.0, 1.0, 7.0, 7.0, 7.0, 7.0 ]
156
157-}
158infixl 3 |||
159(|||) :: Matrix Double -> Matrix Double -> Matrix Double
160a ||| b = fromBlocks [[a,b]]
161
162-- | a synonym for ('|||') (unicode 0x00a6, broken bar)
163infixl 3 ¦
164(¦) :: Matrix Double -> Matrix Double -> Matrix Double
165(¦) = (|||)
166
167
168-- | vertical concatenation of real matrices
169--
170(===) :: Matrix Double -> Matrix Double -> Matrix Double
171infixl 2 ===
172a === b = fromBlocks [[a],[b]]
173
174-- | a synonym for ('===') (unicode 0x2014, em dash)
175(——) :: Matrix Double -> Matrix Double -> Matrix Double
176infixl 2 ——
177(——) = (===)
178
179
180(#) :: Matrix Double -> Matrix Double -> Matrix Double
181infixl 2 #
182a # b = fromBlocks [[a],[b]]
183
184-- | create a single row real matrix from a list
185--
186-- >>> row [2,3,1,8]
187-- (1><4)
188-- [ 2.0, 3.0, 1.0, 8.0 ]
189--
190row :: [Double] -> Matrix Double
191row = asRow . fromList
192
193-- | create a single column real matrix from a list
194--
195-- >>> col [7,-2,4]
196-- (3><1)
197-- [ 7.0
198-- , -2.0
199-- , 4.0 ]
200--
201col :: [Double] -> Matrix Double
202col = asColumn . fromList
203
204{- | extract rows
205
206>>> (20><4) [1..] ? [2,1,1]
207(3><4)
208 [ 9.0, 10.0, 11.0, 12.0
209 , 5.0, 6.0, 7.0, 8.0
210 , 5.0, 6.0, 7.0, 8.0 ]
211
212-}
213infixl 9 ?
214(?) :: Element t => Matrix t -> [Int] -> Matrix t
215(?) = flip extractRows
216
217{- | extract columns
218
219(unicode 0x00bf, inverted question mark, Alt-Gr ?)
220
221>>> (3><4) [1..] ¿ [3,0]
222(3><2)
223 [ 4.0, 1.0
224 , 8.0, 5.0
225 , 12.0, 9.0 ]
226
227-}
228infixl 9 ¿
229(¿) :: Element t => Matrix t -> [Int] -> Matrix t
230(¿)= flip extractColumns
231
232
233cross :: Product t => Vector t -> Vector t -> Vector t
234-- ^ cross product (for three-element vectors)
235cross x y | dim x == 3 && dim y == 3 = fromList [z1,z2,z3]
236 | otherwise = error $ "the cross product requires 3-element vectors (sizes given: "
237 ++show (dim x)++" and "++show (dim y)++")"
238 where
239 [x1,x2,x3] = toList x
240 [y1,y2,y3] = toList y
241 z1 = x2*y3-x3*y2
242 z2 = x3*y1-x1*y3
243 z3 = x1*y2-x2*y1
244
245{-# SPECIALIZE cross :: Vector Double -> Vector Double -> Vector Double #-}
246{-# SPECIALIZE cross :: Vector (Complex Double) -> Vector (Complex Double) -> Vector (Complex Double) #-}
247
248norm :: Vector Double -> Double
249-- ^ 2-norm of real vector
250norm = pnorm PNorm2
251
252class Normed a
253 where
254 norm_0 :: a -> ℝ
255 norm_1 :: a -> ℝ
256 norm_2 :: a -> ℝ
257 norm_Inf :: a -> ℝ
258
259
260instance Normed (Vector ℝ)
261 where
262 norm_0 v = sumElements (step (abs v - scalar (eps*normInf v)))
263 norm_1 = pnorm PNorm1
264 norm_2 = pnorm PNorm2
265 norm_Inf = pnorm Infinity
266
267instance Normed (Vector ℂ)
268 where
269 norm_0 v = sumElements (step (fst (fromComplex (abs v)) - scalar (eps*normInf v)))
270 norm_1 = pnorm PNorm1
271 norm_2 = pnorm PNorm2
272 norm_Inf = pnorm Infinity
273
274instance Normed (Matrix ℝ)
275 where
276 norm_0 = norm_0 . flatten
277 norm_1 = pnorm PNorm1
278 norm_2 = pnorm PNorm2
279 norm_Inf = pnorm Infinity
280
281instance Normed (Matrix ℂ)
282 where
283 norm_0 = norm_0 . flatten
284 norm_1 = pnorm PNorm1
285 norm_2 = pnorm PNorm2
286 norm_Inf = pnorm Infinity
287
288instance Normed (Vector I)
289 where
290 norm_0 = fromIntegral . sumElements . step . abs
291 norm_1 = fromIntegral . norm1
292 norm_2 v = sqrt . fromIntegral $ dot v v
293 norm_Inf = fromIntegral . normInf
294
295
296
297norm_Frob :: (Normed (Vector t), Element t) => Matrix t -> ℝ
298norm_Frob = norm_2 . flatten
299
300norm_nuclear :: Field t => Matrix t -> ℝ
301norm_nuclear = sumElements . singularValues
302
303
304-- | Obtains a vector in the same direction with 2-norm=1
305unitary :: Vector Double -> Vector Double
306unitary v = v / scalar (norm v)
307
308
309-- | trans . inv
310mt :: Matrix Double -> Matrix Double
311mt = trans . inv
312
313--------------------------------------------------------------------------------
314{- |
315
316>>> size $ vector [1..10]
31710
318>>> size $ (2><5)[1..10::Double]
319(2,5)
320
321-}
322size :: Container c t => c t -> IndexOf c
323size = size'
324
325{- | Alternative indexing function.
326
327>>> vector [1..10] ! 3
3284.0
329
330On a matrix it gets the k-th row as a vector:
331
332>>> matrix 5 [1..15] ! 1
333fromList [6.0,7.0,8.0,9.0,10.0]
334
335>>> matrix 5 [1..15] ! 1 ! 3
3369.0
337
338-}
339class Indexable c t | c -> t , t -> c
340 where
341 infixl 9 !
342 (!) :: c -> Int -> t
343
344instance Indexable (Vector Double) Double
345 where
346 (!) = (@>)
347
348instance Indexable (Vector Float) Float
349 where
350 (!) = (@>)
351
352instance Indexable (Vector I) I
353 where
354 (!) = (@>)
355
356instance Indexable (Vector (Complex Double)) (Complex Double)
357 where
358 (!) = (@>)
359
360instance Indexable (Vector (Complex Float)) (Complex Float)
361 where
362 (!) = (@>)
363
364instance Element t => Indexable (Matrix t) (Vector t)
365 where
366 m!j = subVector (j*c) c (flatten m)
367 where
368 c = cols m
369
370--------------------------------------------------------------------------------
371
372-- | Matrix of pairwise squared distances of row vectors
373-- (using the matrix product trick in blog.smola.org)
374pairwiseD2 :: Matrix Double -> Matrix Double -> Matrix Double
375pairwiseD2 x y | ok = x2 `outer` oy + ox `outer` y2 - 2* x <> trans y
376 | otherwise = error $ "pairwiseD2 with different number of columns: "
377 ++ show (size x) ++ ", " ++ show (size y)
378 where
379 ox = one (rows x)
380 oy = one (rows y)
381 oc = one (cols x)
382 one k = konst 1 k
383 x2 = x * x <> oc
384 y2 = y * y <> oc
385 ok = cols x == cols y
386
387--------------------------------------------------------------------------------
388
389{- | outer products of rows
390
391>>> a
392(3><2)
393 [ 1.0, 2.0
394 , 10.0, 20.0
395 , 100.0, 200.0 ]
396>>> b
397(3><3)
398 [ 1.0, 2.0, 3.0
399 , 4.0, 5.0, 6.0
400 , 7.0, 8.0, 9.0 ]
401
402>>> rowOuters a (b ||| 1)
403(3><8)
404 [ 1.0, 2.0, 3.0, 1.0, 2.0, 4.0, 6.0, 2.0
405 , 40.0, 50.0, 60.0, 10.0, 80.0, 100.0, 120.0, 20.0
406 , 700.0, 800.0, 900.0, 100.0, 1400.0, 1600.0, 1800.0, 200.0 ]
407
408-}
409rowOuters :: Matrix Double -> Matrix Double -> Matrix Double
410rowOuters a b = a' * b'
411 where
412 a' = kronecker a (ones 1 (cols b))
413 b' = kronecker (ones 1 (cols a)) b
414
415--------------------------------------------------------------------------------
416
417-- | solution of overconstrained homogeneous linear system
418null1 :: Matrix Double -> Vector Double
419null1 = last . toColumns . snd . rightSV
420
421-- | solution of overconstrained homogeneous symmetric linear system
422null1sym :: Matrix Double -> Vector Double
423null1sym = last . toColumns . snd . eigSH'
424
425--------------------------------------------------------------------------------
426
427infixl 0 ~!~
428c ~!~ msg = when c (error msg)
429
430--------------------------------------------------------------------------------
431
432formatSparse :: String -> String -> String -> Int -> Matrix Double -> String
433
434formatSparse zeroI _zeroF sep _ (approxInt -> Just m) = format sep f m
435 where
436 f 0 = zeroI
437 f x = printf "%.0f" x
438
439formatSparse zeroI zeroF sep n m = format sep f m
440 where
441 f x | abs (x::Double) < 2*peps = zeroI++zeroF
442 | abs (fromIntegral (round x::Int) - x) / abs x < 2*peps
443 = printf ("%.0f."++replicate n ' ') x
444 | otherwise = printf ("%."++show n++"f") x
445
446approxInt m
447 | norm_Inf (v - vi) < 2*peps * norm_Inf v = Just (reshape (cols m) vi)
448 | otherwise = Nothing
449 where
450 v = flatten m
451 vi = roundVector v
452
453dispDots n = putStr . formatSparse "." (replicate n ' ') " " n
454
455dispBlanks n = putStr . formatSparse "" "" " " n
456
457formatShort sep fmt maxr maxc m = auxm4
458 where
459 (rm,cm) = size m
460 (r1,r2,r3)
461 | rm <= maxr = (rm,0,0)
462 | otherwise = (maxr-3,rm-maxr+1,2)
463 (c1,c2,c3)
464 | cm <= maxc = (cm,0,0)
465 | otherwise = (maxc-3,cm-maxc+1,2)
466 [ [a,_,b]
467 ,[_,_,_]
468 ,[c,_,d]] = toBlocks [r1,r2,r3]
469 [c1,c2,c3] m
470 auxm = fromBlocks [[a,b],[c,d]]
471 auxm2
472 | cm > maxc = format "|" fmt auxm
473 | otherwise = format sep fmt auxm
474 auxm3
475 | cm > maxc = map (f . splitOn "|") (lines auxm2)
476 | otherwise = (lines auxm2)
477 f items = intercalate sep (take (maxc-3) items) ++ " .. " ++
478 intercalate sep (drop (maxc-3) items)
479 auxm4
480 | rm > maxr = unlines (take (maxr-3) auxm3 ++ vsep : drop (maxr-3) auxm3)
481 | otherwise = unlines auxm3
482 vsep = map g (head auxm3)
483 g '.' = ':'
484 g _ = ' '
485
486
487dispShort :: Int -> Int -> Int -> Matrix Double -> IO ()
488dispShort maxr maxc dec m =
489 printf "%dx%d\n%s" (rows m) (cols m) (formatShort " " fmt maxr maxc m)
490 where
491 fmt = printf ("%."++show dec ++"f")
492
493--------------------------------------------------------------------------------
494
495-- | generic reference implementation of gaussian elimination
496--
497-- @a <> gauss a b = b@
498--
499gaussElim
500 :: (Fractional t, Num (Vector t), Ord t, Indexable (Vector t) t, Numeric t)
501 => Matrix t -> Matrix t -> Matrix t
502
503gaussElim x y = dropColumns (rows x) (flipud $ fromRows s2)
504 where
505 rs = toRows $ fromBlocks [[x , y]]
506 s1 = pivotDown (rows x) 0 rs
507 s2 = pivotUp (rows x-1) (reverse s1)
508
509pivotDown t n xs
510 | t == n = []
511 | otherwise = y : pivotDown t (n+1) ys
512 where
513 y:ys = redu (pivot n xs)
514
515 pivot k = (const k &&& id)
516 . reverse . sortBy (compare `on` (abs. (!k))) -- FIXME
517
518 redu (k,x:zs)
519 | p == 0 = error "gauss: singular!" -- FIXME
520 | otherwise = u : map f zs
521 where
522 p = x!k
523 u = scale (recip (x!k)) x
524 f z = z - scale (z!k) u
525 redu (_,[]) = []
526
527
528pivotUp n xs
529 | n == -1 = []
530 | otherwise = y : pivotUp (n-1) ys
531 where
532 y:ys = redu' (n,xs)
533
534 redu' (k,x:zs) = u : map f zs
535 where
536 u = x
537 f z = z - scale (z!k) u
538 redu' (_,[]) = []
539
540--------------------------------------------------------------------------------
541
542instance Testable (Matrix I) where
543 checkT _ = test
544
545test :: (Bool, IO())
546test = (and ok, return ())
547 where
548 m = (3><4) [1..12] :: Matrix I
549 r = (2><3) [1,2,3,4,3,2]
550 c = (3><2) [0,4,4,1,2,3]
551 p = (9><10) [0..89] :: Matrix I
552 ep = (2><3) [10,24,32,44,31,23]
553 md = fromInt m :: Matrix Double
554 ok = [ tr m <> m == toInt (tr md <> md)
555 , m <> tr m == toInt (md <> tr md)
556 , m ?? (Take 2, Take 3) == remap (asColumn (range 2)) (asRow (range 3)) m
557 , remap r (tr c) p == ep
558 , tr p ?? (PosCyc (idxs[-5,13]), Pos (idxs[3,7,1])) == (2><3) [35,75,15,33,73,13]
559 ]
560