summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2012-06-15 11:09:15 +0200
committerAlberto Ruiz <aruiz@um.es>2012-06-15 11:09:15 +0200
commit899c1f71f64f49c5d3b2c264501565227977cd9c (patch)
tree4abe298e0be5bcfaad0a8b5bad137229ece13d8c
parentf10af430ec9ab1cc71f8931dec9a4247fc780933 (diff)
kronecker tools
-rw-r--r--CHANGES.md4
-rw-r--r--hmatrix.cabal4
-rw-r--r--lib/Numeric/LinearAlgebra/Util.hs56
-rw-r--r--lib/Numeric/LinearAlgebra/Util/Convolution.hs18
-rw-r--r--packages/tests/src/Numeric/LinearAlgebra/Tests.hs20
5 files changed, 83 insertions, 19 deletions
diff --git a/CHANGES.md b/CHANGES.md
index 0ce7699..3b19ec9 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,7 +1,9 @@
10.14.1.0 10.14.1.0
2-------- 2--------
3 3
4- convolution 4- In Numeric.LinearAlgebra.Util:
5 convolution: corr, conv, corr2, conv2, separable, corrMin
6 kronecker: vec, vech, dup, vtrans
5 7
60.14.0.0 80.14.0.0
7-------- 9--------
diff --git a/hmatrix.cabal b/hmatrix.cabal
index 0ef5cd7..21d6fa4 100644
--- a/hmatrix.cabal
+++ b/hmatrix.cabal
@@ -112,7 +112,6 @@ library
112 Numeric.LinearAlgebra.LAPACK, 112 Numeric.LinearAlgebra.LAPACK,
113 Numeric.LinearAlgebra.Algorithms, 113 Numeric.LinearAlgebra.Algorithms,
114 Numeric.LinearAlgebra.Util, 114 Numeric.LinearAlgebra.Util,
115 Numeric.LinearAlgebra.Util.Convolution,
116 Graphics.Plot, 115 Graphics.Plot,
117 Data.Packed.ST, 116 Data.Packed.ST,
118 Data.Packed.Development 117 Data.Packed.Development
@@ -129,7 +128,8 @@ library
129 Numeric.IO, 128 Numeric.IO,
130 Numeric.Chain, 129 Numeric.Chain,
131 Numeric.Vector, 130 Numeric.Vector,
132 Numeric.Matrix 131 Numeric.Matrix,
132 Numeric.LinearAlgebra.Util.Convolution
133 133
134 C-sources: lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c, 134 C-sources: lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c,
135 lib/Numeric/GSL/gsl-aux.c 135 lib/Numeric/GSL/gsl-aux.c
diff --git a/lib/Numeric/LinearAlgebra/Util.hs b/lib/Numeric/LinearAlgebra/Util.hs
index 79b8774..25eb239 100644
--- a/lib/Numeric/LinearAlgebra/Util.hs
+++ b/lib/Numeric/LinearAlgebra/Util.hs
@@ -1,3 +1,4 @@
1{-# LANGUAGE FlexibleContexts #-}
1----------------------------------------------------------------------------- 2-----------------------------------------------------------------------------
2{- | 3{- |
3Module : Numeric.LinearAlgebra.Util 4Module : Numeric.LinearAlgebra.Util
@@ -11,6 +12,7 @@ Stability : provisional
11----------------------------------------------------------------------------- 12-----------------------------------------------------------------------------
12 13
13module Numeric.LinearAlgebra.Util( 14module Numeric.LinearAlgebra.Util(
15 -- * Convenience functions for real elements
14 disp, 16 disp,
15 zeros, ones, 17 zeros, ones,
16 diagl, 18 diagl,
@@ -19,11 +21,24 @@ module Numeric.LinearAlgebra.Util(
19 (&),(!), (#), 21 (&),(!), (#),
20 rand, randn, 22 rand, randn,
21 cross, 23 cross,
22 norm 24 norm,
25 -- * Convolution
26 -- ** 1D
27 corr, conv, corrMin,
28 -- ** 2D
29 corr2, conv2, separable,
30 -- * Tools for the Kronecker product
31 --
32 -- | @`vec` (a \<> x \<> b) == ('trans' b ` 'kronecker' ` a) \<> 'vec' x@
33 vec,
34 vech,
35 dup,
36 vtrans
23) where 37) where
24 38
25import Numeric.LinearAlgebra 39import Numeric.LinearAlgebra hiding (i)
26import System.Random(randomIO) 40import System.Random(randomIO)
41import Numeric.LinearAlgebra.Util.Convolution
27 42
28 43
29disp :: Int -> Matrix Double -> IO () 44disp :: Int -> Matrix Double -> IO ()
@@ -87,7 +102,7 @@ col :: [Double] -> Matrix Double
87col = asColumn . fromList 102col = asColumn . fromList
88 103
89cross :: Vector Double -> Vector Double -> Vector Double 104cross :: Vector Double -> Vector Double -> Vector Double
90-- ^ cross product of dimension 3 real vectors 105-- ^ cross product (for three-element real vectors)
91cross x y | dim x == 3 && dim y == 3 = fromList [z1,z2,z3] 106cross x y | dim x == 3 && dim y == 3 = fromList [z1,z2,z3]
92 | otherwise = error $ "cross ("++show x++") ("++show y++")" 107 | otherwise = error $ "cross ("++show x++") ("++show y++")"
93 where 108 where
@@ -98,7 +113,40 @@ cross x y | dim x == 3 && dim y == 3 = fromList [z1,z2,z3]
98 z3 = x1*y2-x2*y1 113 z3 = x1*y2-x2*y1
99 114
100norm :: Vector Double -> Double 115norm :: Vector Double -> Double
101-- ^ 2-norm of real vectors 116-- ^ 2-norm of real vector
102norm = pnorm PNorm2 117norm = pnorm PNorm2
103 118
119--------------------------------------------------------------------------------
120
121vec :: Element t => Matrix t -> Vector t
122-- ^ stacking of columns
123vec = flatten . trans
124
125
126vech :: Element t => Matrix t -> Vector t
127-- ^ half-vectorization (of the lower triangular part)
128vech m = join . zipWith f [0..] . toColumns $ m
129 where
130 f k v = subVector k (dim v - k) v
131
132
133dup :: (Num t, Num (Vector t), Element t) => Int -> Matrix t
134-- ^ duplication matrix (@'dup' k \<> 'vech' m == 'vec' m@, for symmetric m of 'dim' k)
135dup k = trans $ fromRows $ map f es
136 where
137 rs = zip [0..] (toRows (ident (k^(2::Int))))
138 es = [(i,j) | j <- [0..k-1], i <- [0..k-1], i>=j ]
139 f (i,j) | i == j = g (k*j + i)
140 | otherwise = g (k*j + i) + g (k*i + j)
141 g j = v
142 where
143 Just v = lookup j rs
144
145
146vtrans :: Element t => Int -> Matrix t -> Matrix t
147-- ^ generalized \"vector\" transposition: @'vtrans' 1 == 'trans'@, and @'vtrans' ('rows' m) m == 'asColumn' ('vec' m)@
148vtrans p m | r == 0 = fromBlocks . map (map asColumn . takesV (replicate q p)) . toColumns $ m
149 | otherwise = error $ "vtrans " ++ show p ++ " of matrix with " ++ show (rows m) ++ " rows"
150 where
151 (q,r) = divMod (rows m) p
104 152
diff --git a/lib/Numeric/LinearAlgebra/Util/Convolution.hs b/lib/Numeric/LinearAlgebra/Util/Convolution.hs
index b64b169..32cb188 100644
--- a/lib/Numeric/LinearAlgebra/Util/Convolution.hs
+++ b/lib/Numeric/LinearAlgebra/Util/Convolution.hs
@@ -12,12 +12,8 @@ Stability : provisional
12----------------------------------------------------------------------------- 12-----------------------------------------------------------------------------
13 13
14module Numeric.LinearAlgebra.Util.Convolution( 14module Numeric.LinearAlgebra.Util.Convolution(
15 -- * 1D 15 corr, conv, corrMin,
16 corr, conv, 16 corr2, conv2, separable
17 -- * 2D
18 corr2, conv2,
19 -- * Misc
20 separable, corrMin
21) where 17) where
22 18
23import Numeric.LinearAlgebra 19import Numeric.LinearAlgebra
@@ -32,8 +28,8 @@ corr :: Product t => Vector t -- ^ kernel
32 -> Vector t 28 -> Vector t
33{- ^ correlation 29{- ^ correlation
34 30
35@\> (fromList[1,2,3]) (fromList [1..10]) 31>>> corr (fromList[1,2,3]) (fromList [1..10])
36fromList [14.0,20.0,26.0,32.0,38.0,44.0,50.0,56.0]@ 32fromList [14.0,20.0,26.0,32.0,38.0,44.0,50.0,56.0]
37 33
38-} 34-}
39corr ker v | dim ker <= dim v = vectSS (dim ker) v <> ker 35corr ker v | dim ker <= dim v = vectSS (dim ker) v <> ker
@@ -43,8 +39,8 @@ corr ker v | dim ker <= dim v = vectSS (dim ker) v <> ker
43conv :: (Product t, Num t) => Vector t -> Vector t -> Vector t 39conv :: (Product t, Num t) => Vector t -> Vector t -> Vector t
44{- ^ convolution ('corr' with reversed kernel and padded input, equivalent to polynomial product) 40{- ^ convolution ('corr' with reversed kernel and padded input, equivalent to polynomial product)
45 41
46@\> conv (fromList[1,1]) (fromList [-1,1]) 42>>> conv (fromList[1,1]) (fromList [-1,1])
47fromList [-1.0,0.0,1.0]@ 43fromList [-1.0,0.0,1.0]
48 44
49-} 45-}
50conv ker v = corr ker' v' 46conv ker v = corr ker' v'
@@ -113,6 +109,6 @@ conv2 k m = corr2 (fliprl . flipud $ k) pm
113 109
114 110
115separable :: Element t => (Vector t -> Vector t) -> Matrix t -> Matrix t 111separable :: Element t => (Vector t -> Vector t) -> Matrix t -> Matrix t
116-- ^ 2D process implemented as separated 1D processes by rows and columns. 112-- ^ matrix computation implemented as separated vector operations by rows and columns.
117separable f = fromColumns . map f . toColumns . fromRows . map f . toRows 113separable f = fromColumns . map f . toColumns . fromRows . map f . toRows
118 114
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
index 4e9a521..99c0c91 100644
--- a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
+++ b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
@@ -42,7 +42,7 @@ import Data.Packed.Development(unsafeFromForeignPtr,unsafeToForeignPtr)
42import Control.Arrow((***)) 42import Control.Arrow((***))
43import Debug.Trace 43import Debug.Trace
44import Control.Monad(when) 44import Control.Monad(when)
45import Numeric.LinearAlgebra.Util.Convolution 45import Numeric.LinearAlgebra.Util hiding (ones,row,col)
46 46
47import Data.Packed.ST 47import Data.Packed.ST
48 48
@@ -435,6 +435,23 @@ convolutionTest = utest "convolution" ok
435 435
436-------------------------------------------------------------------------------- 436--------------------------------------------------------------------------------
437 437
438kroneckerTest = utest "kronecker" ok
439 where
440 a,x,b :: Matrix Double
441 a = (3><4) [1..]
442 x = (4><2) [3,5..]
443 b = (2><5) [0,5..]
444 v1 = vec (a <> x <> b)
445 v2 = (trans b `kronecker` a) <> vec x
446 s = trans b <> b
447 v3 = vec s
448 v4 = dup 5 <> vech s
449 ok = v1 == v2 && v3 == v4
450 && vtrans 1 a == trans a
451 && vtrans (rows a) a == asColumn (vec a)
452
453--------------------------------------------------------------------------------
454
438 455
439 456
440-- | All tests must pass with a maximum dimension of about 20 457-- | All tests must pass with a maximum dimension of about 20
@@ -611,6 +628,7 @@ runTests n = do
611 , conformTest 628 , conformTest
612 , accumTest 629 , accumTest
613 , convolutionTest 630 , convolutionTest
631 , kroneckerTest
614 ] 632 ]
615 when (errors c + failures c > 0) exitFailure 633 when (errors c + failures c > 0) exitFailure
616 return () 634 return ()