summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs5
-rw-r--r--lib/Data/Packed/Matrix.hs50
-rw-r--r--lib/Data/Packed/ST.hs8
-rw-r--r--lib/Numeric/Container.hs15
-rw-r--r--lib/Numeric/LinearAlgebra/Algorithms.hs8
-rw-r--r--lib/Numeric/LinearAlgebra/Tests/Instances.hs2
-rw-r--r--lib/Numeric/LinearAlgebra/Tests/Properties.hs2
-rw-r--r--lib/Numeric/Matrix.hs1
8 files changed, 44 insertions, 47 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index c0824a3..94b56cf 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -221,13 +221,13 @@ where r is the desired number of rows.)
221 , 9.0, 10.0, 11.0, 12.0 ]@ 221 , 9.0, 10.0, 11.0, 12.0 ]@
222 222
223-} 223-}
224reshape :: Element t => Int -> Vector t -> Matrix t 224reshape :: Storable t => Int -> Vector t -> Matrix t
225reshape c v = matrixFromVector RowMajor c v 225reshape c v = matrixFromVector RowMajor c v
226 226
227singleton x = reshape 1 (fromList [x]) 227singleton x = reshape 1 (fromList [x])
228 228
229-- | application of a vector function on the flattened matrix elements 229-- | application of a vector function on the flattened matrix elements
230liftMatrix :: (Element a, Element b) => (Vector a -> Vector b) -> Matrix a -> Matrix b 230liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b
231liftMatrix f MC { icols = c, cdat = d } = matrixFromVector RowMajor c (f d) 231liftMatrix f MC { icols = c, cdat = d } = matrixFromVector RowMajor c (f d)
232liftMatrix f MF { icols = c, fdat = d } = matrixFromVector ColumnMajor c (f d) 232liftMatrix f MF { icols = c, fdat = d } = matrixFromVector ColumnMajor c (f d)
233 233
@@ -246,7 +246,6 @@ compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2
246------------------------------------------------------------------ 246------------------------------------------------------------------
247 247
248-- | Supported element types for basic matrix operations. 248-- | Supported element types for basic matrix operations.
249--class (Storable a, Floating a) => Element a where
250class (Storable a) => Element a where 249class (Storable a) => Element a where
251 subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position 250 subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position
252 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix 251 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs
index b8c309c..ea16748 100644
--- a/lib/Data/Packed/Matrix.hs
+++ b/lib/Data/Packed/Matrix.hs
@@ -22,7 +22,7 @@ module Data.Packed.Matrix (
22 Element, 22 Element,
23 Matrix,rows,cols, 23 Matrix,rows,cols,
24 (><), 24 (><),
25 trans, ctrans, 25 trans,
26 reshape, flatten, 26 reshape, flatten,
27 fromLists, toLists, buildMatrix, 27 fromLists, toLists, buildMatrix,
28 (@@>), 28 (@@>),
@@ -33,7 +33,7 @@ module Data.Packed.Matrix (
33 flipud, fliprl, 33 flipud, fliprl,
34 subMatrix, takeRows, dropRows, takeColumns, dropColumns, 34 subMatrix, takeRows, dropRows, takeColumns, dropColumns,
35 extractRows, 35 extractRows,
36 ident, diag, diagRect, takeDiag, 36 diagRect, takeDiag,
37 liftMatrix, liftMatrix2, liftMatrix2Auto, 37 liftMatrix, liftMatrix2, liftMatrix2Auto,
38 dispf, disps, dispcf, vecdisp, latexFormat, format, 38 dispf, disps, dispcf, vecdisp, latexFormat, format,
39 loadMatrix, saveMatrix, fromFile, fileDimensions, 39 loadMatrix, saveMatrix, fromFile, fileDimensions,
@@ -169,28 +169,19 @@ fliprl m = fromColumns . reverse . toColumns $ m
169 169
170------------------------------------------------------------ 170------------------------------------------------------------
171 171
172-- | Creates a square matrix with a given diagonal. 172{- | creates a rectangular diagonal matrix:
173diag :: (Num a, Element a) => Vector a -> Matrix a
174diag v = ST.runSTMatrix $ do
175 let d = dim v
176 m <- ST.newMatrix 0 d d
177 mapM_ (\k -> ST.writeMatrix m k k (v@>k)) [0..d-1]
178 return m
179 173
180{- | creates a rectangular diagonal matrix 174@> diagRect 7 (fromList [10,20,30]) 4 5 :: Matrix Double
181 175(4><5)
182@> diagRect (constant 5 3) 3 4 :: Matrix Double 176 [ 10.0, 7.0, 7.0, 7.0, 7.0
183(3><4) 177 , 7.0, 20.0, 7.0, 7.0, 7.0
184 [ 5.0, 0.0, 0.0, 0.0 178 , 7.0, 7.0, 30.0, 7.0, 7.0
185 , 0.0, 5.0, 0.0, 0.0 179 , 7.0, 7.0, 7.0, 7.0, 7.0 ]@
186 , 0.0, 0.0, 5.0, 0.0 ]@
187-} 180-}
188diagRect :: (Element t, Num t) => Vector t -> Int -> Int -> Matrix t 181diagRect :: (Storable t) => t -> Vector t -> Int -> Int -> Matrix t
189diagRect v r c 182diagRect z v r c = ST.runSTMatrix $ do
190 | dim v < min r c = error "diagRect called with dim v < min r c" 183 m <- ST.newMatrix z r c
191 | otherwise = ST.runSTMatrix $ do 184 let d = min r c `min` (dim v)
192 m <- ST.newMatrix 0 r c
193 let d = min r c
194 mapM_ (\k -> ST.writeMatrix m k k (v@>k)) [0..d-1] 185 mapM_ (\k -> ST.writeMatrix m k k (v@>k)) [0..d-1]
195 return m 186 return m
196 187
@@ -198,10 +189,6 @@ diagRect v r c
198takeDiag :: (Element t) => Matrix t -> Vector t 189takeDiag :: (Element t) => Matrix t -> Vector t
199takeDiag m = fromList [flatten m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] 190takeDiag m = fromList [flatten m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]]
200 191
201-- | creates the identity matrix of given dimension
202ident :: (Num a, Element a) => Int -> Matrix a
203ident n = diag (constantD 1 n)
204
205------------------------------------------------------------ 192------------------------------------------------------------
206 193
207{- | An easy way to create a matrix: 194{- | An easy way to create a matrix:
@@ -225,7 +212,7 @@ Example:
225 , 4.0, 5.0, 6.0 ]@ 212 , 4.0, 5.0, 6.0 ]@
226 213
227-} 214-}
228(><) :: (Element a) => Int -> Int -> [a] -> Matrix a 215(><) :: (Storable a) => Int -> Int -> [a] -> Matrix a
229r >< c = f where 216r >< c = f where
230 f l | dim v == r*c = matrixFromVector RowMajor c v 217 f l | dim v == r*c = matrixFromVector RowMajor c v
231 | otherwise = error $ "inconsistent list size = " 218 | otherwise = error $ "inconsistent list size = "
@@ -261,16 +248,13 @@ fromLists :: Element t => [[t]] -> Matrix t
261fromLists = fromRows . map fromList 248fromLists = fromRows . map fromList
262 249
263-- | creates a 1-row matrix from a vector 250-- | creates a 1-row matrix from a vector
264asRow :: Element a => Vector a -> Matrix a 251asRow :: Storable a => Vector a -> Matrix a
265asRow v = reshape (dim v) v 252asRow v = reshape (dim v) v
266 253
267-- | creates a 1-column matrix from a vector 254-- | creates a 1-column matrix from a vector
268asColumn :: Element a => Vector a -> Matrix a 255asColumn :: Storable a => Vector a -> Matrix a
269asColumn v = reshape 1 v 256asColumn v = reshape 1 v
270 257
271-- | conjugate transpose
272ctrans :: Element e => Matrix e -> Matrix e
273ctrans = liftMatrix conjugateD . trans
274 258
275 259
276{- | creates a Matrix of the specified size using the supplied function to 260{- | creates a Matrix of the specified size using the supplied function to
@@ -289,7 +273,7 @@ buildMatrix rc cc f =
289 273
290----------------------------------------------------- 274-----------------------------------------------------
291 275
292fromArray2D :: (Element e) => Array (Int, Int) e -> Matrix e 276fromArray2D :: (Storable e) => Array (Int, Int) e -> Matrix e
293fromArray2D m = (r><c) (elems m) 277fromArray2D m = (r><c) (elems m)
294 where ((r0,c0),(r1,c1)) = bounds m 278 where ((r0,c0),(r1,c1)) = bounds m
295 r = r1-r0+1 279 r = r1-r0+1
diff --git a/lib/Data/Packed/ST.hs b/lib/Data/Packed/ST.hs
index 48e35b4..652f43e 100644
--- a/lib/Data/Packed/ST.hs
+++ b/lib/Data/Packed/ST.hs
@@ -90,11 +90,11 @@ writeVector :: Storable t => STVector s t -> Int -> t -> ST s ()
90writeVector = safeIndexV unsafeWriteVector 90writeVector = safeIndexV unsafeWriteVector
91 91
92{-# NOINLINE newUndefinedVector #-} 92{-# NOINLINE newUndefinedVector #-}
93newUndefinedVector :: Element t => Int -> ST s (STVector s t) 93newUndefinedVector :: Storable t => Int -> ST s (STVector s t)
94newUndefinedVector = unsafeIOToST . fmap STVector . createVector 94newUndefinedVector = unsafeIOToST . fmap STVector . createVector
95 95
96{-# INLINE newVector #-} 96{-# INLINE newVector #-}
97newVector :: Element t => t -> Int -> ST s (STVector s t) 97newVector :: Storable t => t -> Int -> ST s (STVector s t)
98newVector x n = do 98newVector x n = do
99 v <- newUndefinedVector n 99 v <- newUndefinedVector n
100 let go (-1) = return v 100 let go (-1) = return v
@@ -164,9 +164,9 @@ writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s ()
164writeMatrix = safeIndexM unsafeWriteMatrix 164writeMatrix = safeIndexM unsafeWriteMatrix
165 165
166{-# NOINLINE newUndefinedMatrix #-} 166{-# NOINLINE newUndefinedMatrix #-}
167newUndefinedMatrix :: Element t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t) 167newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t)
168newUndefinedMatrix order r c = unsafeIOToST $ fmap STMatrix $ createMatrix order r c 168newUndefinedMatrix order r c = unsafeIOToST $ fmap STMatrix $ createMatrix order r c
169 169
170{-# NOINLINE newMatrix #-} 170{-# NOINLINE newMatrix #-}
171newMatrix :: Element t => t -> Int -> Int -> ST s (STMatrix s t) 171newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t)
172newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c) 172newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c)
diff --git a/lib/Numeric/Container.hs b/lib/Numeric/Container.hs
index 83bf44e..1afc5a1 100644
--- a/lib/Numeric/Container.hs
+++ b/lib/Numeric/Container.hs
@@ -22,6 +22,7 @@
22module Numeric.Container ( 22module Numeric.Container (
23 -- * Generic operations 23 -- * Generic operations
24 Container(..), 24 Container(..),
25 ctrans, diag, ident,
25 -- * Matrix product and related functions 26 -- * Matrix product and related functions
26 Product(..), 27 Product(..),
27 mXm,mXv,vXm, 28 mXm,mXv,vXm,
@@ -221,6 +222,20 @@ instance (Container Vector a) => Container Matrix a where
221 222
222---------------------------------------------------- 223----------------------------------------------------
223 224
225-- | conjugate transpose
226ctrans :: Element e => Matrix e -> Matrix e
227ctrans = liftMatrix conjugateD . trans
228
229-- | Creates a square matrix with a given diagonal.
230diag :: (Num a, Element a) => Vector a -> Matrix a
231diag v = diagRect 0 v n n where n = dim v
232
233-- | creates the identity matrix of given dimension
234ident :: (Num a, Element a) => Int -> Matrix a
235ident n = diag (constantD 1 n)
236
237----------------------------------------------------
238
224 239
225-- | Matrix product and related functions 240-- | Matrix product and related functions
226class Element e => Product e where 241class Element e => Product e where
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs
index 4f6f54d..394a1d7 100644
--- a/lib/Numeric/LinearAlgebra/Algorithms.hs
+++ b/lib/Numeric/LinearAlgebra/Algorithms.hs
@@ -183,7 +183,7 @@ singularValues = {-# SCC "singularValues" #-} sv'
183fullSVD :: Field t => Matrix t -> (Matrix t, Matrix Double, Matrix t) 183fullSVD :: Field t => Matrix t -> (Matrix t, Matrix Double, Matrix t)
184fullSVD m = (u,d,v) where 184fullSVD m = (u,d,v) where
185 (u,s,v) = svd m 185 (u,s,v) = svd m
186 d = diagRect s r c 186 d = diagRect 0 s r c
187 r = rows m 187 r = rows m
188 c = cols m 188 c = cols m
189 189
@@ -210,7 +210,7 @@ leftSV m | vertical m = let (u,s,_) = svd m in (u,s)
210{-# DEPRECATED full "use fullSVD instead" #-} 210{-# DEPRECATED full "use fullSVD instead" #-}
211full svdFun m = (u, d ,v) where 211full svdFun m = (u, d ,v) where
212 (u,s,v) = svdFun m 212 (u,s,v) = svdFun m
213 d = diagRect s r c 213 d = diagRect 0 s r c
214 r = rows m 214 r = rows m
215 c = cols m 215 c = cols m
216 216
@@ -624,10 +624,10 @@ luFact (l_u,perm) | r <= c = (l ,u ,p, s)
624 c = cols l_u 624 c = cols l_u
625 tu = triang r c 0 1 625 tu = triang r c 0 1
626 tl = triang r c 0 0 626 tl = triang r c 0 0
627 l = takeColumns r (l_u |*| tl) |+| diagRect (konst 1 r) r r 627 l = takeColumns r (l_u |*| tl) |+| diagRect 0 (konst 1 r) r r
628 u = l_u |*| tu 628 u = l_u |*| tu
629 (p,s) = fixPerm r perm 629 (p,s) = fixPerm r perm
630 l' = (l_u |*| tl) |+| diagRect (konst 1 c) r c 630 l' = (l_u |*| tl) |+| diagRect 0 (konst 1 c) r c
631 u' = takeRows c (l_u |*| tu) 631 u' = takeRows c (l_u |*| tu)
632 (|+|) = add 632 (|+|) = add
633 (|*|) = mul 633 (|*|) = mul
diff --git a/lib/Numeric/LinearAlgebra/Tests/Instances.hs b/lib/Numeric/LinearAlgebra/Tests/Instances.hs
index 804c481..771739a 100644
--- a/lib/Numeric/LinearAlgebra/Tests/Instances.hs
+++ b/lib/Numeric/LinearAlgebra/Tests/Instances.hs
@@ -150,7 +150,7 @@ instance (ArbitraryField a) => Arbitrary (WC a) where
150 c = cols m 150 c = cols m
151 n = min r c 151 n = min r c
152 sv' <- replicateM n (choose (1,100)) 152 sv' <- replicateM n (choose (1,100))
153 let s = diagRect (fromList sv') r c 153 let s = diagRect 0 (fromList sv') r c
154 return $ WC (u <> real s <> trans v) 154 return $ WC (u <> real s <> trans v)
155 155
156#if MIN_VERSION_QuickCheck(2,0,0) 156#if MIN_VERSION_QuickCheck(2,0,0)
diff --git a/lib/Numeric/LinearAlgebra/Tests/Properties.hs b/lib/Numeric/LinearAlgebra/Tests/Properties.hs
index 623b78c..a35f591 100644
--- a/lib/Numeric/LinearAlgebra/Tests/Properties.hs
+++ b/lib/Numeric/LinearAlgebra/Tests/Properties.hs
@@ -138,7 +138,7 @@ svdProp1 m = m |~| u <> real d <> trans v && unitary u && unitary v
138 138
139svdProp1a svdfun m = m |~| u <> real d <> trans v && unitary u && unitary v where 139svdProp1a svdfun m = m |~| u <> real d <> trans v && unitary u && unitary v where
140 (u,s,v) = svdfun m 140 (u,s,v) = svdfun m
141 d = diagRect s (rows m) (cols m) 141 d = diagRect 0 s (rows m) (cols m)
142 142
143svdProp1b svdfun m = unitary u && unitary v where 143svdProp1b svdfun m = unitary u && unitary v where
144 (u,_,v) = svdfun m 144 (u,_,v) = svdfun m
diff --git a/lib/Numeric/Matrix.hs b/lib/Numeric/Matrix.hs
index d5c6f44..9260bd5 100644
--- a/lib/Numeric/Matrix.hs
+++ b/lib/Numeric/Matrix.hs
@@ -28,7 +28,6 @@ module Numeric.Matrix (
28 -- * Basic functions 28 -- * Basic functions
29 module Data.Packed.Matrix, 29 module Data.Packed.Matrix,
30 module Numeric.Vector, 30 module Numeric.Vector,
31 --module Numeric.Container,
32 optimiseMult, 31 optimiseMult,
33 -- * Operators 32 -- * Operators
34 (<>), (<\>) 33 (<>), (<\>)