diff options
author | Alberto Ruiz <aruiz@um.es> | 2012-06-15 11:09:15 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2012-06-15 11:09:15 +0200 |
commit | 899c1f71f64f49c5d3b2c264501565227977cd9c (patch) | |
tree | 4abe298e0be5bcfaad0a8b5bad137229ece13d8c | |
parent | f10af430ec9ab1cc71f8931dec9a4247fc780933 (diff) |
kronecker tools
-rw-r--r-- | CHANGES.md | 4 | ||||
-rw-r--r-- | hmatrix.cabal | 4 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Util.hs | 56 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Util/Convolution.hs | 18 | ||||
-rw-r--r-- | packages/tests/src/Numeric/LinearAlgebra/Tests.hs | 20 |
5 files changed, 83 insertions, 19 deletions
@@ -1,7 +1,9 @@ | |||
1 | 0.14.1.0 | 1 | 0.14.1.0 |
2 | -------- | 2 | -------- |
3 | 3 | ||
4 | - convolution | 4 | - In Numeric.LinearAlgebra.Util: |
5 | convolution: corr, conv, corr2, conv2, separable, corrMin | ||
6 | kronecker: vec, vech, dup, vtrans | ||
5 | 7 | ||
6 | 0.14.0.0 | 8 | 0.14.0.0 |
7 | -------- | 9 | -------- |
diff --git a/hmatrix.cabal b/hmatrix.cabal index 0ef5cd7..21d6fa4 100644 --- a/hmatrix.cabal +++ b/hmatrix.cabal | |||
@@ -112,7 +112,6 @@ library | |||
112 | Numeric.LinearAlgebra.LAPACK, | 112 | Numeric.LinearAlgebra.LAPACK, |
113 | Numeric.LinearAlgebra.Algorithms, | 113 | Numeric.LinearAlgebra.Algorithms, |
114 | Numeric.LinearAlgebra.Util, | 114 | Numeric.LinearAlgebra.Util, |
115 | Numeric.LinearAlgebra.Util.Convolution, | ||
116 | Graphics.Plot, | 115 | Graphics.Plot, |
117 | Data.Packed.ST, | 116 | Data.Packed.ST, |
118 | Data.Packed.Development | 117 | Data.Packed.Development |
@@ -129,7 +128,8 @@ library | |||
129 | Numeric.IO, | 128 | Numeric.IO, |
130 | Numeric.Chain, | 129 | Numeric.Chain, |
131 | Numeric.Vector, | 130 | Numeric.Vector, |
132 | Numeric.Matrix | 131 | Numeric.Matrix, |
132 | Numeric.LinearAlgebra.Util.Convolution | ||
133 | 133 | ||
134 | C-sources: lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c, | 134 | C-sources: lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c, |
135 | lib/Numeric/GSL/gsl-aux.c | 135 | lib/Numeric/GSL/gsl-aux.c |
diff --git a/lib/Numeric/LinearAlgebra/Util.hs b/lib/Numeric/LinearAlgebra/Util.hs index 79b8774..25eb239 100644 --- a/lib/Numeric/LinearAlgebra/Util.hs +++ b/lib/Numeric/LinearAlgebra/Util.hs | |||
@@ -1,3 +1,4 @@ | |||
1 | {-# LANGUAGE FlexibleContexts #-} | ||
1 | ----------------------------------------------------------------------------- | 2 | ----------------------------------------------------------------------------- |
2 | {- | | 3 | {- | |
3 | Module : Numeric.LinearAlgebra.Util | 4 | Module : Numeric.LinearAlgebra.Util |
@@ -11,6 +12,7 @@ Stability : provisional | |||
11 | ----------------------------------------------------------------------------- | 12 | ----------------------------------------------------------------------------- |
12 | 13 | ||
13 | module Numeric.LinearAlgebra.Util( | 14 | module Numeric.LinearAlgebra.Util( |
15 | -- * Convenience functions for real elements | ||
14 | disp, | 16 | disp, |
15 | zeros, ones, | 17 | zeros, ones, |
16 | diagl, | 18 | diagl, |
@@ -19,11 +21,24 @@ module Numeric.LinearAlgebra.Util( | |||
19 | (&),(!), (#), | 21 | (&),(!), (#), |
20 | rand, randn, | 22 | rand, randn, |
21 | cross, | 23 | cross, |
22 | norm | 24 | norm, |
25 | -- * Convolution | ||
26 | -- ** 1D | ||
27 | corr, conv, corrMin, | ||
28 | -- ** 2D | ||
29 | corr2, conv2, separable, | ||
30 | -- * Tools for the Kronecker product | ||
31 | -- | ||
32 | -- | @`vec` (a \<> x \<> b) == ('trans' b ` 'kronecker' ` a) \<> 'vec' x@ | ||
33 | vec, | ||
34 | vech, | ||
35 | dup, | ||
36 | vtrans | ||
23 | ) where | 37 | ) where |
24 | 38 | ||
25 | import Numeric.LinearAlgebra | 39 | import Numeric.LinearAlgebra hiding (i) |
26 | import System.Random(randomIO) | 40 | import System.Random(randomIO) |
41 | import Numeric.LinearAlgebra.Util.Convolution | ||
27 | 42 | ||
28 | 43 | ||
29 | disp :: Int -> Matrix Double -> IO () | 44 | disp :: Int -> Matrix Double -> IO () |
@@ -87,7 +102,7 @@ col :: [Double] -> Matrix Double | |||
87 | col = asColumn . fromList | 102 | col = asColumn . fromList |
88 | 103 | ||
89 | cross :: Vector Double -> Vector Double -> Vector Double | 104 | cross :: Vector Double -> Vector Double -> Vector Double |
90 | -- ^ cross product of dimension 3 real vectors | 105 | -- ^ cross product (for three-element real vectors) |
91 | cross x y | dim x == 3 && dim y == 3 = fromList [z1,z2,z3] | 106 | cross x y | dim x == 3 && dim y == 3 = fromList [z1,z2,z3] |
92 | | otherwise = error $ "cross ("++show x++") ("++show y++")" | 107 | | otherwise = error $ "cross ("++show x++") ("++show y++")" |
93 | where | 108 | where |
@@ -98,7 +113,40 @@ cross x y | dim x == 3 && dim y == 3 = fromList [z1,z2,z3] | |||
98 | z3 = x1*y2-x2*y1 | 113 | z3 = x1*y2-x2*y1 |
99 | 114 | ||
100 | norm :: Vector Double -> Double | 115 | norm :: Vector Double -> Double |
101 | -- ^ 2-norm of real vectors | 116 | -- ^ 2-norm of real vector |
102 | norm = pnorm PNorm2 | 117 | norm = pnorm PNorm2 |
103 | 118 | ||
119 | -------------------------------------------------------------------------------- | ||
120 | |||
121 | vec :: Element t => Matrix t -> Vector t | ||
122 | -- ^ stacking of columns | ||
123 | vec = flatten . trans | ||
124 | |||
125 | |||
126 | vech :: Element t => Matrix t -> Vector t | ||
127 | -- ^ half-vectorization (of the lower triangular part) | ||
128 | vech m = join . zipWith f [0..] . toColumns $ m | ||
129 | where | ||
130 | f k v = subVector k (dim v - k) v | ||
131 | |||
132 | |||
133 | dup :: (Num t, Num (Vector t), Element t) => Int -> Matrix t | ||
134 | -- ^ duplication matrix (@'dup' k \<> 'vech' m == 'vec' m@, for symmetric m of 'dim' k) | ||
135 | dup k = trans $ fromRows $ map f es | ||
136 | where | ||
137 | rs = zip [0..] (toRows (ident (k^(2::Int)))) | ||
138 | es = [(i,j) | j <- [0..k-1], i <- [0..k-1], i>=j ] | ||
139 | f (i,j) | i == j = g (k*j + i) | ||
140 | | otherwise = g (k*j + i) + g (k*i + j) | ||
141 | g j = v | ||
142 | where | ||
143 | Just v = lookup j rs | ||
144 | |||
145 | |||
146 | vtrans :: Element t => Int -> Matrix t -> Matrix t | ||
147 | -- ^ generalized \"vector\" transposition: @'vtrans' 1 == 'trans'@, and @'vtrans' ('rows' m) m == 'asColumn' ('vec' m)@ | ||
148 | vtrans p m | r == 0 = fromBlocks . map (map asColumn . takesV (replicate q p)) . toColumns $ m | ||
149 | | otherwise = error $ "vtrans " ++ show p ++ " of matrix with " ++ show (rows m) ++ " rows" | ||
150 | where | ||
151 | (q,r) = divMod (rows m) p | ||
104 | 152 | ||
diff --git a/lib/Numeric/LinearAlgebra/Util/Convolution.hs b/lib/Numeric/LinearAlgebra/Util/Convolution.hs index b64b169..32cb188 100644 --- a/lib/Numeric/LinearAlgebra/Util/Convolution.hs +++ b/lib/Numeric/LinearAlgebra/Util/Convolution.hs | |||
@@ -12,12 +12,8 @@ Stability : provisional | |||
12 | ----------------------------------------------------------------------------- | 12 | ----------------------------------------------------------------------------- |
13 | 13 | ||
14 | module Numeric.LinearAlgebra.Util.Convolution( | 14 | module Numeric.LinearAlgebra.Util.Convolution( |
15 | -- * 1D | 15 | corr, conv, corrMin, |
16 | corr, conv, | 16 | corr2, conv2, separable |
17 | -- * 2D | ||
18 | corr2, conv2, | ||
19 | -- * Misc | ||
20 | separable, corrMin | ||
21 | ) where | 17 | ) where |
22 | 18 | ||
23 | import Numeric.LinearAlgebra | 19 | import Numeric.LinearAlgebra |
@@ -32,8 +28,8 @@ corr :: Product t => Vector t -- ^ kernel | |||
32 | -> Vector t | 28 | -> Vector t |
33 | {- ^ correlation | 29 | {- ^ correlation |
34 | 30 | ||
35 | @\> (fromList[1,2,3]) (fromList [1..10]) | 31 | >>> corr (fromList[1,2,3]) (fromList [1..10]) |
36 | fromList [14.0,20.0,26.0,32.0,38.0,44.0,50.0,56.0]@ | 32 | fromList [14.0,20.0,26.0,32.0,38.0,44.0,50.0,56.0] |
37 | 33 | ||
38 | -} | 34 | -} |
39 | corr ker v | dim ker <= dim v = vectSS (dim ker) v <> ker | 35 | corr ker v | dim ker <= dim v = vectSS (dim ker) v <> ker |
@@ -43,8 +39,8 @@ corr ker v | dim ker <= dim v = vectSS (dim ker) v <> ker | |||
43 | conv :: (Product t, Num t) => Vector t -> Vector t -> Vector t | 39 | conv :: (Product t, Num t) => Vector t -> Vector t -> Vector t |
44 | {- ^ convolution ('corr' with reversed kernel and padded input, equivalent to polynomial product) | 40 | {- ^ convolution ('corr' with reversed kernel and padded input, equivalent to polynomial product) |
45 | 41 | ||
46 | @\> conv (fromList[1,1]) (fromList [-1,1]) | 42 | >>> conv (fromList[1,1]) (fromList [-1,1]) |
47 | fromList [-1.0,0.0,1.0]@ | 43 | fromList [-1.0,0.0,1.0] |
48 | 44 | ||
49 | -} | 45 | -} |
50 | conv ker v = corr ker' v' | 46 | conv ker v = corr ker' v' |
@@ -113,6 +109,6 @@ conv2 k m = corr2 (fliprl . flipud $ k) pm | |||
113 | 109 | ||
114 | 110 | ||
115 | separable :: Element t => (Vector t -> Vector t) -> Matrix t -> Matrix t | 111 | separable :: Element t => (Vector t -> Vector t) -> Matrix t -> Matrix t |
116 | -- ^ 2D process implemented as separated 1D processes by rows and columns. | 112 | -- ^ matrix computation implemented as separated vector operations by rows and columns. |
117 | separable f = fromColumns . map f . toColumns . fromRows . map f . toRows | 113 | separable f = fromColumns . map f . toColumns . fromRows . map f . toRows |
118 | 114 | ||
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs index 4e9a521..99c0c91 100644 --- a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs | |||
@@ -42,7 +42,7 @@ import Data.Packed.Development(unsafeFromForeignPtr,unsafeToForeignPtr) | |||
42 | import Control.Arrow((***)) | 42 | import Control.Arrow((***)) |
43 | import Debug.Trace | 43 | import Debug.Trace |
44 | import Control.Monad(when) | 44 | import Control.Monad(when) |
45 | import Numeric.LinearAlgebra.Util.Convolution | 45 | import Numeric.LinearAlgebra.Util hiding (ones,row,col) |
46 | 46 | ||
47 | import Data.Packed.ST | 47 | import Data.Packed.ST |
48 | 48 | ||
@@ -435,6 +435,23 @@ convolutionTest = utest "convolution" ok | |||
435 | 435 | ||
436 | -------------------------------------------------------------------------------- | 436 | -------------------------------------------------------------------------------- |
437 | 437 | ||
438 | kroneckerTest = utest "kronecker" ok | ||
439 | where | ||
440 | a,x,b :: Matrix Double | ||
441 | a = (3><4) [1..] | ||
442 | x = (4><2) [3,5..] | ||
443 | b = (2><5) [0,5..] | ||
444 | v1 = vec (a <> x <> b) | ||
445 | v2 = (trans b `kronecker` a) <> vec x | ||
446 | s = trans b <> b | ||
447 | v3 = vec s | ||
448 | v4 = dup 5 <> vech s | ||
449 | ok = v1 == v2 && v3 == v4 | ||
450 | && vtrans 1 a == trans a | ||
451 | && vtrans (rows a) a == asColumn (vec a) | ||
452 | |||
453 | -------------------------------------------------------------------------------- | ||
454 | |||
438 | 455 | ||
439 | 456 | ||
440 | -- | All tests must pass with a maximum dimension of about 20 | 457 | -- | All tests must pass with a maximum dimension of about 20 |
@@ -611,6 +628,7 @@ runTests n = do | |||
611 | , conformTest | 628 | , conformTest |
612 | , accumTest | 629 | , accumTest |
613 | , convolutionTest | 630 | , convolutionTest |
631 | , kroneckerTest | ||
614 | ] | 632 | ] |
615 | when (errors c + failures c > 0) exitFailure | 633 | when (errors c + failures c > 0) exitFailure |
616 | return () | 634 | return () |