diff options
-rw-r--r-- | CHANGES.md | 5 | ||||
-rw-r--r-- | hmatrix.cabal | 3 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Util/Convolution.hs | 118 | ||||
-rw-r--r-- | packages/tests/hmatrix-tests.cabal | 4 | ||||
-rw-r--r-- | packages/tests/src/Numeric/LinearAlgebra/Tests.hs | 17 |
5 files changed, 143 insertions, 4 deletions
@@ -1,3 +1,8 @@ | |||
1 | 0.14.1.0 | ||
2 | -------- | ||
3 | |||
4 | - convolution | ||
5 | |||
1 | 0.14.0.0 | 6 | 0.14.0.0 |
2 | -------- | 7 | -------- |
3 | 8 | ||
diff --git a/hmatrix.cabal b/hmatrix.cabal index 4d3beb6..0ef5cd7 100644 --- a/hmatrix.cabal +++ b/hmatrix.cabal | |||
@@ -1,5 +1,5 @@ | |||
1 | Name: hmatrix | 1 | Name: hmatrix |
2 | Version: 0.14.0.1 | 2 | Version: 0.14.1.0 |
3 | License: GPL | 3 | License: GPL |
4 | License-file: LICENSE | 4 | License-file: LICENSE |
5 | Author: Alberto Ruiz | 5 | Author: Alberto Ruiz |
@@ -112,6 +112,7 @@ library | |||
112 | Numeric.LinearAlgebra.LAPACK, | 112 | Numeric.LinearAlgebra.LAPACK, |
113 | Numeric.LinearAlgebra.Algorithms, | 113 | Numeric.LinearAlgebra.Algorithms, |
114 | Numeric.LinearAlgebra.Util, | 114 | Numeric.LinearAlgebra.Util, |
115 | Numeric.LinearAlgebra.Util.Convolution, | ||
115 | Graphics.Plot, | 116 | Graphics.Plot, |
116 | Data.Packed.ST, | 117 | Data.Packed.ST, |
117 | Data.Packed.Development | 118 | Data.Packed.Development |
diff --git a/lib/Numeric/LinearAlgebra/Util/Convolution.hs b/lib/Numeric/LinearAlgebra/Util/Convolution.hs new file mode 100644 index 0000000..b64b169 --- /dev/null +++ b/lib/Numeric/LinearAlgebra/Util/Convolution.hs | |||
@@ -0,0 +1,118 @@ | |||
1 | {-# LANGUAGE FlexibleContexts #-} | ||
2 | ----------------------------------------------------------------------------- | ||
3 | {- | | ||
4 | Module : Numeric.LinearAlgebra.Util.Convolution | ||
5 | Copyright : (c) Alberto Ruiz 2012 | ||
6 | License : GPL | ||
7 | |||
8 | Maintainer : Alberto Ruiz (aruiz at um dot es) | ||
9 | Stability : provisional | ||
10 | |||
11 | -} | ||
12 | ----------------------------------------------------------------------------- | ||
13 | |||
14 | module Numeric.LinearAlgebra.Util.Convolution( | ||
15 | -- * 1D | ||
16 | corr, conv, | ||
17 | -- * 2D | ||
18 | corr2, conv2, | ||
19 | -- * Misc | ||
20 | separable, corrMin | ||
21 | ) where | ||
22 | |||
23 | import Numeric.LinearAlgebra | ||
24 | |||
25 | |||
26 | vectSS :: Element t => Int -> Vector t -> Matrix t | ||
27 | vectSS n v = fromRows [ subVector k n v | k <- [0 .. dim v - n] ] | ||
28 | |||
29 | |||
30 | corr :: Product t => Vector t -- ^ kernel | ||
31 | -> Vector t -- ^ source | ||
32 | -> Vector t | ||
33 | {- ^ correlation | ||
34 | |||
35 | @\> (fromList[1,2,3]) (fromList [1..10]) | ||
36 | fromList [14.0,20.0,26.0,32.0,38.0,44.0,50.0,56.0]@ | ||
37 | |||
38 | -} | ||
39 | corr ker v | dim ker <= dim v = vectSS (dim ker) v <> ker | ||
40 | | otherwise = error $ "corr: dim kernel ("++show (dim ker)++") > dim vector ("++show (dim v)++")" | ||
41 | |||
42 | |||
43 | conv :: (Product t, Num t) => Vector t -> Vector t -> Vector t | ||
44 | {- ^ convolution ('corr' with reversed kernel and padded input, equivalent to polynomial product) | ||
45 | |||
46 | @\> conv (fromList[1,1]) (fromList [-1,1]) | ||
47 | fromList [-1.0,0.0,1.0]@ | ||
48 | |||
49 | -} | ||
50 | conv ker v = corr ker' v' | ||
51 | where | ||
52 | ker' = (flatten.fliprl.asRow) ker | ||
53 | v' | dim ker > 1 = join [z,v,z] | ||
54 | | otherwise = v | ||
55 | z = constant 0 (dim ker -1) | ||
56 | |||
57 | corrMin :: (Container Vector t, RealElement t, Product t) | ||
58 | => Vector t | ||
59 | -> Vector t | ||
60 | -> Vector t | ||
61 | -- ^ similar to 'corr', using 'min' instead of (*) | ||
62 | corrMin ker v = minEvery ss (asRow ker) <> ones | ||
63 | where | ||
64 | minEvery a b = cond a b a a b | ||
65 | ss = vectSS (dim ker) v | ||
66 | ones = konst' 1 (dim ker) | ||
67 | |||
68 | |||
69 | |||
70 | matSS :: Element t => Int -> Matrix t -> [Matrix t] | ||
71 | matSS dr m = map (reshape c) [ subVector (k*c) n v | k <- [0 .. r - dr] ] | ||
72 | where | ||
73 | v = flatten m | ||
74 | c = cols m | ||
75 | r = rows m | ||
76 | n = dr*c | ||
77 | |||
78 | |||
79 | corr2 :: Product a => Matrix a -> Matrix a -> Matrix a | ||
80 | -- ^ 2D correlation | ||
81 | corr2 ker mat = dims | ||
82 | . concatMap (map ((<.> ker') . flatten) . matSS c . trans) | ||
83 | . matSS r $ mat | ||
84 | where | ||
85 | r = rows ker | ||
86 | c = cols ker | ||
87 | ker' = flatten (trans ker) | ||
88 | rr = rows mat - r + 1 | ||
89 | rc = cols mat - c + 1 | ||
90 | dims | rr > 0 && rc > 0 = (rr >< rc) | ||
91 | | otherwise = error $ "corr2: dim kernel ("++sz ker++") > dim matrix ("++sz mat++")" | ||
92 | sz m = show (rows m)++"x"++show (cols m) | ||
93 | |||
94 | conv2 :: (Num a, Product a) => Matrix a -> Matrix a -> Matrix a | ||
95 | -- ^ 2D convolution | ||
96 | conv2 k m = corr2 (fliprl . flipud $ k) pm | ||
97 | where | ||
98 | pm | r == 0 && c == 0 = m | ||
99 | | r == 0 = fromBlocks [[z3,m,z3]] | ||
100 | | c == 0 = fromBlocks [[z2],[m],[z2]] | ||
101 | | otherwise = fromBlocks [[z1,z2,z1] | ||
102 | ,[z3, m,z3] | ||
103 | ,[z1,z2,z1]] | ||
104 | r = rows k - 1 | ||
105 | c = cols k - 1 | ||
106 | h = rows m | ||
107 | w = cols m | ||
108 | z1 = konst' 0 (r,c) | ||
109 | z2 = konst' 0 (r,w) | ||
110 | z3 = konst' 0 (h,c) | ||
111 | |||
112 | -- TODO: could be simplified using future empty arrays | ||
113 | |||
114 | |||
115 | separable :: Element t => (Vector t -> Vector t) -> Matrix t -> Matrix t | ||
116 | -- ^ 2D process implemented as separated 1D processes by rows and columns. | ||
117 | separable f = fromColumns . map f . toColumns . fromRows . map f . toRows | ||
118 | |||
diff --git a/packages/tests/hmatrix-tests.cabal b/packages/tests/hmatrix-tests.cabal index 1c6443e..10b7c80 100644 --- a/packages/tests/hmatrix-tests.cabal +++ b/packages/tests/hmatrix-tests.cabal | |||
@@ -1,5 +1,5 @@ | |||
1 | Name: hmatrix-tests | 1 | Name: hmatrix-tests |
2 | Version: 0.2 | 2 | Version: 0.3 |
3 | License: GPL | 3 | License: GPL |
4 | License-file: LICENSE | 4 | License-file: LICENSE |
5 | Author: Alberto Ruiz | 5 | Author: Alberto Ruiz |
@@ -21,7 +21,7 @@ extra-source-files: CHANGES | |||
21 | library | 21 | library |
22 | 22 | ||
23 | Build-Depends: base >= 4 && < 5, | 23 | Build-Depends: base >= 4 && < 5, |
24 | hmatrix >= 0.14, | 24 | hmatrix >= 0.14.1, |
25 | QuickCheck >= 2, HUnit, random | 25 | QuickCheck >= 2, HUnit, random |
26 | 26 | ||
27 | hs-source-dirs: src | 27 | hs-source-dirs: src |
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs index fd66767..4e9a521 100644 --- a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs | |||
@@ -42,6 +42,7 @@ import Data.Packed.Development(unsafeFromForeignPtr,unsafeToForeignPtr) | |||
42 | import Control.Arrow((***)) | 42 | import Control.Arrow((***)) |
43 | import Debug.Trace | 43 | import Debug.Trace |
44 | import Control.Monad(when) | 44 | import Control.Monad(when) |
45 | import Numeric.LinearAlgebra.Util.Convolution | ||
45 | 46 | ||
46 | import Data.Packed.ST | 47 | import Data.Packed.ST |
47 | 48 | ||
@@ -421,7 +422,20 @@ accumTest = utest "accum" ok | |||
421 | && | 422 | && |
422 | toList (flatten x) == [1,0,0,0,1,0,0,0,1] | 423 | toList (flatten x) == [1,0,0,0,1,0,0,0,1] |
423 | 424 | ||
424 | --------------------------------------------------------------------- | 425 | -------------------------------------------------------------------------------- |
426 | |||
427 | convolutionTest = utest "convolution" ok | ||
428 | where | ||
429 | -- a = fromList [1..10] :: Vector Double | ||
430 | b = fromList [1..3] :: Vector Double | ||
431 | c = (5><7) [1..] :: Matrix Double | ||
432 | -- d = (3><3) [0,-1,0,-1,4,-1,0,-1,0] :: Matrix Double | ||
433 | ok = separable (corr b) c == corr2 (outer b b) c | ||
434 | && separable (conv b) c == conv2 (outer b b) c | ||
435 | |||
436 | -------------------------------------------------------------------------------- | ||
437 | |||
438 | |||
425 | 439 | ||
426 | -- | All tests must pass with a maximum dimension of about 20 | 440 | -- | All tests must pass with a maximum dimension of about 20 |
427 | -- (some tests may fail with bigger sizes due to precision loss). | 441 | -- (some tests may fail with bigger sizes due to precision loss). |
@@ -596,6 +610,7 @@ runTests n = do | |||
596 | , condTest | 610 | , condTest |
597 | , conformTest | 611 | , conformTest |
598 | , accumTest | 612 | , accumTest |
613 | , convolutionTest | ||
599 | ] | 614 | ] |
600 | when (errors c + failures c > 0) exitFailure | 615 | when (errors c + failures c > 0) exitFailure |
601 | return () | 616 | return () |