summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/Numeric/Container.hs99
-rw-r--r--lib/Numeric/HMatrix.hs43
-rw-r--r--lib/Numeric/HMatrix/Data.hs6
-rw-r--r--lib/Numeric/HMatrix/Devel.hs13
-rw-r--r--lib/Numeric/LinearAlgebra/Util.hs8
-rw-r--r--lib/Numeric/Matrix.hs27
6 files changed, 143 insertions, 53 deletions
diff --git a/lib/Numeric/Container.hs b/lib/Numeric/Container.hs
index d1ce588..a71fdfe 100644
--- a/lib/Numeric/Container.hs
+++ b/lib/Numeric/Container.hs
@@ -37,7 +37,8 @@ module Numeric.Container (
37 Container(..), 37 Container(..),
38 -- * Matrix product 38 -- * Matrix product
39 Product(..), 39 Product(..),
40 Contraction(..), 40 Mul(..),
41 Contraction(..), mmul,
41 optimiseMult, 42 optimiseMult,
42 mXm,mXv,vXm,LSDiv(..), cdot, (·), dot, (<.>), 43 mXm,mXv,vXm,LSDiv(..), cdot, (·), dot, (<.>),
43 outer, kronecker, 44 outer, kronecker,
@@ -102,12 +103,19 @@ cdot u v = udot (conj u) v
102 103
103-------------------------------------------------------- 104--------------------------------------------------------
104 105
105class Contraction a b c | a b -> c, a c -> b, b c -> a 106class Contraction a b c | a b -> c, c -> a b
106 where 107 where
107 infixl 7 <> 108 infixr 7 ×
108 {- | Matrix-matrix product, matrix-vector product, and unconjugated dot product 109 {- | Matrix-matrix product, matrix-vector product, and unconjugated dot product
109 110
110>>> let a = (3><4) [1..] :: Matrix Double 111(unicode 0x00d7, multiplication sign)
112
113Examples:
114
115>>> let a = (3><4) [1..] :: Matrix Double
116>>> let v = fromList [1,0,2,-1] :: Vector Double
117>>> let u = fromList [1,2,3] :: Vector Double
118
111>>> a 119>>> a
112(3><4) 120(3><4)
113 [ 1.0, 2.0, 3.0, 4.0 121 [ 1.0, 2.0, 3.0, 4.0
@@ -116,7 +124,7 @@ class Contraction a b c | a b -> c, a c -> b, b c -> a
116 124
117matrix × matrix: 125matrix × matrix:
118 126
119>>> disp 2 (a <> trans a) 127>>> disp 2 (a × trans a)
1203x3 1283x3
121 30 70 110 129 30 70 110
122 70 174 278 130 70 174 278
@@ -124,52 +132,79 @@ matrix × matrix:
124 132
125matrix × vector: 133matrix × vector:
126 134
127>>> a <> fromList [1,0,2,-1::Double] 135>>> a × v
128fromList [3.0,11.0,19.0] 136fromList [3.0,11.0,19.0]
129 137
130vector × matrix:
131
132>>> fromList [1,2,3::Double] <> a
133fromList [38.0,44.0,50.0,56.0]
134
135unconjugated dot product: 138unconjugated dot product:
136 139
137>>> fromList [1,i] <> fromList[2*i+1,3] 140>>> fromList [1,i] × fromList[2*i+1,3]
1381.0 :+ 5.0 1411.0 :+ 5.0
139 142
140-} 143(×) is right associative, so we can write:
141 (<>) :: a -> b -> c
142 144
143instance Product t => Contraction (Vector t) (Vector t) t where 145>>> u × a × v
144 (<>) = udot 14682.0 :: Double
145 147
146instance Product t => Contraction (Matrix t) (Vector t) (Vector t) where 148-}
147 (<>) = mXv 149 (×) :: a -> b -> c
148 150
149instance Product t => Contraction (Vector t) (Matrix t) (Vector t) where 151instance Product t => Contraction (Matrix t) (Vector t) (Vector t) where
150 (<>) = vXm 152 (×) = mXv
151 153
152instance Product t => Contraction (Matrix t) (Matrix t) (Matrix t) where 154instance Product t => Contraction (Matrix t) (Matrix t) (Matrix t) where
155 (×) = mXm
156
157instance Contraction (Vector Double) (Vector Double) Double where
158 (×) = udot
159
160instance Contraction (Vector Float) (Vector Float) Float where
161 (×) = udot
162
163instance Contraction (Vector (Complex Double)) (Vector (Complex Double)) (Complex Double) where
164 (×) = udot
165
166instance Contraction (Vector (Complex Float)) (Vector (Complex Float)) (Complex Float) where
167 (×) = udot
168
169
170-- | alternative function for the matrix product (×)
171mmul :: Contraction a b c => a -> b -> c
172mmul = (×)
173
174--------------------------------------------------------------------------------
175
176class Mul a b c | a b -> c where
177 infixl 7 <>
178 -- | Matrix-matrix, matrix-vector, and vector-matrix products.
179 (<>) :: Product t => a t -> b t -> c t
180
181instance Mul Matrix Matrix Matrix where
153 (<>) = mXm 182 (<>) = mXm
154 183
155-------------------------------------------------------- 184instance Mul Matrix Vector Vector where
185 (<>) m v = flatten $ m <> asColumn v
156 186
157class LSDiv b c | b -> c, c->b where 187instance Mul Vector Matrix Vector where
188 (<>) v m = flatten $ asRow v <> m
189
190--------------------------------------------------------------------------------
191
192class LSDiv c where
158 infixl 7 <\> 193 infixl 7 <\>
159 -- | least squares solution of a linear system, similar to the \\ operator of Matlab\/Octave (based on linearSolveSVD) 194 -- | least squares solution of a linear system, similar to the \\ operator of Matlab\/Octave (based on linearSolveSVD)
160 (<\>) :: Field t => Matrix t -> b t -> c t 195 (<\>) :: Field t => Matrix t -> c t -> c t
161 196
162instance LSDiv Vector Vector where 197instance LSDiv Vector where
163 m <\> v = flatten (linearSolveSVD m (reshape 1 v)) 198 m <\> v = flatten (linearSolveSVD m (reshape 1 v))
164 199
165instance LSDiv Matrix Matrix where 200instance LSDiv Matrix where
166 (<\>) = linearSolveSVD 201 (<\>) = linearSolveSVD
167 202
168-------------------------------------------------------- 203--------------------------------------------------------
169 204
170{- | Dot product : @u · v = 'cdot' u v@ 205{- | Dot product : @u · v = 'cdot' u v@
171 206
172 (unicode 0x00b7, Alt-Gr .) 207 (unicode 0x00b7, middle dot, Alt-Gr .)
173 208
174>>> fromList [1,i] · fromList[2*i+1,3] 209>>> fromList [1,i] · fromList[2*i+1,3]
1751.0 :+ (-1.0) 2101.0 :+ (-1.0)
@@ -233,7 +268,15 @@ instance Container Matrix e => Build (Int,Int) (e -> e -> e) Matrix e
233 268
234-------------------------------------------------------------------------------- 269--------------------------------------------------------------------------------
235 270
236-- | Compute mean vector and covariance matrix of the rows of a matrix. 271{- | Compute mean vector and covariance matrix of the rows of a matrix.
272
273>>> meanCov $ gaussianSample 666 1000 (fromList[4,5]) (diagl[2,3])
274(fromList [4.010341078059521,5.0197204699640405],
275(2><2)
276 [ 1.9862461923890056, -1.0127225830525157e-2
277 , -1.0127225830525157e-2, 3.0373954915729318 ])
278
279-}
237meanCov :: Matrix Double -> (Vector Double, Matrix Double) 280meanCov :: Matrix Double -> (Vector Double, Matrix Double)
238meanCov x = (med,cov) where 281meanCov x = (med,cov) where
239 r = rows x 282 r = rows x
@@ -249,7 +292,7 @@ meanCov x = (med,cov) where
249dot :: Product e => Vector e -> Vector e -> e 292dot :: Product e => Vector e -> Vector e -> e
250dot = udot 293dot = udot
251 294
252{-# DEPRECATED (<.>) "use udot or (<>)" #-} 295{-# DEPRECATED (<.>) "use udot or (×)" #-}
253infixl 7 <.> 296infixl 7 <.>
254(<.>) :: Product e => Vector e -> Vector e -> e 297(<.>) :: Product e => Vector e -> Vector e -> e
255(<.>) = udot 298(<.>) = udot
diff --git a/lib/Numeric/HMatrix.hs b/lib/Numeric/HMatrix.hs
index 8e0b4a2..a2f09df 100644
--- a/lib/Numeric/HMatrix.hs
+++ b/lib/Numeric/HMatrix.hs
@@ -16,25 +16,45 @@ module Numeric.HMatrix (
16 -- * Basic types and data processing 16 -- * Basic types and data processing
17 module Numeric.HMatrix.Data, 17 module Numeric.HMatrix.Data,
18 18
19 -- | The standard numeric classes are defined elementwise. 19 -- | The standard numeric classes are defined elementwise:
20 -- 20 --
21 -- >>> fromList [1,2,3] * fromList [3,0,-2 :: Double] 21 -- >>> fromList [1,2,3] * fromList [3,0,-2 :: Double]
22 -- fromList [3.0,0.0,-6.0] 22 -- fromList [3.0,0.0,-6.0]
23 -- 23 --
24 -- In arithmetic operations single-element vectors and matrices automatically 24 -- >>> (3><3) [1..9] * ident 3 :: Matrix Double
25 -- expand to match the dimensions of the other operand. 25 -- (3><3)
26 -- [ 1.0, 0.0, 0.0
27 -- , 0.0, 5.0, 0.0
28 -- , 0.0, 0.0, 9.0 ]
29 --
30 -- In arithmetic operations single-element vectors and matrices
31 -- (created from numeric literals or using 'scalar') automatically
32 -- expand to match the dimensions of the other operand:
26 -- 33 --
27 -- >>> 2 * ident 3 34 -- >>> 5 + 2*ident 3 :: Matrix Double
28 -- 2 * ident 3 :: Matrix Double
29 -- (3><3) 35 -- (3><3)
30 -- [ 2.0, 0.0, 0.0 36 -- [ 7.0, 5.0, 5.0
31 -- , 0.0, 2.0, 0.0 37 -- , 5.0, 7.0, 5.0
32 -- , 0.0, 0.0, 2.0 ] 38 -- , 5.0, 5.0, 7.0 ]
33 -- 39 --
34 40
35 -- * Products 41 -- * Products
36 (<>), (·), outer, kronecker, cross, 42 (×),
37 optimiseMult, scale, 43
44 -- | The matrix product is also implemented in the "Data.Monoid" instance for Matrix, where
45 -- single-element matrices (created from numeric literals or using 'scalar')
46 -- are used for scaling.
47 --
48 -- >>> let m = (2><3)[1..] :: Matrix Double
49 -- >>> m <> 2 <> diagl[0.5,1,0]
50 -- (2><3)
51 -- [ 1.0, 4.0, 0.0
52 -- , 4.0, 10.0, 0.0 ]
53 --
54 -- mconcat uses 'optimiseMult' to get the optimal association order.
55
56 (·), outer, kronecker, cross,
57 scale,
38 sumElements, prodElements, absSum, 58 sumElements, prodElements, absSum,
39 59
40 -- * Linear Systems 60 -- * Linear Systems
@@ -103,7 +123,7 @@ module Numeric.HMatrix (
103 rand, randn, RandDist(..), randomVector, gaussianSample, uniformSample, 123 rand, randn, RandDist(..), randomVector, gaussianSample, uniformSample,
104 124
105 -- * Misc 125 -- * Misc
106 meanCov, peps, relativeError, haussholder 126 meanCov, peps, relativeError, haussholder, optimiseMult, udot, cdot, mmul
107) where 127) where
108 128
109import Numeric.HMatrix.Data 129import Numeric.HMatrix.Data
@@ -114,4 +134,3 @@ import Numeric.Container
114import Numeric.LinearAlgebra.Algorithms 134import Numeric.LinearAlgebra.Algorithms
115import Numeric.LinearAlgebra.Util 135import Numeric.LinearAlgebra.Util
116 136
117
diff --git a/lib/Numeric/HMatrix/Data.hs b/lib/Numeric/HMatrix/Data.hs
index 49dad10..288b0af 100644
--- a/lib/Numeric/HMatrix/Data.hs
+++ b/lib/Numeric/HMatrix/Data.hs
@@ -51,12 +51,6 @@ module Numeric.HMatrix.Data(
51 51
52-- * Conversion 52-- * Conversion
53 Convert(..), 53 Convert(..),
54 Complexable(),
55 RealElement(),
56
57 RealOf, ComplexOf, SingleOf, DoubleOf,
58
59 IndexOf,
60 54
61 -- * Misc 55 -- * Misc
62 arctan2, 56 arctan2,
diff --git a/lib/Numeric/HMatrix/Devel.hs b/lib/Numeric/HMatrix/Devel.hs
index 37bf826..7363477 100644
--- a/lib/Numeric/HMatrix/Devel.hs
+++ b/lib/Numeric/HMatrix/Devel.hs
@@ -50,15 +50,20 @@ module Numeric.HMatrix.Devel(
50 mapMatrixWithIndex, mapMatrixWithIndexM, mapMatrixWithIndexM_, 50 mapMatrixWithIndex, mapMatrixWithIndexM, mapMatrixWithIndexM_,
51 liftMatrix, liftMatrix2, liftMatrix2Auto, 51 liftMatrix, liftMatrix2, liftMatrix2Auto,
52 52
53 -- * Misc 53 -- * Auxiliary classes
54 Element, Container, Product, Contraction, LSDiv, Field 54 Element, Container, Product, Contraction, LSDiv,
55 Complexable(), RealElement(),
56 RealOf, ComplexOf, SingleOf, DoubleOf,
57 IndexOf,
58 Field,
55) where 59) where
56 60
57import Data.Packed.Foreign 61import Data.Packed.Foreign
58import Data.Packed.Development 62import Data.Packed.Development
59import Data.Packed.ST 63import Data.Packed.ST
60import Numeric.Container(Container,Contraction,LSDiv,Product) 64import Numeric.Container(Container,Contraction,LSDiv,Product,
65 Complexable(),RealElement(),
66 RealOf, ComplexOf, SingleOf, DoubleOf, IndexOf)
61import Data.Packed 67import Data.Packed
62import Numeric.LinearAlgebra.Algorithms(Field) 68import Numeric.LinearAlgebra.Algorithms(Field)
63 69
64
diff --git a/lib/Numeric/LinearAlgebra/Util.hs b/lib/Numeric/LinearAlgebra/Util.hs
index 21b6188..7164827 100644
--- a/lib/Numeric/LinearAlgebra/Util.hs
+++ b/lib/Numeric/LinearAlgebra/Util.hs
@@ -134,7 +134,7 @@ a & b = vjoin [a,b]
134 134
135{- | horizontal concatenation of real matrices 135{- | horizontal concatenation of real matrices
136 136
137 (0x00a6 broken bar) 137 (unicode 0x00a6, broken bar)
138 138
139>>> ident 3 ¦ konst 7 (3,4) 139>>> ident 3 ¦ konst 7 (3,4)
140(3><7) 140(3><7)
@@ -149,7 +149,7 @@ a ¦ b = fromBlocks [[a,b]]
149 149
150-- | vertical concatenation of real matrices 150-- | vertical concatenation of real matrices
151-- 151--
152-- (0x2014, em dash) 152-- (unicode 0x2014, em dash)
153(——) :: Matrix Double -> Matrix Double -> Matrix Double 153(——) :: Matrix Double -> Matrix Double -> Matrix Double
154infixl 2 —— 154infixl 2 ——
155a —— b = fromBlocks [[a],[b]] 155a —— b = fromBlocks [[a],[b]]
@@ -179,7 +179,9 @@ infixl 9 ?
179(?) :: Element t => Matrix t -> [Int] -> Matrix t 179(?) :: Element t => Matrix t -> [Int] -> Matrix t
180(?) = flip extractRows 180(?) = flip extractRows
181 181
182-- | (00BF) extract selected columns 182-- | extract selected columns
183--
184-- (unicode 0x00bf, inverted question mark)
183infixl 9 ¿ 185infixl 9 ¿
184(¿) :: Element t => Matrix t -> [Int] -> Matrix t 186(¿) :: Element t => Matrix t -> [Int] -> Matrix t
185m ¿ ks = trans . extractRows ks . trans $ m 187m ¿ ks = trans . extractRows ks . trans $ m
diff --git a/lib/Numeric/Matrix.hs b/lib/Numeric/Matrix.hs
index 8397911..e285ff2 100644
--- a/lib/Numeric/Matrix.hs
+++ b/lib/Numeric/Matrix.hs
@@ -28,6 +28,8 @@ module Numeric.Matrix (
28------------------------------------------------------------------- 28-------------------------------------------------------------------
29 29
30import Numeric.Container 30import Numeric.Container
31import qualified Data.Monoid as M
32import Data.List(partition)
31 33
32------------------------------------------------------------------- 34-------------------------------------------------------------------
33 35
@@ -69,3 +71,28 @@ instance (Floating a, Container Vector a, Floating (Vector a), Fractional (Matri
69 (**) = liftMatrix2Auto (**) 71 (**) = liftMatrix2Auto (**)
70 sqrt = liftMatrix sqrt 72 sqrt = liftMatrix sqrt
71 pi = (1><1) [pi] 73 pi = (1><1) [pi]
74
75--------------------------------------------------------------------------------
76
77isScalar m = rows m == 1 && cols m == 1
78
79adaptScalarM f1 f2 f3 x y
80 | isScalar x = f1 (x @@>(0,0) ) y
81 | isScalar y = f3 x (y @@>(0,0) )
82 | otherwise = f2 x y
83
84instance (Container Vector t, Eq t, Num (Vector t), Product t) => M.Monoid (Matrix t)
85 where
86 mempty = 1
87 mappend = adaptScalarM scale mXm (flip scale)
88
89 mconcat xs = work (partition isScalar xs)
90 where
91 work (ss,[]) = product ss
92 work (ss,ms) = scale' (product ss) (optimiseMult ms)
93 scale' x m
94 | isScalar x && x00 == 1 = m
95 | otherwise = scale x00 m
96 where
97 x00 = x @@> (0,0)
98