summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2014-04-24 10:30:01 +0200
committerAlberto Ruiz <aruiz@um.es>2014-04-24 10:30:01 +0200
commit6fbed842525491e280448a00a4b5426e6830ccaa (patch)
treeb78d2712f7ac23845fc29120d3a9fbcd7d189004
parent7c5adb83c9cb632c39eb2d844a1496e2a7a23e8b (diff)
cdot and (×)
added cdot dot renamed to udot <.> changed to cdot and moved to Numeric.LinearAlgebra.Util new general contraction operator (×) Plot functions moved to Numeric.LinearAlgebra.Util
-rw-r--r--examples/multiply.hs10
-rw-r--r--hmatrix.cabal4
-rw-r--r--lib/Numeric/Container.hs49
-rw-r--r--lib/Numeric/ContainerBoot.hs12
-rw-r--r--lib/Numeric/LinearAlgebra/Util.hs24
-rw-r--r--lib/Numeric/LinearAlgebra/Util/Convolution.hs2
-rw-r--r--packages/tests/src/Numeric/LinearAlgebra/Tests.hs26
7 files changed, 93 insertions, 34 deletions
diff --git a/examples/multiply.hs b/examples/multiply.hs
index d7c74ee..fbfb9d7 100644
--- a/examples/multiply.hs
+++ b/examples/multiply.hs
@@ -6,7 +6,7 @@
6-- , OverlappingInstances 6-- , OverlappingInstances
7 , UndecidableInstances #-} 7 , UndecidableInstances #-}
8 8
9import Numeric.LinearAlgebra 9import Numeric.LinearAlgebra hiding (Contraction(..))
10 10
11class Scaling a b c | a b -> c where 11class Scaling a b c | a b -> c where
12 -- ^ 0x22C5 8901 DOT OPERATOR, scaling 12 -- ^ 0x22C5 8901 DOT OPERATOR, scaling
@@ -43,7 +43,7 @@ instance Container Vector t => Scaling (Matrix t) t (Matrix t) where
43 43
44 44
45instance Product t => Contraction (Vector t) (Vector t) t where 45instance Product t => Contraction (Vector t) (Vector t) t where
46 (×) = dot 46 (×) = udot
47 47
48instance Product t => Contraction (Matrix t) (Vector t) (Vector t) where 48instance Product t => Contraction (Matrix t) (Vector t) (Vector t) where
49 (×) = mXv 49 (×) = mXv
@@ -90,9 +90,9 @@ c = v ⊗ m ⊗ v ⊗ m
90d = s ⋅ (3 |> [10,20..] :: Vector Double) 90d = s ⋅ (3 |> [10,20..] :: Vector Double)
91 91
92main = do 92main = do
93 print $ scale s v <> m <.> v 93 print $ (scale s v <> m) `udot` v
94 print $ scale s v <.> (m <> v) 94 print $ scale s v `udot` (m <> v)
95 print $ s * (v <> m <.> v) 95 print $ s * ((v <> m) `udot` v)
96 print $ s ⋅ v × m × v 96 print $ s ⋅ v × m × v
97 print a 97 print a
98 print (b == c) 98 print (b == c)
diff --git a/hmatrix.cabal b/hmatrix.cabal
index 23e81dd..e9107f3 100644
--- a/hmatrix.cabal
+++ b/hmatrix.cabal
@@ -113,7 +113,6 @@ library
113 Numeric.LinearAlgebra.LAPACK, 113 Numeric.LinearAlgebra.LAPACK,
114 Numeric.LinearAlgebra.Algorithms, 114 Numeric.LinearAlgebra.Algorithms,
115 Numeric.LinearAlgebra.Util, 115 Numeric.LinearAlgebra.Util,
116 Graphics.Plot,
117 Data.Packed.ST, 116 Data.Packed.ST,
118 Data.Packed.Development 117 Data.Packed.Development
119 other-modules: Data.Packed.Internal, 118 other-modules: Data.Packed.Internal,
@@ -130,7 +129,8 @@ library
130 Numeric.Chain, 129 Numeric.Chain,
131 Numeric.Vector, 130 Numeric.Vector,
132 Numeric.Matrix, 131 Numeric.Matrix,
133 Numeric.LinearAlgebra.Util.Convolution 132 Numeric.LinearAlgebra.Util.Convolution,
133 Graphics.Plot
134 134
135 C-sources: lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c, 135 C-sources: lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c,
136 lib/Numeric/GSL/gsl-aux.c 136 lib/Numeric/GSL/gsl-aux.c
diff --git a/lib/Numeric/Container.hs b/lib/Numeric/Container.hs
index 345c1f1..ed6714f 100644
--- a/lib/Numeric/Container.hs
+++ b/lib/Numeric/Container.hs
@@ -8,8 +8,8 @@
8----------------------------------------------------------------------------- 8-----------------------------------------------------------------------------
9-- | 9-- |
10-- Module : Numeric.Container 10-- Module : Numeric.Container
11-- Copyright : (c) Alberto Ruiz 2010 11-- Copyright : (c) Alberto Ruiz 2010-14
12-- License : GPL-style 12-- License : GPL
13-- 13--
14-- Maintainer : Alberto Ruiz <aruiz@um.es> 14-- Maintainer : Alberto Ruiz <aruiz@um.es>
15-- Stability : provisional 15-- Stability : provisional
@@ -35,8 +35,9 @@ module Numeric.Container (
35 Container(..), 35 Container(..),
36 -- * Matrix product 36 -- * Matrix product
37 Product(..), 37 Product(..),
38 Contraction(..),
38 optimiseMult, 39 optimiseMult,
39 mXm,mXv,vXm,(<.>),Mul(..),LSDiv(..), 40 mXm,mXv,vXm,Mul(..),LSDiv(..), cdot,
40 outer, kronecker, 41 outer, kronecker,
41 -- * Random numbers 42 -- * Random numbers
42 RandDist(..), 43 RandDist(..),
@@ -95,12 +96,9 @@ linspace :: (Enum e, Container Vector e) => Int -> (e, e) -> Vector e
95linspace n (a,b) = addConstant a $ scale s $ fromList [0 .. fromIntegral n-1] 96linspace n (a,b) = addConstant a $ scale s $ fromList [0 .. fromIntegral n-1]
96 where s = (b-a)/fromIntegral (n-1) 97 where s = (b-a)/fromIntegral (n-1)
97 98
98-- | Dot product: @u \<.\> v = dot u v@ 99-- | dot product: @cdot u v = 'udot' ('conj' u) v@
99(<.>) :: Product t => Vector t -> Vector t -> t 100cdot :: (Container Vector t, Product t) => Vector t -> Vector t -> t
100infixl 7 <.> 101cdot u v = udot (conj u) v
101(<.>) = dot
102
103
104 102
105-------------------------------------------------------- 103--------------------------------------------------------
106 104
@@ -143,3 +141,36 @@ meanCov x = (med,cov) where
143 xc = x `sub` meds 141 xc = x `sub` meds
144 cov = scale (recip (fromIntegral (r-1))) (trans xc `mXm` xc) 142 cov = scale (recip (fromIntegral (r-1))) (trans xc `mXm` xc)
145 143
144--------------------------------------------------------------------------------
145
146-- | matrix-matrix product, matrix-vector product, unconjugated dot product, and scaling
147class Contraction a b c | a b -> c
148 where
149 -- ^ 0x00d7 multiplication sign
150 infixl 7 ×
151 (×) :: a -> b -> c
152
153instance Product t => Contraction (Vector t) (Vector t) t where
154 (×) = udot
155
156instance Product t => Contraction (Matrix t) (Vector t) (Vector t) where
157 (×) = mXv
158
159instance Product t => Contraction (Vector t) (Matrix t) (Vector t) where
160 (×) = vXm
161
162instance Product t => Contraction (Matrix t) (Matrix t) (Matrix t) where
163 (×) = mXm
164
165instance Container Vector t => Contraction t (Vector t) (Vector t) where
166 (×) = scale
167
168instance Container Vector t => Contraction (Vector t) t (Vector t) where
169 (×) = flip scale
170
171instance Container Matrix t => Contraction t (Matrix t) (Matrix t) where
172 (×) = scale
173
174instance Container Matrix t => Contraction (Matrix t) t (Matrix t) where
175 (×) = flip scale
176
diff --git a/lib/Numeric/ContainerBoot.hs b/lib/Numeric/ContainerBoot.hs
index dcb326c..4c5bbd0 100644
--- a/lib/Numeric/ContainerBoot.hs
+++ b/lib/Numeric/ContainerBoot.hs
@@ -327,8 +327,8 @@ instance (Container Vector a) => Container Matrix a where
327class Element e => Product e where 327class Element e => Product e where
328 -- | matrix product 328 -- | matrix product
329 multiply :: Matrix e -> Matrix e -> Matrix e 329 multiply :: Matrix e -> Matrix e -> Matrix e
330 -- | dot (inner) product 330 -- | (unconjugated) dot product
331 dot :: Vector e -> Vector e -> e 331 udot :: Vector e -> Vector e -> e
332 -- | sum of absolute value of elements (differs in complex case from @norm1@) 332 -- | sum of absolute value of elements (differs in complex case from @norm1@)
333 absSum :: Vector e -> RealOf e 333 absSum :: Vector e -> RealOf e
334 -- | sum of absolute value of elements 334 -- | sum of absolute value of elements
@@ -341,7 +341,7 @@ class Element e => Product e where
341instance Product Float where 341instance Product Float where
342 norm2 = toScalarF Norm2 342 norm2 = toScalarF Norm2
343 absSum = toScalarF AbsSum 343 absSum = toScalarF AbsSum
344 dot = dotF 344 udot = dotF
345 norm1 = toScalarF AbsSum 345 norm1 = toScalarF AbsSum
346 normInf = maxElement . vectorMapF Abs 346 normInf = maxElement . vectorMapF Abs
347 multiply = multiplyF 347 multiply = multiplyF
@@ -349,7 +349,7 @@ instance Product Float where
349instance Product Double where 349instance Product Double where
350 norm2 = toScalarR Norm2 350 norm2 = toScalarR Norm2
351 absSum = toScalarR AbsSum 351 absSum = toScalarR AbsSum
352 dot = dotR 352 udot = dotR
353 norm1 = toScalarR AbsSum 353 norm1 = toScalarR AbsSum
354 normInf = maxElement . vectorMapR Abs 354 normInf = maxElement . vectorMapR Abs
355 multiply = multiplyR 355 multiply = multiplyR
@@ -357,7 +357,7 @@ instance Product Double where
357instance Product (Complex Float) where 357instance Product (Complex Float) where
358 norm2 = toScalarQ Norm2 358 norm2 = toScalarQ Norm2
359 absSum = toScalarQ AbsSum 359 absSum = toScalarQ AbsSum
360 dot = dotQ 360 udot = dotQ
361 norm1 = sumElements . fst . fromComplex . vectorMapQ Abs 361 norm1 = sumElements . fst . fromComplex . vectorMapQ Abs
362 normInf = maxElement . fst . fromComplex . vectorMapQ Abs 362 normInf = maxElement . fst . fromComplex . vectorMapQ Abs
363 multiply = multiplyQ 363 multiply = multiplyQ
@@ -365,7 +365,7 @@ instance Product (Complex Float) where
365instance Product (Complex Double) where 365instance Product (Complex Double) where
366 norm2 = toScalarC Norm2 366 norm2 = toScalarC Norm2
367 absSum = toScalarC AbsSum 367 absSum = toScalarC AbsSum
368 dot = dotC 368 udot = dotC
369 norm1 = sumElements . fst . fromComplex . vectorMapC Abs 369 norm1 = sumElements . fst . fromComplex . vectorMapC Abs
370 normInf = maxElement . fst . fromComplex . vectorMapC Abs 370 normInf = maxElement . fst . fromComplex . vectorMapC Abs
371 multiply = multiplyC 371 multiply = multiplyC
diff --git a/lib/Numeric/LinearAlgebra/Util.hs b/lib/Numeric/LinearAlgebra/Util.hs
index f7c40d7..f6aa7da 100644
--- a/lib/Numeric/LinearAlgebra/Util.hs
+++ b/lib/Numeric/LinearAlgebra/Util.hs
@@ -19,10 +19,11 @@ module Numeric.LinearAlgebra.Util(
19 diagl, 19 diagl,
20 row, 20 row,
21 col, 21 col,
22 (&),(!), (¦), (#), 22 (&), (¦), (#),
23 (?),(¿), 23 (?), (¿),
24 rand, randn, 24 rand, randn,
25 cross, 25 cross,
26 (<.>),
26 norm, 27 norm,
27 unitary, 28 unitary,
28 mt, 29 mt,
@@ -45,7 +46,13 @@ module Numeric.LinearAlgebra.Util(
45 vec, 46 vec,
46 vech, 47 vech,
47 dup, 48 dup,
48 vtrans 49 vtrans,
50 -- * Plot
51 mplot,
52 plot, parametricPlot,
53 splot, mesh, meshdom,
54 matrixToPGM, imshow,
55 gnuplotX, gnuplotpdf, gnuplotWin
49) where 56) where
50 57
51import Numeric.Container 58import Numeric.Container
@@ -55,6 +62,7 @@ import Numeric.Vector()
55 62
56import System.Random(randomIO) 63import System.Random(randomIO)
57import Numeric.LinearAlgebra.Util.Convolution 64import Numeric.LinearAlgebra.Util.Convolution
65import Graphics.Plot
58 66
59 67
60disp :: Int -> Matrix Double -> IO () 68disp :: Int -> Matrix Double -> IO ()
@@ -99,11 +107,6 @@ infixl 3 &
99(&) :: Vector Double -> Vector Double -> Vector Double 107(&) :: Vector Double -> Vector Double -> Vector Double
100a & b = vjoin [a,b] 108a & b = vjoin [a,b]
101 109
102-- | horizontal concatenation of real matrices
103infixl 3 !
104(!) :: Matrix Double -> Matrix Double -> Matrix Double
105a ! b = fromBlocks [[a,b]]
106
107-- | (00A6) horizontal concatenation of real matrices 110-- | (00A6) horizontal concatenation of real matrices
108infixl 3 ¦ 111infixl 3 ¦
109(¦) :: Matrix Double -> Matrix Double -> Matrix Double 112(¦) :: Matrix Double -> Matrix Double -> Matrix Double
@@ -161,6 +164,11 @@ size m = (rows m, cols m)
161mt :: Matrix Double -> Matrix Double 164mt :: Matrix Double -> Matrix Double
162mt = trans . inv 165mt = trans . inv
163 166
167-- | dot product: @u \<.\> v = 'cdot' u v@
168(<.>) :: (Container Vector t, Product t) => Vector t -> Vector t -> t
169infixl 7 <.>
170u <.> v = cdot u v
171
164---------------------------------------------------------------------- 172----------------------------------------------------------------------
165 173
166-- | Matrix of pairwise squared distances of row vectors 174-- | Matrix of pairwise squared distances of row vectors
diff --git a/lib/Numeric/LinearAlgebra/Util/Convolution.hs b/lib/Numeric/LinearAlgebra/Util/Convolution.hs
index be9b1eb..1043614 100644
--- a/lib/Numeric/LinearAlgebra/Util/Convolution.hs
+++ b/lib/Numeric/LinearAlgebra/Util/Convolution.hs
@@ -75,7 +75,7 @@ matSS dr m = map (reshape c) [ subVector (k*c) n v | k <- [0 .. r - dr] ]
75corr2 :: Product a => Matrix a -> Matrix a -> Matrix a 75corr2 :: Product a => Matrix a -> Matrix a -> Matrix a
76-- ^ 2D correlation 76-- ^ 2D correlation
77corr2 ker mat = dims 77corr2 ker mat = dims
78 . concatMap (map ((<.> ker') . flatten) . matSS c . trans) 78 . concatMap (map (udot ker' . flatten) . matSS c . trans)
79 . matSS r $ mat 79 . matSS r $ mat
80 where 80 where
81 r = rows ker 81 r = rows ker
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
index 99c0c91..7e1799e 100644
--- a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
+++ b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
@@ -43,6 +43,8 @@ import Control.Arrow((***))
43import Debug.Trace 43import Debug.Trace
44import Control.Monad(when) 44import Control.Monad(when)
45import Numeric.LinearAlgebra.Util hiding (ones,row,col) 45import Numeric.LinearAlgebra.Util hiding (ones,row,col)
46import Control.Applicative
47import Control.Monad(ap)
46 48
47import Data.Packed.ST 49import Data.Packed.ST
48 50
@@ -266,9 +268,9 @@ normsVTest = TestList [
266 ] where v = fromList [1,-2,3:+4] :: Vector (Complex Double) 268 ] where v = fromList [1,-2,3:+4] :: Vector (Complex Double)
267 x = fromList [1,2,-3] :: Vector Double 269 x = fromList [1,2,-3] :: Vector Double
268#ifndef NONORMVTEST 270#ifndef NONORMVTEST
269 norm2PropR a = norm2 a =~= sqrt (dot a a) 271 norm2PropR a = norm2 a =~= sqrt (udot a a)
270#endif 272#endif
271 norm2PropC a = norm2 a =~= realPart (sqrt (dot a (conj a))) 273 norm2PropC a = norm2 a =~= realPart (sqrt (udot a (conj a)))
272 a =~= b = fromList [a] |~| fromList [b] 274 a =~= b = fromList [a] |~| fromList [b]
273 275
274normsMTest = TestList [ 276normsMTest = TestList [
@@ -330,6 +332,15 @@ conjuTest m = mapVector conjugate (flatten (trans m)) == flatten (ctrans m)
330 332
331newtype State s a = State { runState :: s -> (a,s) } 333newtype State s a = State { runState :: s -> (a,s) }
332 334
335instance Functor (State s)
336 where
337 fmap f x = pure f <*> x
338
339instance Applicative (State s)
340 where
341 pure = return
342 (<*>) = ap
343
333instance Monad (State s) where 344instance Monad (State s) where
334 return a = State $ \s -> (a,s) 345 return a = State $ \s -> (a,s)
335 m >>= f = State $ \s -> let (a,s') = runState m s 346 m >>= f = State $ \s -> let (a,s') = runState m s
@@ -347,6 +358,15 @@ evalState m s = let (a,s') = runState m s
347 358
348newtype MaybeT m a = MaybeT { runMaybeT :: m (Maybe a) } 359newtype MaybeT m a = MaybeT { runMaybeT :: m (Maybe a) }
349 360
361instance Monad m => Functor (MaybeT m)
362 where
363 fmap f x = pure f <*> x
364
365instance Monad m => Applicative (MaybeT m)
366 where
367 pure = return
368 (<*>) = ap
369
350instance Monad m => Monad (MaybeT m) where 370instance Monad m => Monad (MaybeT m) where
351 return a = MaybeT $ return $ Just a 371 return a = MaybeT $ return $ Just a
352 m >>= f = MaybeT $ do 372 m >>= f = MaybeT $ do
@@ -640,7 +660,7 @@ a |~~| b = a :~6~: b
640 660
641makeUnitary v | realPart n > 1 = v / scalar n 661makeUnitary v | realPart n > 1 = v / scalar n
642 | otherwise = v 662 | otherwise = v
643 where n = sqrt (conj v <.> v) 663 where n = sqrt (conj v `udot` v)
644 664
645-- -- | Some additional tests on big matrices. They take a few minutes. 665-- -- | Some additional tests on big matrices. They take a few minutes.
646-- runBigTests :: IO () 666-- runBigTests :: IO ()