summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Util.hs
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2015-06-05 16:35:18 +0200
committerAlberto Ruiz <aruiz@um.es>2015-06-05 16:35:18 +0200
commit2876998f04380c9e835c6177b440447368dfe623 (patch)
tree233144431cbcfcabdbc1433cda81f534f4262bbe /packages/base/src/Internal/Util.hs
parent01e2dac904b37e5831617c28814cf5a65bb6f1e6 (diff)
move util
Diffstat (limited to 'packages/base/src/Internal/Util.hs')
-rw-r--r--packages/base/src/Internal/Util.hs568
1 files changed, 568 insertions, 0 deletions
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs
new file mode 100644
index 0000000..9900770
--- /dev/null
+++ b/packages/base/src/Internal/Util.hs
@@ -0,0 +1,568 @@
1{-# LANGUAGE FlexibleContexts #-}
2{-# LANGUAGE FlexibleInstances #-}
3{-# LANGUAGE TypeFamilies #-}
4{-# LANGUAGE MultiParamTypeClasses #-}
5{-# LANGUAGE FunctionalDependencies #-}
6{-# LANGUAGE ViewPatterns #-}
7
8
9-----------------------------------------------------------------------------
10{- |
11Module : Internal.Util
12Copyright : (c) Alberto Ruiz 2013
13License : BSD3
14Maintainer : Alberto Ruiz
15Stability : provisional
16
17-}
18-----------------------------------------------------------------------------
19
20module Internal.Util(
21
22 -- * Convenience functions
23 vector, matrix,
24 disp,
25 formatSparse,
26 approxInt,
27 dispDots,
28 dispBlanks,
29 formatShort,
30 dispShort,
31 zeros, ones,
32 diagl,
33 row,
34 col,
35 (&), (¦), (|||), (——), (===), (#),
36 (?), (¿),
37 Indexable(..), size,
38 Numeric,
39 rand, randn,
40 cross,
41 norm,
42 ℕ,ℤ,ℝ,ℂ,iC,
43 Normed(..), norm_Frob, norm_nuclear,
44 unitary,
45 mt,
46 (~!~),
47 pairwiseD2,
48 rowOuters,
49 null1,
50 null1sym,
51 -- * Convolution
52 -- ** 1D
53 corr, conv, corrMin,
54 -- ** 2D
55 corr2, conv2, separable,
56 gaussElim
57) where
58
59import Internal.Tools
60import Internal.Vector
61import Internal.Matrix hiding (size)
62import Internal.Numeric
63import Internal.Element
64import Internal.Container
65import Internal.Vectorized
66import Internal.IO
67import Internal.Algorithms hiding (i,Normed)
68import Numeric.Matrix()
69import Numeric.Vector()
70import Internal.Random
71import Internal.Convolution
72import Control.Monad(when)
73import Text.Printf
74import Data.List.Split(splitOn)
75import Data.List(intercalate,sortBy)
76import Data.Function(on)
77import Control.Arrow((&&&))
78import Data.Complex
79import Data.Vector.Storable(fromList)
80
81type ℝ = Double
82type ℕ = Int
83type ℤ = Int
84type ℂ = Complex Double
85
86-- | imaginary unit
87iC :: ℂ
88iC = 0:+1
89
90{- | Create a real vector.
91
92>>> vector [1..5]
93fromList [1.0,2.0,3.0,4.0,5.0]
94
95-}
96vector :: [ℝ] -> Vector ℝ
97vector = fromList
98
99{- | Create a real matrix.
100
101>>> matrix 5 [1..15]
102(3><5)
103 [ 1.0, 2.0, 3.0, 4.0, 5.0
104 , 6.0, 7.0, 8.0, 9.0, 10.0
105 , 11.0, 12.0, 13.0, 14.0, 15.0 ]
106
107-}
108matrix
109 :: Int -- ^ number of columns
110 -> [ℝ] -- ^ elements in row order
111 -> Matrix ℝ
112matrix c = reshape c . fromList
113
114
115{- | print a real matrix with given number of digits after the decimal point
116
117>>> disp 5 $ ident 2 / 3
1182x2
1190.33333 0.00000
1200.00000 0.33333
121
122-}
123disp :: Int -> Matrix Double -> IO ()
124
125disp n = putStr . dispf n
126
127
128{- | create a real diagonal matrix from a list
129
130>>> diagl [1,2,3]
131(3><3)
132 [ 1.0, 0.0, 0.0
133 , 0.0, 2.0, 0.0
134 , 0.0, 0.0, 3.0 ]
135
136-}
137diagl :: [Double] -> Matrix Double
138diagl = diag . fromList
139
140-- | a real matrix of zeros
141zeros :: Int -- ^ rows
142 -> Int -- ^ columns
143 -> Matrix Double
144zeros r c = konst 0 (r,c)
145
146-- | a real matrix of ones
147ones :: Int -- ^ rows
148 -> Int -- ^ columns
149 -> Matrix Double
150ones r c = konst 1 (r,c)
151
152-- | concatenation of real vectors
153infixl 3 &
154(&) :: Vector Double -> Vector Double -> Vector Double
155a & b = vjoin [a,b]
156
157{- | horizontal concatenation of real matrices
158
159>>> ident 3 ||| konst 7 (3,4)
160(3><7)
161 [ 1.0, 0.0, 0.0, 7.0, 7.0, 7.0, 7.0
162 , 0.0, 1.0, 0.0, 7.0, 7.0, 7.0, 7.0
163 , 0.0, 0.0, 1.0, 7.0, 7.0, 7.0, 7.0 ]
164
165-}
166infixl 3 |||
167(|||) :: Matrix Double -> Matrix Double -> Matrix Double
168a ||| b = fromBlocks [[a,b]]
169
170-- | a synonym for ('|||') (unicode 0x00a6, broken bar)
171infixl 3 ¦
172(¦) :: Matrix Double -> Matrix Double -> Matrix Double
173(¦) = (|||)
174
175
176-- | vertical concatenation of real matrices
177--
178(===) :: Matrix Double -> Matrix Double -> Matrix Double
179infixl 2 ===
180a === b = fromBlocks [[a],[b]]
181
182-- | a synonym for ('===') (unicode 0x2014, em dash)
183(——) :: Matrix Double -> Matrix Double -> Matrix Double
184infixl 2 ——
185(——) = (===)
186
187
188(#) :: Matrix Double -> Matrix Double -> Matrix Double
189infixl 2 #
190a # b = fromBlocks [[a],[b]]
191
192-- | create a single row real matrix from a list
193--
194-- >>> row [2,3,1,8]
195-- (1><4)
196-- [ 2.0, 3.0, 1.0, 8.0 ]
197--
198row :: [Double] -> Matrix Double
199row = asRow . fromList
200
201-- | create a single column real matrix from a list
202--
203-- >>> col [7,-2,4]
204-- (3><1)
205-- [ 7.0
206-- , -2.0
207-- , 4.0 ]
208--
209col :: [Double] -> Matrix Double
210col = asColumn . fromList
211
212{- | extract rows
213
214>>> (20><4) [1..] ? [2,1,1]
215(3><4)
216 [ 9.0, 10.0, 11.0, 12.0
217 , 5.0, 6.0, 7.0, 8.0
218 , 5.0, 6.0, 7.0, 8.0 ]
219
220-}
221infixl 9 ?
222(?) :: Element t => Matrix t -> [Int] -> Matrix t
223(?) = flip extractRows
224
225{- | extract columns
226
227(unicode 0x00bf, inverted question mark, Alt-Gr ?)
228
229>>> (3><4) [1..] ¿ [3,0]
230(3><2)
231 [ 4.0, 1.0
232 , 8.0, 5.0
233 , 12.0, 9.0 ]
234
235-}
236infixl 9 ¿
237(¿) :: Element t => Matrix t -> [Int] -> Matrix t
238(¿)= flip extractColumns
239
240
241cross :: Product t => Vector t -> Vector t -> Vector t
242-- ^ cross product (for three-element vectors)
243cross x y | dim x == 3 && dim y == 3 = fromList [z1,z2,z3]
244 | otherwise = error $ "the cross product requires 3-element vectors (sizes given: "
245 ++show (dim x)++" and "++show (dim y)++")"
246 where
247 [x1,x2,x3] = toList x
248 [y1,y2,y3] = toList y
249 z1 = x2*y3-x3*y2
250 z2 = x3*y1-x1*y3
251 z3 = x1*y2-x2*y1
252
253{-# SPECIALIZE cross :: Vector Double -> Vector Double -> Vector Double #-}
254{-# SPECIALIZE cross :: Vector (Complex Double) -> Vector (Complex Double) -> Vector (Complex Double) #-}
255
256norm :: Vector Double -> Double
257-- ^ 2-norm of real vector
258norm = pnorm PNorm2
259
260class Normed a
261 where
262 norm_0 :: a -> ℝ
263 norm_1 :: a -> ℝ
264 norm_2 :: a -> ℝ
265 norm_Inf :: a -> ℝ
266
267
268instance Normed (Vector ℝ)
269 where
270 norm_0 v = sumElements (step (abs v - scalar (eps*normInf v)))
271 norm_1 = pnorm PNorm1
272 norm_2 = pnorm PNorm2
273 norm_Inf = pnorm Infinity
274
275instance Normed (Vector ℂ)
276 where
277 norm_0 v = sumElements (step (fst (fromComplex (abs v)) - scalar (eps*normInf v)))
278 norm_1 = pnorm PNorm1
279 norm_2 = pnorm PNorm2
280 norm_Inf = pnorm Infinity
281
282instance Normed (Matrix ℝ)
283 where
284 norm_0 = norm_0 . flatten
285 norm_1 = pnorm PNorm1
286 norm_2 = pnorm PNorm2
287 norm_Inf = pnorm Infinity
288
289instance Normed (Matrix ℂ)
290 where
291 norm_0 = norm_0 . flatten
292 norm_1 = pnorm PNorm1
293 norm_2 = pnorm PNorm2
294 norm_Inf = pnorm Infinity
295
296instance Normed (Vector I)
297 where
298 norm_0 = fromIntegral . sumElements . step . abs
299 norm_1 = fromIntegral . norm1
300 norm_2 v = sqrt . fromIntegral $ dot v v
301 norm_Inf = fromIntegral . normInf
302
303
304
305norm_Frob :: (Normed (Vector t), Element t) => Matrix t -> ℝ
306norm_Frob = norm_2 . flatten
307
308norm_nuclear :: Field t => Matrix t -> ℝ
309norm_nuclear = sumElements . singularValues
310
311
312-- | Obtains a vector in the same direction with 2-norm=1
313unitary :: Vector Double -> Vector Double
314unitary v = v / scalar (norm v)
315
316
317-- | trans . inv
318mt :: Matrix Double -> Matrix Double
319mt = trans . inv
320
321--------------------------------------------------------------------------------
322{- |
323
324>>> size $ vector [1..10]
32510
326>>> size $ (2><5)[1..10::Double]
327(2,5)
328
329-}
330size :: Container c t => c t -> IndexOf c
331size = size'
332
333{- | Alternative indexing function.
334
335>>> vector [1..10] ! 3
3364.0
337
338On a matrix it gets the k-th row as a vector:
339
340>>> matrix 5 [1..15] ! 1
341fromList [6.0,7.0,8.0,9.0,10.0]
342
343>>> matrix 5 [1..15] ! 1 ! 3
3449.0
345
346-}
347class Indexable c t | c -> t , t -> c
348 where
349 infixl 9 !
350 (!) :: c -> Int -> t
351
352instance Indexable (Vector Double) Double
353 where
354 (!) = (@>)
355
356instance Indexable (Vector Float) Float
357 where
358 (!) = (@>)
359
360instance Indexable (Vector I) I
361 where
362 (!) = (@>)
363
364instance Indexable (Vector (Complex Double)) (Complex Double)
365 where
366 (!) = (@>)
367
368instance Indexable (Vector (Complex Float)) (Complex Float)
369 where
370 (!) = (@>)
371
372instance Element t => Indexable (Matrix t) (Vector t)
373 where
374 m!j = subVector (j*c) c (flatten m)
375 where
376 c = cols m
377
378--------------------------------------------------------------------------------
379
380-- | Matrix of pairwise squared distances of row vectors
381-- (using the matrix product trick in blog.smola.org)
382pairwiseD2 :: Matrix Double -> Matrix Double -> Matrix Double
383pairwiseD2 x y | ok = x2 `outer` oy + ox `outer` y2 - 2* x <> trans y
384 | otherwise = error $ "pairwiseD2 with different number of columns: "
385 ++ show (size x) ++ ", " ++ show (size y)
386 where
387 ox = one (rows x)
388 oy = one (rows y)
389 oc = one (cols x)
390 one k = konst 1 k
391 x2 = x * x <> oc
392 y2 = y * y <> oc
393 ok = cols x == cols y
394
395--------------------------------------------------------------------------------
396
397{- | outer products of rows
398
399>>> a
400(3><2)
401 [ 1.0, 2.0
402 , 10.0, 20.0
403 , 100.0, 200.0 ]
404>>> b
405(3><3)
406 [ 1.0, 2.0, 3.0
407 , 4.0, 5.0, 6.0
408 , 7.0, 8.0, 9.0 ]
409
410>>> rowOuters a (b ||| 1)
411(3><8)
412 [ 1.0, 2.0, 3.0, 1.0, 2.0, 4.0, 6.0, 2.0
413 , 40.0, 50.0, 60.0, 10.0, 80.0, 100.0, 120.0, 20.0
414 , 700.0, 800.0, 900.0, 100.0, 1400.0, 1600.0, 1800.0, 200.0 ]
415
416-}
417rowOuters :: Matrix Double -> Matrix Double -> Matrix Double
418rowOuters a b = a' * b'
419 where
420 a' = kronecker a (ones 1 (cols b))
421 b' = kronecker (ones 1 (cols a)) b
422
423--------------------------------------------------------------------------------
424
425-- | solution of overconstrained homogeneous linear system
426null1 :: Matrix Double -> Vector Double
427null1 = last . toColumns . snd . rightSV
428
429-- | solution of overconstrained homogeneous symmetric linear system
430null1sym :: Matrix Double -> Vector Double
431null1sym = last . toColumns . snd . eigSH'
432
433--------------------------------------------------------------------------------
434
435infixl 0 ~!~
436c ~!~ msg = when c (error msg)
437
438--------------------------------------------------------------------------------
439
440formatSparse :: String -> String -> String -> Int -> Matrix Double -> String
441
442formatSparse zeroI _zeroF sep _ (approxInt -> Just m) = format sep f m
443 where
444 f 0 = zeroI
445 f x = printf "%.0f" x
446
447formatSparse zeroI zeroF sep n m = format sep f m
448 where
449 f x | abs (x::Double) < 2*peps = zeroI++zeroF
450 | abs (fromIntegral (round x::Int) - x) / abs x < 2*peps
451 = printf ("%.0f."++replicate n ' ') x
452 | otherwise = printf ("%."++show n++"f") x
453
454approxInt m
455 | norm_Inf (v - vi) < 2*peps * norm_Inf v = Just (reshape (cols m) vi)
456 | otherwise = Nothing
457 where
458 v = flatten m
459 vi = roundVector v
460
461dispDots n = putStr . formatSparse "." (replicate n ' ') " " n
462
463dispBlanks n = putStr . formatSparse "" "" " " n
464
465formatShort sep fmt maxr maxc m = auxm4
466 where
467 (rm,cm) = size m
468 (r1,r2,r3)
469 | rm <= maxr = (rm,0,0)
470 | otherwise = (maxr-3,rm-maxr+1,2)
471 (c1,c2,c3)
472 | cm <= maxc = (cm,0,0)
473 | otherwise = (maxc-3,cm-maxc+1,2)
474 [ [a,_,b]
475 ,[_,_,_]
476 ,[c,_,d]] = toBlocks [r1,r2,r3]
477 [c1,c2,c3] m
478 auxm = fromBlocks [[a,b],[c,d]]
479 auxm2
480 | cm > maxc = format "|" fmt auxm
481 | otherwise = format sep fmt auxm
482 auxm3
483 | cm > maxc = map (f . splitOn "|") (lines auxm2)
484 | otherwise = (lines auxm2)
485 f items = intercalate sep (take (maxc-3) items) ++ " .. " ++
486 intercalate sep (drop (maxc-3) items)
487 auxm4
488 | rm > maxr = unlines (take (maxr-3) auxm3 ++ vsep : drop (maxr-3) auxm3)
489 | otherwise = unlines auxm3
490 vsep = map g (head auxm3)
491 g '.' = ':'
492 g _ = ' '
493
494
495dispShort :: Int -> Int -> Int -> Matrix Double -> IO ()
496dispShort maxr maxc dec m =
497 printf "%dx%d\n%s" (rows m) (cols m) (formatShort " " fmt maxr maxc m)
498 where
499 fmt = printf ("%."++show dec ++"f")
500
501--------------------------------------------------------------------------------
502
503-- | generic reference implementation of gaussian elimination
504--
505-- @a <> gauss a b = b@
506--
507gaussElim
508 :: (Fractional t, Num (Vector t), Ord t, Indexable (Vector t) t, Numeric t)
509 => Matrix t -> Matrix t -> Matrix t
510
511gaussElim x y = dropColumns (rows x) (flipud $ fromRows s2)
512 where
513 rs = toRows $ fromBlocks [[x , y]]
514 s1 = pivotDown (rows x) 0 rs
515 s2 = pivotUp (rows x-1) (reverse s1)
516
517pivotDown t n xs
518 | t == n = []
519 | otherwise = y : pivotDown t (n+1) ys
520 where
521 y:ys = redu (pivot n xs)
522
523 pivot k = (const k &&& id)
524 . reverse . sortBy (compare `on` (abs. (!k))) -- FIXME
525
526 redu (k,x:zs)
527 | p == 0 = error "gauss: singular!" -- FIXME
528 | otherwise = u : map f zs
529 where
530 p = x!k
531 u = scale (recip (x!k)) x
532 f z = z - scale (z!k) u
533 redu (_,[]) = []
534
535
536pivotUp n xs
537 | n == -1 = []
538 | otherwise = y : pivotUp (n-1) ys
539 where
540 y:ys = redu' (n,xs)
541
542 redu' (k,x:zs) = u : map f zs
543 where
544 u = x
545 f z = z - scale (z!k) u
546 redu' (_,[]) = []
547
548--------------------------------------------------------------------------------
549
550instance Testable (Matrix I) where
551 checkT _ = test
552
553test :: (Bool, IO())
554test = (and ok, return ())
555 where
556 m = (3><4) [1..12] :: Matrix I
557 r = (2><3) [1,2,3,4,3,2]
558 c = (3><2) [0,4,4,1,2,3]
559 p = (9><10) [0..89] :: Matrix I
560 ep = (2><3) [10,24,32,44,31,23]
561 md = fromInt m :: Matrix Double
562 ok = [ tr m <> m == toInt (tr md <> md)
563 , m <> tr m == toInt (md <> tr md)
564 , m ?? (Take 2, Take 3) == remap (asColumn (range 2)) (asRow (range 3)) m
565 , remap r (tr c) p == ep
566 , tr p ?? (PosCyc (idxs[-5,13]), Pos (idxs[3,7,1])) == (2><3) [35,75,15,33,73,13]
567 ]
568