diff options
Diffstat (limited to 'packages/base')
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Util/Convolution.hs | 83 |
1 files changed, 58 insertions, 25 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/Convolution.hs b/packages/base/src/Numeric/LinearAlgebra/Util/Convolution.hs index 3cad8d7..1d4e089 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Util/Convolution.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Util/Convolution.hs | |||
@@ -32,8 +32,10 @@ corr :: Product t => Vector t -- ^ kernel | |||
32 | fromList [14.0,20.0,26.0,32.0,38.0,44.0,50.0,56.0] | 32 | fromList [14.0,20.0,26.0,32.0,38.0,44.0,50.0,56.0] |
33 | 33 | ||
34 | -} | 34 | -} |
35 | corr ker v | dim ker <= dim v = vectSS (dim ker) v <> ker | 35 | corr ker v |
36 | | otherwise = error $ "corr: dim kernel ("++show (dim ker)++") > dim vector ("++show (dim v)++")" | 36 | | dim ker == 0 = constant 0 (dim v) |
37 | | dim ker <= dim v = vectSS (dim ker) v <> ker | ||
38 | | otherwise = error $ "corr: dim kernel ("++show (dim ker)++") > dim vector ("++show (dim v)++")" | ||
37 | 39 | ||
38 | 40 | ||
39 | conv :: (Product t, Num t) => Vector t -> Vector t -> Vector t | 41 | conv :: (Product t, Num t) => Vector t -> Vector t -> Vector t |
@@ -43,11 +45,12 @@ conv :: (Product t, Num t) => Vector t -> Vector t -> Vector t | |||
43 | fromList [-1.0,0.0,1.0] | 45 | fromList [-1.0,0.0,1.0] |
44 | 46 | ||
45 | -} | 47 | -} |
46 | conv ker v = corr ker' v' | 48 | conv ker v |
49 | | dim ker == 0 = constant 0 (dim v) | ||
50 | | otherwise = corr ker' v' | ||
47 | where | 51 | where |
48 | ker' = (flatten.fliprl.asRow) ker | 52 | ker' = (flatten.fliprl.asRow) ker |
49 | v' | dim ker > 1 = vjoin [z,v,z] | 53 | v' = vjoin [z,v,z] |
50 | | otherwise = v | ||
51 | z = constant 0 (dim ker -1) | 54 | z = constant 0 (dim ker -1) |
52 | 55 | ||
53 | corrMin :: (Container Vector t, RealElement t, Product t) | 56 | corrMin :: (Container Vector t, RealElement t, Product t) |
@@ -55,7 +58,9 @@ corrMin :: (Container Vector t, RealElement t, Product t) | |||
55 | -> Vector t | 58 | -> Vector t |
56 | -> Vector t | 59 | -> Vector t |
57 | -- ^ similar to 'corr', using 'min' instead of (*) | 60 | -- ^ similar to 'corr', using 'min' instead of (*) |
58 | corrMin ker v = minEvery ss (asRow ker) <> ones | 61 | corrMin ker v |
62 | | dim ker == 0 = error "corrMin: empty kernel" | ||
63 | | otherwise = minEvery ss (asRow ker) <> ones | ||
59 | where | 64 | where |
60 | minEvery a b = cond a b a a b | 65 | minEvery a b = cond a b a a b |
61 | ss = vectSS (dim ker) v | 66 | ss = vectSS (dim ker) v |
@@ -72,8 +77,21 @@ matSS dr m = map (reshape c) [ subVector (k*c) n v | k <- [0 .. r - dr] ] | |||
72 | n = dr*c | 77 | n = dr*c |
73 | 78 | ||
74 | 79 | ||
80 | {- | 2D correlation (without padding) | ||
81 | |||
82 | >>> disp 5 $ corr2 (konst 1 (3,3)) (ident 10 :: Matrix Double) | ||
83 | 8x8 | ||
84 | 3 2 1 0 0 0 0 0 | ||
85 | 2 3 2 1 0 0 0 0 | ||
86 | 1 2 3 2 1 0 0 0 | ||
87 | 0 1 2 3 2 1 0 0 | ||
88 | 0 0 1 2 3 2 1 0 | ||
89 | 0 0 0 1 2 3 2 1 | ||
90 | 0 0 0 0 1 2 3 2 | ||
91 | 0 0 0 0 0 1 2 3 | ||
92 | |||
93 | -} | ||
75 | corr2 :: Product a => Matrix a -> Matrix a -> Matrix a | 94 | corr2 :: Product a => Matrix a -> Matrix a -> Matrix a |
76 | -- ^ 2D correlation | ||
77 | corr2 ker mat = dims | 95 | corr2 ker mat = dims |
78 | . concatMap (map (udot ker' . flatten) . matSS c . trans) | 96 | . concatMap (map (udot ker' . flatten) . matSS c . trans) |
79 | . matSS r $ mat | 97 | . matSS r $ mat |
@@ -86,26 +104,41 @@ corr2 ker mat = dims | |||
86 | dims | rr > 0 && rc > 0 = (rr >< rc) | 104 | dims | rr > 0 && rc > 0 = (rr >< rc) |
87 | | otherwise = error $ "corr2: dim kernel ("++sz ker++") > dim matrix ("++sz mat++")" | 105 | | otherwise = error $ "corr2: dim kernel ("++sz ker++") > dim matrix ("++sz mat++")" |
88 | sz m = show (rows m)++"x"++show (cols m) | 106 | sz m = show (rows m)++"x"++show (cols m) |
107 | -- TODO check empty kernel | ||
108 | |||
109 | {- | 2D convolution | ||
110 | |||
111 | >>> disp 5 $ conv2 (konst 1 (3,3)) (ident 10 :: Matrix Double) | ||
112 | 12x12 | ||
113 | 1 1 1 0 0 0 0 0 0 0 0 0 | ||
114 | 1 2 2 1 0 0 0 0 0 0 0 0 | ||
115 | 1 2 3 2 1 0 0 0 0 0 0 0 | ||
116 | 0 1 2 3 2 1 0 0 0 0 0 0 | ||
117 | 0 0 1 2 3 2 1 0 0 0 0 0 | ||
118 | 0 0 0 1 2 3 2 1 0 0 0 0 | ||
119 | 0 0 0 0 1 2 3 2 1 0 0 0 | ||
120 | 0 0 0 0 0 1 2 3 2 1 0 0 | ||
121 | 0 0 0 0 0 0 1 2 3 2 1 0 | ||
122 | 0 0 0 0 0 0 0 1 2 3 2 1 | ||
123 | 0 0 0 0 0 0 0 0 1 2 2 1 | ||
124 | 0 0 0 0 0 0 0 0 0 1 1 1 | ||
89 | 125 | ||
90 | conv2 :: (Num a, Product a, Container Vector a) => Matrix a -> Matrix a -> Matrix a | 126 | -} |
91 | -- ^ 2D convolution | 127 | conv2 |
92 | conv2 k m = corr2 (fliprl . flipud $ k) pm | 128 | :: (Num (Matrix a), Product a, Container Vector a) |
129 | => Matrix a -- ^ kernel | ||
130 | -> Matrix a -> Matrix a | ||
131 | conv2 k m | ||
132 | | empty = konst 0 (rows m + r -1, cols m + c -1) | ||
133 | | otherwise = corr2 (fliprl . flipud $ k) padded | ||
93 | where | 134 | where |
94 | pm | r == 0 && c == 0 = m | 135 | padded = fromBlocks [[z,0,0] |
95 | | r == 0 = fromBlocks [[z3,m,z3]] | 136 | ,[0,m,0] |
96 | | c == 0 = fromBlocks [[z2],[m],[z2]] | 137 | ,[0,0,z]] |
97 | | otherwise = fromBlocks [[z1,z2,z1] | 138 | r = rows k |
98 | ,[z3, m,z3] | 139 | c = cols k |
99 | ,[z1,z2,z1]] | 140 | z = konst 0 (r-1,c-1) |
100 | r = rows k - 1 | 141 | empty = r == 0 || c == 0 |
101 | c = cols k - 1 | ||
102 | h = rows m | ||
103 | w = cols m | ||
104 | z1 = konst 0 (r,c) | ||
105 | z2 = konst 0 (r,w) | ||
106 | z3 = konst 0 (h,c) | ||
107 | |||
108 | -- TODO: could be simplified using future empty arrays | ||
109 | 142 | ||
110 | 143 | ||
111 | separable :: Element t => (Vector t -> Vector t) -> Matrix t -> Matrix t | 144 | separable :: Element t => (Vector t -> Vector t) -> Matrix t -> Matrix t |