summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGES.md5
-rw-r--r--hmatrix.cabal3
-rw-r--r--lib/Numeric/LinearAlgebra/Util/Convolution.hs118
-rw-r--r--packages/tests/hmatrix-tests.cabal4
-rw-r--r--packages/tests/src/Numeric/LinearAlgebra/Tests.hs17
5 files changed, 143 insertions, 4 deletions
diff --git a/CHANGES.md b/CHANGES.md
index 2788c52..0ce7699 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,3 +1,8 @@
10.14.1.0
2--------
3
4- convolution
5
10.14.0.0 60.14.0.0
2-------- 7--------
3 8
diff --git a/hmatrix.cabal b/hmatrix.cabal
index 4d3beb6..0ef5cd7 100644
--- a/hmatrix.cabal
+++ b/hmatrix.cabal
@@ -1,5 +1,5 @@
1Name: hmatrix 1Name: hmatrix
2Version: 0.14.0.1 2Version: 0.14.1.0
3License: GPL 3License: GPL
4License-file: LICENSE 4License-file: LICENSE
5Author: Alberto Ruiz 5Author: Alberto Ruiz
@@ -112,6 +112,7 @@ library
112 Numeric.LinearAlgebra.LAPACK, 112 Numeric.LinearAlgebra.LAPACK,
113 Numeric.LinearAlgebra.Algorithms, 113 Numeric.LinearAlgebra.Algorithms,
114 Numeric.LinearAlgebra.Util, 114 Numeric.LinearAlgebra.Util,
115 Numeric.LinearAlgebra.Util.Convolution,
115 Graphics.Plot, 116 Graphics.Plot,
116 Data.Packed.ST, 117 Data.Packed.ST,
117 Data.Packed.Development 118 Data.Packed.Development
diff --git a/lib/Numeric/LinearAlgebra/Util/Convolution.hs b/lib/Numeric/LinearAlgebra/Util/Convolution.hs
new file mode 100644
index 0000000..b64b169
--- /dev/null
+++ b/lib/Numeric/LinearAlgebra/Util/Convolution.hs
@@ -0,0 +1,118 @@
1{-# LANGUAGE FlexibleContexts #-}
2-----------------------------------------------------------------------------
3{- |
4Module : Numeric.LinearAlgebra.Util.Convolution
5Copyright : (c) Alberto Ruiz 2012
6License : GPL
7
8Maintainer : Alberto Ruiz (aruiz at um dot es)
9Stability : provisional
10
11-}
12-----------------------------------------------------------------------------
13
14module Numeric.LinearAlgebra.Util.Convolution(
15 -- * 1D
16 corr, conv,
17 -- * 2D
18 corr2, conv2,
19 -- * Misc
20 separable, corrMin
21) where
22
23import Numeric.LinearAlgebra
24
25
26vectSS :: Element t => Int -> Vector t -> Matrix t
27vectSS n v = fromRows [ subVector k n v | k <- [0 .. dim v - n] ]
28
29
30corr :: Product t => Vector t -- ^ kernel
31 -> Vector t -- ^ source
32 -> Vector t
33{- ^ correlation
34
35@\> (fromList[1,2,3]) (fromList [1..10])
36fromList [14.0,20.0,26.0,32.0,38.0,44.0,50.0,56.0]@
37
38-}
39corr ker v | dim ker <= dim v = vectSS (dim ker) v <> ker
40 | otherwise = error $ "corr: dim kernel ("++show (dim ker)++") > dim vector ("++show (dim v)++")"
41
42
43conv :: (Product t, Num t) => Vector t -> Vector t -> Vector t
44{- ^ convolution ('corr' with reversed kernel and padded input, equivalent to polynomial product)
45
46@\> conv (fromList[1,1]) (fromList [-1,1])
47fromList [-1.0,0.0,1.0]@
48
49-}
50conv ker v = corr ker' v'
51 where
52 ker' = (flatten.fliprl.asRow) ker
53 v' | dim ker > 1 = join [z,v,z]
54 | otherwise = v
55 z = constant 0 (dim ker -1)
56
57corrMin :: (Container Vector t, RealElement t, Product t)
58 => Vector t
59 -> Vector t
60 -> Vector t
61-- ^ similar to 'corr', using 'min' instead of (*)
62corrMin ker v = minEvery ss (asRow ker) <> ones
63 where
64 minEvery a b = cond a b a a b
65 ss = vectSS (dim ker) v
66 ones = konst' 1 (dim ker)
67
68
69
70matSS :: Element t => Int -> Matrix t -> [Matrix t]
71matSS dr m = map (reshape c) [ subVector (k*c) n v | k <- [0 .. r - dr] ]
72 where
73 v = flatten m
74 c = cols m
75 r = rows m
76 n = dr*c
77
78
79corr2 :: Product a => Matrix a -> Matrix a -> Matrix a
80-- ^ 2D correlation
81corr2 ker mat = dims
82 . concatMap (map ((<.> ker') . flatten) . matSS c . trans)
83 . matSS r $ mat
84 where
85 r = rows ker
86 c = cols ker
87 ker' = flatten (trans ker)
88 rr = rows mat - r + 1
89 rc = cols mat - c + 1
90 dims | rr > 0 && rc > 0 = (rr >< rc)
91 | otherwise = error $ "corr2: dim kernel ("++sz ker++") > dim matrix ("++sz mat++")"
92 sz m = show (rows m)++"x"++show (cols m)
93
94conv2 :: (Num a, Product a) => Matrix a -> Matrix a -> Matrix a
95-- ^ 2D convolution
96conv2 k m = corr2 (fliprl . flipud $ k) pm
97 where
98 pm | r == 0 && c == 0 = m
99 | r == 0 = fromBlocks [[z3,m,z3]]
100 | c == 0 = fromBlocks [[z2],[m],[z2]]
101 | otherwise = fromBlocks [[z1,z2,z1]
102 ,[z3, m,z3]
103 ,[z1,z2,z1]]
104 r = rows k - 1
105 c = cols k - 1
106 h = rows m
107 w = cols m
108 z1 = konst' 0 (r,c)
109 z2 = konst' 0 (r,w)
110 z3 = konst' 0 (h,c)
111
112-- TODO: could be simplified using future empty arrays
113
114
115separable :: Element t => (Vector t -> Vector t) -> Matrix t -> Matrix t
116-- ^ 2D process implemented as separated 1D processes by rows and columns.
117separable f = fromColumns . map f . toColumns . fromRows . map f . toRows
118
diff --git a/packages/tests/hmatrix-tests.cabal b/packages/tests/hmatrix-tests.cabal
index 1c6443e..10b7c80 100644
--- a/packages/tests/hmatrix-tests.cabal
+++ b/packages/tests/hmatrix-tests.cabal
@@ -1,5 +1,5 @@
1Name: hmatrix-tests 1Name: hmatrix-tests
2Version: 0.2 2Version: 0.3
3License: GPL 3License: GPL
4License-file: LICENSE 4License-file: LICENSE
5Author: Alberto Ruiz 5Author: Alberto Ruiz
@@ -21,7 +21,7 @@ extra-source-files: CHANGES
21library 21library
22 22
23 Build-Depends: base >= 4 && < 5, 23 Build-Depends: base >= 4 && < 5,
24 hmatrix >= 0.14, 24 hmatrix >= 0.14.1,
25 QuickCheck >= 2, HUnit, random 25 QuickCheck >= 2, HUnit, random
26 26
27 hs-source-dirs: src 27 hs-source-dirs: src
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
index fd66767..4e9a521 100644
--- a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
+++ b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
@@ -42,6 +42,7 @@ import Data.Packed.Development(unsafeFromForeignPtr,unsafeToForeignPtr)
42import Control.Arrow((***)) 42import Control.Arrow((***))
43import Debug.Trace 43import Debug.Trace
44import Control.Monad(when) 44import Control.Monad(when)
45import Numeric.LinearAlgebra.Util.Convolution
45 46
46import Data.Packed.ST 47import Data.Packed.ST
47 48
@@ -421,7 +422,20 @@ accumTest = utest "accum" ok
421 && 422 &&
422 toList (flatten x) == [1,0,0,0,1,0,0,0,1] 423 toList (flatten x) == [1,0,0,0,1,0,0,0,1]
423 424
424--------------------------------------------------------------------- 425--------------------------------------------------------------------------------
426
427convolutionTest = utest "convolution" ok
428 where
429-- a = fromList [1..10] :: Vector Double
430 b = fromList [1..3] :: Vector Double
431 c = (5><7) [1..] :: Matrix Double
432-- d = (3><3) [0,-1,0,-1,4,-1,0,-1,0] :: Matrix Double
433 ok = separable (corr b) c == corr2 (outer b b) c
434 && separable (conv b) c == conv2 (outer b b) c
435
436--------------------------------------------------------------------------------
437
438
425 439
426-- | All tests must pass with a maximum dimension of about 20 440-- | All tests must pass with a maximum dimension of about 20
427-- (some tests may fail with bigger sizes due to precision loss). 441-- (some tests may fail with bigger sizes due to precision loss).
@@ -596,6 +610,7 @@ runTests n = do
596 , condTest 610 , condTest
597 , conformTest 611 , conformTest
598 , accumTest 612 , accumTest
613 , convolutionTest
599 ] 614 ]
600 when (errors c + failures c > 0) exitFailure 615 when (errors c + failures c > 0) exitFailure
601 return () 616 return ()