diff options
author | Alberto Ruiz <aruiz@um.es> | 2010-09-20 17:08:34 +0000 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2010-09-20 17:08:34 +0000 |
commit | 05908719a7323110ba1955038d8341a8b7483351 (patch) | |
tree | f1f1fe28a8db64675dacc7eb4ec79d36e8174588 /lib | |
parent | 482b533c3fbfcd75d6c5c1d3ce32585bf9fc2ad7 (diff) |
generalized diagRect
Diffstat (limited to 'lib')
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 5 | ||||
-rw-r--r-- | lib/Data/Packed/Matrix.hs | 50 | ||||
-rw-r--r-- | lib/Data/Packed/ST.hs | 8 | ||||
-rw-r--r-- | lib/Numeric/Container.hs | 15 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Algorithms.hs | 8 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Tests/Instances.hs | 2 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Tests/Properties.hs | 2 | ||||
-rw-r--r-- | lib/Numeric/Matrix.hs | 1 |
8 files changed, 44 insertions, 47 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index c0824a3..94b56cf 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -221,13 +221,13 @@ where r is the desired number of rows.) | |||
221 | , 9.0, 10.0, 11.0, 12.0 ]@ | 221 | , 9.0, 10.0, 11.0, 12.0 ]@ |
222 | 222 | ||
223 | -} | 223 | -} |
224 | reshape :: Element t => Int -> Vector t -> Matrix t | 224 | reshape :: Storable t => Int -> Vector t -> Matrix t |
225 | reshape c v = matrixFromVector RowMajor c v | 225 | reshape c v = matrixFromVector RowMajor c v |
226 | 226 | ||
227 | singleton x = reshape 1 (fromList [x]) | 227 | singleton x = reshape 1 (fromList [x]) |
228 | 228 | ||
229 | -- | application of a vector function on the flattened matrix elements | 229 | -- | application of a vector function on the flattened matrix elements |
230 | liftMatrix :: (Element a, Element b) => (Vector a -> Vector b) -> Matrix a -> Matrix b | 230 | liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b |
231 | liftMatrix f MC { icols = c, cdat = d } = matrixFromVector RowMajor c (f d) | 231 | liftMatrix f MC { icols = c, cdat = d } = matrixFromVector RowMajor c (f d) |
232 | liftMatrix f MF { icols = c, fdat = d } = matrixFromVector ColumnMajor c (f d) | 232 | liftMatrix f MF { icols = c, fdat = d } = matrixFromVector ColumnMajor c (f d) |
233 | 233 | ||
@@ -246,7 +246,6 @@ compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2 | |||
246 | ------------------------------------------------------------------ | 246 | ------------------------------------------------------------------ |
247 | 247 | ||
248 | -- | Supported element types for basic matrix operations. | 248 | -- | Supported element types for basic matrix operations. |
249 | --class (Storable a, Floating a) => Element a where | ||
250 | class (Storable a) => Element a where | 249 | class (Storable a) => Element a where |
251 | subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position | 250 | subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position |
252 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix | 251 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix |
diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs index b8c309c..ea16748 100644 --- a/lib/Data/Packed/Matrix.hs +++ b/lib/Data/Packed/Matrix.hs | |||
@@ -22,7 +22,7 @@ module Data.Packed.Matrix ( | |||
22 | Element, | 22 | Element, |
23 | Matrix,rows,cols, | 23 | Matrix,rows,cols, |
24 | (><), | 24 | (><), |
25 | trans, ctrans, | 25 | trans, |
26 | reshape, flatten, | 26 | reshape, flatten, |
27 | fromLists, toLists, buildMatrix, | 27 | fromLists, toLists, buildMatrix, |
28 | (@@>), | 28 | (@@>), |
@@ -33,7 +33,7 @@ module Data.Packed.Matrix ( | |||
33 | flipud, fliprl, | 33 | flipud, fliprl, |
34 | subMatrix, takeRows, dropRows, takeColumns, dropColumns, | 34 | subMatrix, takeRows, dropRows, takeColumns, dropColumns, |
35 | extractRows, | 35 | extractRows, |
36 | ident, diag, diagRect, takeDiag, | 36 | diagRect, takeDiag, |
37 | liftMatrix, liftMatrix2, liftMatrix2Auto, | 37 | liftMatrix, liftMatrix2, liftMatrix2Auto, |
38 | dispf, disps, dispcf, vecdisp, latexFormat, format, | 38 | dispf, disps, dispcf, vecdisp, latexFormat, format, |
39 | loadMatrix, saveMatrix, fromFile, fileDimensions, | 39 | loadMatrix, saveMatrix, fromFile, fileDimensions, |
@@ -169,28 +169,19 @@ fliprl m = fromColumns . reverse . toColumns $ m | |||
169 | 169 | ||
170 | ------------------------------------------------------------ | 170 | ------------------------------------------------------------ |
171 | 171 | ||
172 | -- | Creates a square matrix with a given diagonal. | 172 | {- | creates a rectangular diagonal matrix: |
173 | diag :: (Num a, Element a) => Vector a -> Matrix a | ||
174 | diag v = ST.runSTMatrix $ do | ||
175 | let d = dim v | ||
176 | m <- ST.newMatrix 0 d d | ||
177 | mapM_ (\k -> ST.writeMatrix m k k (v@>k)) [0..d-1] | ||
178 | return m | ||
179 | 173 | ||
180 | {- | creates a rectangular diagonal matrix | 174 | @> diagRect 7 (fromList [10,20,30]) 4 5 :: Matrix Double |
181 | 175 | (4><5) | |
182 | @> diagRect (constant 5 3) 3 4 :: Matrix Double | 176 | [ 10.0, 7.0, 7.0, 7.0, 7.0 |
183 | (3><4) | 177 | , 7.0, 20.0, 7.0, 7.0, 7.0 |
184 | [ 5.0, 0.0, 0.0, 0.0 | 178 | , 7.0, 7.0, 30.0, 7.0, 7.0 |
185 | , 0.0, 5.0, 0.0, 0.0 | 179 | , 7.0, 7.0, 7.0, 7.0, 7.0 ]@ |
186 | , 0.0, 0.0, 5.0, 0.0 ]@ | ||
187 | -} | 180 | -} |
188 | diagRect :: (Element t, Num t) => Vector t -> Int -> Int -> Matrix t | 181 | diagRect :: (Storable t) => t -> Vector t -> Int -> Int -> Matrix t |
189 | diagRect v r c | 182 | diagRect z v r c = ST.runSTMatrix $ do |
190 | | dim v < min r c = error "diagRect called with dim v < min r c" | 183 | m <- ST.newMatrix z r c |
191 | | otherwise = ST.runSTMatrix $ do | 184 | let d = min r c `min` (dim v) |
192 | m <- ST.newMatrix 0 r c | ||
193 | let d = min r c | ||
194 | mapM_ (\k -> ST.writeMatrix m k k (v@>k)) [0..d-1] | 185 | mapM_ (\k -> ST.writeMatrix m k k (v@>k)) [0..d-1] |
195 | return m | 186 | return m |
196 | 187 | ||
@@ -198,10 +189,6 @@ diagRect v r c | |||
198 | takeDiag :: (Element t) => Matrix t -> Vector t | 189 | takeDiag :: (Element t) => Matrix t -> Vector t |
199 | takeDiag m = fromList [flatten m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] | 190 | takeDiag m = fromList [flatten m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] |
200 | 191 | ||
201 | -- | creates the identity matrix of given dimension | ||
202 | ident :: (Num a, Element a) => Int -> Matrix a | ||
203 | ident n = diag (constantD 1 n) | ||
204 | |||
205 | ------------------------------------------------------------ | 192 | ------------------------------------------------------------ |
206 | 193 | ||
207 | {- | An easy way to create a matrix: | 194 | {- | An easy way to create a matrix: |
@@ -225,7 +212,7 @@ Example: | |||
225 | , 4.0, 5.0, 6.0 ]@ | 212 | , 4.0, 5.0, 6.0 ]@ |
226 | 213 | ||
227 | -} | 214 | -} |
228 | (><) :: (Element a) => Int -> Int -> [a] -> Matrix a | 215 | (><) :: (Storable a) => Int -> Int -> [a] -> Matrix a |
229 | r >< c = f where | 216 | r >< c = f where |
230 | f l | dim v == r*c = matrixFromVector RowMajor c v | 217 | f l | dim v == r*c = matrixFromVector RowMajor c v |
231 | | otherwise = error $ "inconsistent list size = " | 218 | | otherwise = error $ "inconsistent list size = " |
@@ -261,16 +248,13 @@ fromLists :: Element t => [[t]] -> Matrix t | |||
261 | fromLists = fromRows . map fromList | 248 | fromLists = fromRows . map fromList |
262 | 249 | ||
263 | -- | creates a 1-row matrix from a vector | 250 | -- | creates a 1-row matrix from a vector |
264 | asRow :: Element a => Vector a -> Matrix a | 251 | asRow :: Storable a => Vector a -> Matrix a |
265 | asRow v = reshape (dim v) v | 252 | asRow v = reshape (dim v) v |
266 | 253 | ||
267 | -- | creates a 1-column matrix from a vector | 254 | -- | creates a 1-column matrix from a vector |
268 | asColumn :: Element a => Vector a -> Matrix a | 255 | asColumn :: Storable a => Vector a -> Matrix a |
269 | asColumn v = reshape 1 v | 256 | asColumn v = reshape 1 v |
270 | 257 | ||
271 | -- | conjugate transpose | ||
272 | ctrans :: Element e => Matrix e -> Matrix e | ||
273 | ctrans = liftMatrix conjugateD . trans | ||
274 | 258 | ||
275 | 259 | ||
276 | {- | creates a Matrix of the specified size using the supplied function to | 260 | {- | creates a Matrix of the specified size using the supplied function to |
@@ -289,7 +273,7 @@ buildMatrix rc cc f = | |||
289 | 273 | ||
290 | ----------------------------------------------------- | 274 | ----------------------------------------------------- |
291 | 275 | ||
292 | fromArray2D :: (Element e) => Array (Int, Int) e -> Matrix e | 276 | fromArray2D :: (Storable e) => Array (Int, Int) e -> Matrix e |
293 | fromArray2D m = (r><c) (elems m) | 277 | fromArray2D m = (r><c) (elems m) |
294 | where ((r0,c0),(r1,c1)) = bounds m | 278 | where ((r0,c0),(r1,c1)) = bounds m |
295 | r = r1-r0+1 | 279 | r = r1-r0+1 |
diff --git a/lib/Data/Packed/ST.hs b/lib/Data/Packed/ST.hs index 48e35b4..652f43e 100644 --- a/lib/Data/Packed/ST.hs +++ b/lib/Data/Packed/ST.hs | |||
@@ -90,11 +90,11 @@ writeVector :: Storable t => STVector s t -> Int -> t -> ST s () | |||
90 | writeVector = safeIndexV unsafeWriteVector | 90 | writeVector = safeIndexV unsafeWriteVector |
91 | 91 | ||
92 | {-# NOINLINE newUndefinedVector #-} | 92 | {-# NOINLINE newUndefinedVector #-} |
93 | newUndefinedVector :: Element t => Int -> ST s (STVector s t) | 93 | newUndefinedVector :: Storable t => Int -> ST s (STVector s t) |
94 | newUndefinedVector = unsafeIOToST . fmap STVector . createVector | 94 | newUndefinedVector = unsafeIOToST . fmap STVector . createVector |
95 | 95 | ||
96 | {-# INLINE newVector #-} | 96 | {-# INLINE newVector #-} |
97 | newVector :: Element t => t -> Int -> ST s (STVector s t) | 97 | newVector :: Storable t => t -> Int -> ST s (STVector s t) |
98 | newVector x n = do | 98 | newVector x n = do |
99 | v <- newUndefinedVector n | 99 | v <- newUndefinedVector n |
100 | let go (-1) = return v | 100 | let go (-1) = return v |
@@ -164,9 +164,9 @@ writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s () | |||
164 | writeMatrix = safeIndexM unsafeWriteMatrix | 164 | writeMatrix = safeIndexM unsafeWriteMatrix |
165 | 165 | ||
166 | {-# NOINLINE newUndefinedMatrix #-} | 166 | {-# NOINLINE newUndefinedMatrix #-} |
167 | newUndefinedMatrix :: Element t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t) | 167 | newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t) |
168 | newUndefinedMatrix order r c = unsafeIOToST $ fmap STMatrix $ createMatrix order r c | 168 | newUndefinedMatrix order r c = unsafeIOToST $ fmap STMatrix $ createMatrix order r c |
169 | 169 | ||
170 | {-# NOINLINE newMatrix #-} | 170 | {-# NOINLINE newMatrix #-} |
171 | newMatrix :: Element t => t -> Int -> Int -> ST s (STMatrix s t) | 171 | newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t) |
172 | newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c) | 172 | newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c) |
diff --git a/lib/Numeric/Container.hs b/lib/Numeric/Container.hs index 83bf44e..1afc5a1 100644 --- a/lib/Numeric/Container.hs +++ b/lib/Numeric/Container.hs | |||
@@ -22,6 +22,7 @@ | |||
22 | module Numeric.Container ( | 22 | module Numeric.Container ( |
23 | -- * Generic operations | 23 | -- * Generic operations |
24 | Container(..), | 24 | Container(..), |
25 | ctrans, diag, ident, | ||
25 | -- * Matrix product and related functions | 26 | -- * Matrix product and related functions |
26 | Product(..), | 27 | Product(..), |
27 | mXm,mXv,vXm, | 28 | mXm,mXv,vXm, |
@@ -221,6 +222,20 @@ instance (Container Vector a) => Container Matrix a where | |||
221 | 222 | ||
222 | ---------------------------------------------------- | 223 | ---------------------------------------------------- |
223 | 224 | ||
225 | -- | conjugate transpose | ||
226 | ctrans :: Element e => Matrix e -> Matrix e | ||
227 | ctrans = liftMatrix conjugateD . trans | ||
228 | |||
229 | -- | Creates a square matrix with a given diagonal. | ||
230 | diag :: (Num a, Element a) => Vector a -> Matrix a | ||
231 | diag v = diagRect 0 v n n where n = dim v | ||
232 | |||
233 | -- | creates the identity matrix of given dimension | ||
234 | ident :: (Num a, Element a) => Int -> Matrix a | ||
235 | ident n = diag (constantD 1 n) | ||
236 | |||
237 | ---------------------------------------------------- | ||
238 | |||
224 | 239 | ||
225 | -- | Matrix product and related functions | 240 | -- | Matrix product and related functions |
226 | class Element e => Product e where | 241 | class Element e => Product e where |
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs index 4f6f54d..394a1d7 100644 --- a/lib/Numeric/LinearAlgebra/Algorithms.hs +++ b/lib/Numeric/LinearAlgebra/Algorithms.hs | |||
@@ -183,7 +183,7 @@ singularValues = {-# SCC "singularValues" #-} sv' | |||
183 | fullSVD :: Field t => Matrix t -> (Matrix t, Matrix Double, Matrix t) | 183 | fullSVD :: Field t => Matrix t -> (Matrix t, Matrix Double, Matrix t) |
184 | fullSVD m = (u,d,v) where | 184 | fullSVD m = (u,d,v) where |
185 | (u,s,v) = svd m | 185 | (u,s,v) = svd m |
186 | d = diagRect s r c | 186 | d = diagRect 0 s r c |
187 | r = rows m | 187 | r = rows m |
188 | c = cols m | 188 | c = cols m |
189 | 189 | ||
@@ -210,7 +210,7 @@ leftSV m | vertical m = let (u,s,_) = svd m in (u,s) | |||
210 | {-# DEPRECATED full "use fullSVD instead" #-} | 210 | {-# DEPRECATED full "use fullSVD instead" #-} |
211 | full svdFun m = (u, d ,v) where | 211 | full svdFun m = (u, d ,v) where |
212 | (u,s,v) = svdFun m | 212 | (u,s,v) = svdFun m |
213 | d = diagRect s r c | 213 | d = diagRect 0 s r c |
214 | r = rows m | 214 | r = rows m |
215 | c = cols m | 215 | c = cols m |
216 | 216 | ||
@@ -624,10 +624,10 @@ luFact (l_u,perm) | r <= c = (l ,u ,p, s) | |||
624 | c = cols l_u | 624 | c = cols l_u |
625 | tu = triang r c 0 1 | 625 | tu = triang r c 0 1 |
626 | tl = triang r c 0 0 | 626 | tl = triang r c 0 0 |
627 | l = takeColumns r (l_u |*| tl) |+| diagRect (konst 1 r) r r | 627 | l = takeColumns r (l_u |*| tl) |+| diagRect 0 (konst 1 r) r r |
628 | u = l_u |*| tu | 628 | u = l_u |*| tu |
629 | (p,s) = fixPerm r perm | 629 | (p,s) = fixPerm r perm |
630 | l' = (l_u |*| tl) |+| diagRect (konst 1 c) r c | 630 | l' = (l_u |*| tl) |+| diagRect 0 (konst 1 c) r c |
631 | u' = takeRows c (l_u |*| tu) | 631 | u' = takeRows c (l_u |*| tu) |
632 | (|+|) = add | 632 | (|+|) = add |
633 | (|*|) = mul | 633 | (|*|) = mul |
diff --git a/lib/Numeric/LinearAlgebra/Tests/Instances.hs b/lib/Numeric/LinearAlgebra/Tests/Instances.hs index 804c481..771739a 100644 --- a/lib/Numeric/LinearAlgebra/Tests/Instances.hs +++ b/lib/Numeric/LinearAlgebra/Tests/Instances.hs | |||
@@ -150,7 +150,7 @@ instance (ArbitraryField a) => Arbitrary (WC a) where | |||
150 | c = cols m | 150 | c = cols m |
151 | n = min r c | 151 | n = min r c |
152 | sv' <- replicateM n (choose (1,100)) | 152 | sv' <- replicateM n (choose (1,100)) |
153 | let s = diagRect (fromList sv') r c | 153 | let s = diagRect 0 (fromList sv') r c |
154 | return $ WC (u <> real s <> trans v) | 154 | return $ WC (u <> real s <> trans v) |
155 | 155 | ||
156 | #if MIN_VERSION_QuickCheck(2,0,0) | 156 | #if MIN_VERSION_QuickCheck(2,0,0) |
diff --git a/lib/Numeric/LinearAlgebra/Tests/Properties.hs b/lib/Numeric/LinearAlgebra/Tests/Properties.hs index 623b78c..a35f591 100644 --- a/lib/Numeric/LinearAlgebra/Tests/Properties.hs +++ b/lib/Numeric/LinearAlgebra/Tests/Properties.hs | |||
@@ -138,7 +138,7 @@ svdProp1 m = m |~| u <> real d <> trans v && unitary u && unitary v | |||
138 | 138 | ||
139 | svdProp1a svdfun m = m |~| u <> real d <> trans v && unitary u && unitary v where | 139 | svdProp1a svdfun m = m |~| u <> real d <> trans v && unitary u && unitary v where |
140 | (u,s,v) = svdfun m | 140 | (u,s,v) = svdfun m |
141 | d = diagRect s (rows m) (cols m) | 141 | d = diagRect 0 s (rows m) (cols m) |
142 | 142 | ||
143 | svdProp1b svdfun m = unitary u && unitary v where | 143 | svdProp1b svdfun m = unitary u && unitary v where |
144 | (u,_,v) = svdfun m | 144 | (u,_,v) = svdfun m |
diff --git a/lib/Numeric/Matrix.hs b/lib/Numeric/Matrix.hs index d5c6f44..9260bd5 100644 --- a/lib/Numeric/Matrix.hs +++ b/lib/Numeric/Matrix.hs | |||
@@ -28,7 +28,6 @@ module Numeric.Matrix ( | |||
28 | -- * Basic functions | 28 | -- * Basic functions |
29 | module Data.Packed.Matrix, | 29 | module Data.Packed.Matrix, |
30 | module Numeric.Vector, | 30 | module Numeric.Vector, |
31 | --module Numeric.Container, | ||
32 | optimiseMult, | 31 | optimiseMult, |
33 | -- * Operators | 32 | -- * Operators |
34 | (<>), (<\>) | 33 | (<>), (<\>) |