summaryrefslogtreecommitdiff
path: root/lib/Data/Packed/Internal/Matrix.hs
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Data/Packed/Internal/Matrix.hs')
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs36
1 files changed, 18 insertions, 18 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index f63ee52..fbab33c 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -84,17 +84,17 @@ type Mt t s = Int -> Int -> Ptr t -> s
84-- type t ::> s = Mt t s 84-- type t ::> s = Mt t s
85 85
86-- | the inverse of 'Data.Packed.Matrix.fromLists' 86-- | the inverse of 'Data.Packed.Matrix.fromLists'
87toLists :: (Field t) => Matrix t -> [[t]] 87toLists :: (Element t) => Matrix t -> [[t]]
88toLists m = partit (cols m) . toList . cdat $ m 88toLists m = partit (cols m) . toList . cdat $ m
89 89
90-- | creates a Matrix from a list of vectors 90-- | creates a Matrix from a list of vectors
91fromRows :: Field t => [Vector t] -> Matrix t 91fromRows :: Element t => [Vector t] -> Matrix t
92fromRows vs = case common dim vs of 92fromRows vs = case common dim vs of
93 Nothing -> error "fromRows applied to [] or to vectors with different sizes" 93 Nothing -> error "fromRows applied to [] or to vectors with different sizes"
94 Just c -> reshape c (join vs) 94 Just c -> reshape c (join vs)
95 95
96-- | extracts the rows of a matrix as a list of vectors 96-- | extracts the rows of a matrix as a list of vectors
97toRows :: Field t => Matrix t -> [Vector t] 97toRows :: Element t => Matrix t -> [Vector t]
98toRows m = toRows' 0 where 98toRows m = toRows' 0 where
99 v = cdat m 99 v = cdat m
100 r = rows m 100 r = rows m
@@ -103,11 +103,11 @@ toRows m = toRows' 0 where
103 | otherwise = subVector k c v : toRows' (k+c) 103 | otherwise = subVector k c v : toRows' (k+c)
104 104
105-- | Creates a matrix from a list of vectors, as columns 105-- | Creates a matrix from a list of vectors, as columns
106fromColumns :: Field t => [Vector t] -> Matrix t 106fromColumns :: Element t => [Vector t] -> Matrix t
107fromColumns m = trans . fromRows $ m 107fromColumns m = trans . fromRows $ m
108 108
109-- | Creates a list of vectors from the columns of a matrix 109-- | Creates a list of vectors from the columns of a matrix
110toColumns :: Field t => Matrix t -> [Vector t] 110toColumns :: Element t => Matrix t -> [Vector t]
111toColumns m = toRows . trans $ m 111toColumns m = toRows . trans $ m
112 112
113 113
@@ -152,18 +152,18 @@ where r is the desired number of rows.)
152 , 9.0, 10.0, 11.0, 12.0 ]@ 152 , 9.0, 10.0, 11.0, 12.0 ]@
153 153
154-} 154-}
155reshape :: Field t => Int -> Vector t -> Matrix t 155reshape :: Element t => Int -> Vector t -> Matrix t
156reshape c v = matrixFromVector RowMajor c v 156reshape c v = matrixFromVector RowMajor c v
157 157
158singleton x = reshape 1 (fromList [x]) 158singleton x = reshape 1 (fromList [x])
159 159
160-- | application of a vector function on the flattened matrix elements 160-- | application of a vector function on the flattened matrix elements
161liftMatrix :: (Field a, Field b) => (Vector a -> Vector b) -> Matrix a -> Matrix b 161liftMatrix :: (Element a, Element b) => (Vector a -> Vector b) -> Matrix a -> Matrix b
162liftMatrix f MC { cols = c, cdat = d } = matrixFromVector RowMajor c (f d) 162liftMatrix f MC { cols = c, cdat = d } = matrixFromVector RowMajor c (f d)
163liftMatrix f MF { cols = c, fdat = d } = matrixFromVector ColumnMajor c (f d) 163liftMatrix f MF { cols = c, fdat = d } = matrixFromVector ColumnMajor c (f d)
164 164
165-- | application of a vector function on the flattened matrices elements 165-- | application of a vector function on the flattened matrices elements
166liftMatrix2 :: (Field t, Field a, Field b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t 166liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
167liftMatrix2 f m1 m2 167liftMatrix2 f m1 m2
168 | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2" 168 | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2"
169 | otherwise = case m1 of 169 | otherwise = case m1 of
@@ -176,8 +176,8 @@ compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2
176 176
177---------------------------------------------------------------- 177----------------------------------------------------------------
178 178
179-- | Optimized matrix computations are provided for elements in the Field class. 179-- | Optimized matrix computations are provided for elements in the Element class.
180class (Storable a, Floating a) => Field a where 180class (Storable a, Floating a) => Element a where
181 constantD :: a -> Int -> Vector a 181 constantD :: a -> Int -> Vector a
182 transdata :: Int -> Vector a -> Int -> Vector a 182 transdata :: Int -> Vector a -> Int -> Vector a
183 multiplyD :: Matrix a -> Matrix a -> Matrix a 183 multiplyD :: Matrix a -> Matrix a -> Matrix a
@@ -186,14 +186,14 @@ class (Storable a, Floating a) => Field a where
186 -> Matrix a -> Matrix a 186 -> Matrix a -> Matrix a
187 diagD :: Vector a -> Matrix a 187 diagD :: Vector a -> Matrix a
188 188
189instance Field Double where 189instance Element Double where
190 constantD = constantR 190 constantD = constantR
191 transdata = transdataR 191 transdata = transdataR
192 multiplyD = multiplyR 192 multiplyD = multiplyR
193 subMatrixD = subMatrixR 193 subMatrixD = subMatrixR
194 diagD = diagR 194 diagD = diagR
195 195
196instance Field (Complex Double) where 196instance Element (Complex Double) where
197 constantD = constantC 197 constantD = constantC
198 transdata = transdataC 198 transdata = transdataC
199 multiplyD = multiplyC 199 multiplyD = multiplyC
@@ -202,7 +202,7 @@ instance Field (Complex Double) where
202 202
203------------------------------------------------------------------ 203------------------------------------------------------------------
204 204
205(>|<) :: (Field a) => Int -> Int -> [a] -> Matrix a 205(>|<) :: (Element a) => Int -> Int -> [a] -> Matrix a
206r >|< c = f where 206r >|< c = f where
207 f l | dim v == r*c = matrixFromVector ColumnMajor c v 207 f l | dim v == r*c = matrixFromVector ColumnMajor c v
208 | otherwise = error $ "inconsistent list size = " 208 | otherwise = error $ "inconsistent list size = "
@@ -260,13 +260,13 @@ foreign import ccall safe "auxi.h multiplyC"
260 -> Int -> Int -> Ptr (Complex Double) 260 -> Int -> Int -> Ptr (Complex Double)
261 -> IO Int 261 -> IO Int
262 262
263multiply' :: (Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a 263multiply' :: (Element a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a
264multiply' RowMajor a b = multiplyD a b 264multiply' RowMajor a b = multiplyD a b
265multiply' ColumnMajor a b = trans $ multiplyD (trans b) (trans a) 265multiply' ColumnMajor a b = trans $ multiplyD (trans b) (trans a)
266 266
267 267
268-- | matrix product 268-- | matrix product
269multiply :: (Field a) => Matrix a -> Matrix a -> Matrix a 269multiply :: (Element a) => Matrix a -> Matrix a -> Matrix a
270multiply = multiplyD 270multiply = multiplyD
271 271
272---------------------------------------------------------------------- 272----------------------------------------------------------------------
@@ -287,7 +287,7 @@ subMatrixC (r0,c0) (rt,ct) x =
287 reshape (2*cols x) . asReal . cdat $ x 287 reshape (2*cols x) . asReal . cdat $ x
288 288
289-- | Extracts a submatrix from a matrix. 289-- | Extracts a submatrix from a matrix.
290subMatrix :: Field a 290subMatrix :: Element a
291 => (Int,Int) -- ^ (r0,c0) starting position 291 => (Int,Int) -- ^ (r0,c0) starting position
292 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix 292 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
293 -> Matrix a -- ^ input matrix 293 -> Matrix a -- ^ input matrix
@@ -313,7 +313,7 @@ diagC = diagAux c_diagC "diagC"
313foreign import ccall "auxi.h diagC" c_diagC :: TCVCM 313foreign import ccall "auxi.h diagC" c_diagC :: TCVCM
314 314
315-- | creates a square matrix with the given diagonal 315-- | creates a square matrix with the given diagonal
316diag :: Field a => Vector a -> Matrix a 316diag :: Element a => Vector a -> Matrix a
317diag = diagD 317diag = diagD
318 318
319------------------------------------------------------------------------ 319------------------------------------------------------------------------
@@ -340,7 +340,7 @@ foreign import ccall safe "auxi.h constantC"
340@> constant 2 7 340@> constant 2 7
3417 |> [2.0,2.0,2.0,2.0,2.0,2.0,2.0]@ 3417 |> [2.0,2.0,2.0,2.0,2.0,2.0,2.0]@
342-} 342-}
343constant :: Field a => a -> Int -> Vector a 343constant :: Element a => a -> Int -> Vector a
344constant = constantD 344constant = constantD
345 345
346-------------------------------------------------------------------------- 346--------------------------------------------------------------------------