summaryrefslogtreecommitdiff
path: root/lib/Data/Packed
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Data/Packed')
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs5
-rw-r--r--lib/Data/Packed/Matrix.hs50
-rw-r--r--lib/Data/Packed/ST.hs8
3 files changed, 23 insertions, 40 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index c0824a3..94b56cf 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -221,13 +221,13 @@ where r is the desired number of rows.)
221 , 9.0, 10.0, 11.0, 12.0 ]@ 221 , 9.0, 10.0, 11.0, 12.0 ]@
222 222
223-} 223-}
224reshape :: Element t => Int -> Vector t -> Matrix t 224reshape :: Storable t => Int -> Vector t -> Matrix t
225reshape c v = matrixFromVector RowMajor c v 225reshape c v = matrixFromVector RowMajor c v
226 226
227singleton x = reshape 1 (fromList [x]) 227singleton x = reshape 1 (fromList [x])
228 228
229-- | application of a vector function on the flattened matrix elements 229-- | application of a vector function on the flattened matrix elements
230liftMatrix :: (Element a, Element b) => (Vector a -> Vector b) -> Matrix a -> Matrix b 230liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b
231liftMatrix f MC { icols = c, cdat = d } = matrixFromVector RowMajor c (f d) 231liftMatrix f MC { icols = c, cdat = d } = matrixFromVector RowMajor c (f d)
232liftMatrix f MF { icols = c, fdat = d } = matrixFromVector ColumnMajor c (f d) 232liftMatrix f MF { icols = c, fdat = d } = matrixFromVector ColumnMajor c (f d)
233 233
@@ -246,7 +246,6 @@ compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2
246------------------------------------------------------------------ 246------------------------------------------------------------------
247 247
248-- | Supported element types for basic matrix operations. 248-- | Supported element types for basic matrix operations.
249--class (Storable a, Floating a) => Element a where
250class (Storable a) => Element a where 249class (Storable a) => Element a where
251 subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position 250 subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position
252 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix 251 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs
index b8c309c..ea16748 100644
--- a/lib/Data/Packed/Matrix.hs
+++ b/lib/Data/Packed/Matrix.hs
@@ -22,7 +22,7 @@ module Data.Packed.Matrix (
22 Element, 22 Element,
23 Matrix,rows,cols, 23 Matrix,rows,cols,
24 (><), 24 (><),
25 trans, ctrans, 25 trans,
26 reshape, flatten, 26 reshape, flatten,
27 fromLists, toLists, buildMatrix, 27 fromLists, toLists, buildMatrix,
28 (@@>), 28 (@@>),
@@ -33,7 +33,7 @@ module Data.Packed.Matrix (
33 flipud, fliprl, 33 flipud, fliprl,
34 subMatrix, takeRows, dropRows, takeColumns, dropColumns, 34 subMatrix, takeRows, dropRows, takeColumns, dropColumns,
35 extractRows, 35 extractRows,
36 ident, diag, diagRect, takeDiag, 36 diagRect, takeDiag,
37 liftMatrix, liftMatrix2, liftMatrix2Auto, 37 liftMatrix, liftMatrix2, liftMatrix2Auto,
38 dispf, disps, dispcf, vecdisp, latexFormat, format, 38 dispf, disps, dispcf, vecdisp, latexFormat, format,
39 loadMatrix, saveMatrix, fromFile, fileDimensions, 39 loadMatrix, saveMatrix, fromFile, fileDimensions,
@@ -169,28 +169,19 @@ fliprl m = fromColumns . reverse . toColumns $ m
169 169
170------------------------------------------------------------ 170------------------------------------------------------------
171 171
172-- | Creates a square matrix with a given diagonal. 172{- | creates a rectangular diagonal matrix:
173diag :: (Num a, Element a) => Vector a -> Matrix a
174diag v = ST.runSTMatrix $ do
175 let d = dim v
176 m <- ST.newMatrix 0 d d
177 mapM_ (\k -> ST.writeMatrix m k k (v@>k)) [0..d-1]
178 return m
179 173
180{- | creates a rectangular diagonal matrix 174@> diagRect 7 (fromList [10,20,30]) 4 5 :: Matrix Double
181 175(4><5)
182@> diagRect (constant 5 3) 3 4 :: Matrix Double 176 [ 10.0, 7.0, 7.0, 7.0, 7.0
183(3><4) 177 , 7.0, 20.0, 7.0, 7.0, 7.0
184 [ 5.0, 0.0, 0.0, 0.0 178 , 7.0, 7.0, 30.0, 7.0, 7.0
185 , 0.0, 5.0, 0.0, 0.0 179 , 7.0, 7.0, 7.0, 7.0, 7.0 ]@
186 , 0.0, 0.0, 5.0, 0.0 ]@
187-} 180-}
188diagRect :: (Element t, Num t) => Vector t -> Int -> Int -> Matrix t 181diagRect :: (Storable t) => t -> Vector t -> Int -> Int -> Matrix t
189diagRect v r c 182diagRect z v r c = ST.runSTMatrix $ do
190 | dim v < min r c = error "diagRect called with dim v < min r c" 183 m <- ST.newMatrix z r c
191 | otherwise = ST.runSTMatrix $ do 184 let d = min r c `min` (dim v)
192 m <- ST.newMatrix 0 r c
193 let d = min r c
194 mapM_ (\k -> ST.writeMatrix m k k (v@>k)) [0..d-1] 185 mapM_ (\k -> ST.writeMatrix m k k (v@>k)) [0..d-1]
195 return m 186 return m
196 187
@@ -198,10 +189,6 @@ diagRect v r c
198takeDiag :: (Element t) => Matrix t -> Vector t 189takeDiag :: (Element t) => Matrix t -> Vector t
199takeDiag m = fromList [flatten m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] 190takeDiag m = fromList [flatten m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]]
200 191
201-- | creates the identity matrix of given dimension
202ident :: (Num a, Element a) => Int -> Matrix a
203ident n = diag (constantD 1 n)
204
205------------------------------------------------------------ 192------------------------------------------------------------
206 193
207{- | An easy way to create a matrix: 194{- | An easy way to create a matrix:
@@ -225,7 +212,7 @@ Example:
225 , 4.0, 5.0, 6.0 ]@ 212 , 4.0, 5.0, 6.0 ]@
226 213
227-} 214-}
228(><) :: (Element a) => Int -> Int -> [a] -> Matrix a 215(><) :: (Storable a) => Int -> Int -> [a] -> Matrix a
229r >< c = f where 216r >< c = f where
230 f l | dim v == r*c = matrixFromVector RowMajor c v 217 f l | dim v == r*c = matrixFromVector RowMajor c v
231 | otherwise = error $ "inconsistent list size = " 218 | otherwise = error $ "inconsistent list size = "
@@ -261,16 +248,13 @@ fromLists :: Element t => [[t]] -> Matrix t
261fromLists = fromRows . map fromList 248fromLists = fromRows . map fromList
262 249
263-- | creates a 1-row matrix from a vector 250-- | creates a 1-row matrix from a vector
264asRow :: Element a => Vector a -> Matrix a 251asRow :: Storable a => Vector a -> Matrix a
265asRow v = reshape (dim v) v 252asRow v = reshape (dim v) v
266 253
267-- | creates a 1-column matrix from a vector 254-- | creates a 1-column matrix from a vector
268asColumn :: Element a => Vector a -> Matrix a 255asColumn :: Storable a => Vector a -> Matrix a
269asColumn v = reshape 1 v 256asColumn v = reshape 1 v
270 257
271-- | conjugate transpose
272ctrans :: Element e => Matrix e -> Matrix e
273ctrans = liftMatrix conjugateD . trans
274 258
275 259
276{- | creates a Matrix of the specified size using the supplied function to 260{- | creates a Matrix of the specified size using the supplied function to
@@ -289,7 +273,7 @@ buildMatrix rc cc f =
289 273
290----------------------------------------------------- 274-----------------------------------------------------
291 275
292fromArray2D :: (Element e) => Array (Int, Int) e -> Matrix e 276fromArray2D :: (Storable e) => Array (Int, Int) e -> Matrix e
293fromArray2D m = (r><c) (elems m) 277fromArray2D m = (r><c) (elems m)
294 where ((r0,c0),(r1,c1)) = bounds m 278 where ((r0,c0),(r1,c1)) = bounds m
295 r = r1-r0+1 279 r = r1-r0+1
diff --git a/lib/Data/Packed/ST.hs b/lib/Data/Packed/ST.hs
index 48e35b4..652f43e 100644
--- a/lib/Data/Packed/ST.hs
+++ b/lib/Data/Packed/ST.hs
@@ -90,11 +90,11 @@ writeVector :: Storable t => STVector s t -> Int -> t -> ST s ()
90writeVector = safeIndexV unsafeWriteVector 90writeVector = safeIndexV unsafeWriteVector
91 91
92{-# NOINLINE newUndefinedVector #-} 92{-# NOINLINE newUndefinedVector #-}
93newUndefinedVector :: Element t => Int -> ST s (STVector s t) 93newUndefinedVector :: Storable t => Int -> ST s (STVector s t)
94newUndefinedVector = unsafeIOToST . fmap STVector . createVector 94newUndefinedVector = unsafeIOToST . fmap STVector . createVector
95 95
96{-# INLINE newVector #-} 96{-# INLINE newVector #-}
97newVector :: Element t => t -> Int -> ST s (STVector s t) 97newVector :: Storable t => t -> Int -> ST s (STVector s t)
98newVector x n = do 98newVector x n = do
99 v <- newUndefinedVector n 99 v <- newUndefinedVector n
100 let go (-1) = return v 100 let go (-1) = return v
@@ -164,9 +164,9 @@ writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s ()
164writeMatrix = safeIndexM unsafeWriteMatrix 164writeMatrix = safeIndexM unsafeWriteMatrix
165 165
166{-# NOINLINE newUndefinedMatrix #-} 166{-# NOINLINE newUndefinedMatrix #-}
167newUndefinedMatrix :: Element t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t) 167newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t)
168newUndefinedMatrix order r c = unsafeIOToST $ fmap STMatrix $ createMatrix order r c 168newUndefinedMatrix order r c = unsafeIOToST $ fmap STMatrix $ createMatrix order r c
169 169
170{-# NOINLINE newMatrix #-} 170{-# NOINLINE newMatrix #-}
171newMatrix :: Element t => t -> Int -> Int -> ST s (STMatrix s t) 171newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t)
172newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c) 172newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c)