diff options
Diffstat (limited to 'lib/Data/Packed/Internal/Matrix.hs')
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 74 |
1 files changed, 68 insertions, 6 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index db53cd1..bd333d4 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -74,8 +74,7 @@ common f = commonval . map f where | |||
74 | commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing | 74 | commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing |
75 | 75 | ||
76 | 76 | ||
77 | toLists m | fortran m = transpose $ partit (rows m) . toList . dat $ m | 77 | toLists m = partit (cols m) . toList . cdat $ m |
78 | | otherwise = partit (cols m) . toList . dat $ m | ||
79 | 78 | ||
80 | dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp | 79 | dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp |
81 | where | 80 | where |
@@ -145,6 +144,8 @@ transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2 | |||
145 | --{-# RULES "transdataR" transdata=transdataR #-} | 144 | --{-# RULES "transdataR" transdata=transdataR #-} |
146 | --{-# RULES "transdataC" transdata=transdataC #-} | 145 | --{-# RULES "transdataC" transdata=transdataC #-} |
147 | 146 | ||
147 | ----------------------------------------------------------------------------- | ||
148 | |||
148 | -- | creates a Matrix from a list of vectors | 149 | -- | creates a Matrix from a list of vectors |
149 | fromRows :: Field t => [Vector t] -> Matrix t | 150 | fromRows :: Field t => [Vector t] -> Matrix t |
150 | fromRows vs = case common dim vs of | 151 | fromRows vs = case common dim vs of |
@@ -160,6 +161,34 @@ toRows m = toRows' 0 where | |||
160 | toRows' k | k == r*c = [] | 161 | toRows' k | k == r*c = [] |
161 | | otherwise = subVector k c v : toRows' (k+c) | 162 | | otherwise = subVector k c v : toRows' (k+c) |
162 | 163 | ||
164 | -- | Creates a matrix from a list of vectors, as columns | ||
165 | fromColumns :: Field t => [Vector t] -> Matrix t | ||
166 | fromColumns m = trans . fromRows $ m | ||
167 | |||
168 | -- | Creates a list of vectors from the columns of a matrix | ||
169 | toColumns :: Field t => Matrix t -> [Vector t] | ||
170 | toColumns m = toRows . trans $ m | ||
171 | |||
172 | -- | creates a matrix from a vertical list of matrices | ||
173 | joinVert :: Field t => [Matrix t] -> Matrix t | ||
174 | joinVert ms = case common cols ms of | ||
175 | Nothing -> error "joinVert on matrices with different number of columns" | ||
176 | Just c -> reshape c $ join (map cdat ms) | ||
177 | |||
178 | -- | creates a matrix from a horizontal list of matrices | ||
179 | joinHoriz :: Field t => [Matrix t] -> Matrix t | ||
180 | joinHoriz ms = trans. joinVert . map trans $ ms | ||
181 | |||
182 | ------------------------------------------------------------------------------ | ||
183 | |||
184 | -- | Reverse rows | ||
185 | flipud :: Field t => Matrix t -> Matrix t | ||
186 | flipud m = fromRows . reverse . toRows $ m | ||
187 | |||
188 | -- | Reverse columns | ||
189 | fliprl :: Field t => Matrix t -> Matrix t | ||
190 | fliprl m = fromColumns . reverse . toColumns $ m | ||
191 | |||
163 | ----------------------------------------------------------------- | 192 | ----------------------------------------------------------------- |
164 | 193 | ||
165 | liftMatrix f m = m { dat = f dat, tdat = f tdat } -- check sizes | 194 | liftMatrix f m = m { dat = f dat, tdat = f tdat } -- check sizes |
@@ -168,7 +197,11 @@ liftMatrix f m = m { dat = f dat, tdat = f tdat } -- check sizes | |||
168 | 197 | ||
169 | dotL a b = sum (zipWith (*) a b) | 198 | dotL a b = sum (zipWith (*) a b) |
170 | 199 | ||
171 | multiplyL a b = [[dotL x y | y <- transpose b] | x <- a] | 200 | multiplyL a b | ok = [[dotL x y | y <- transpose b] | x <- a] |
201 | | otherwise = error "inconsistent dimensions in contraction " | ||
202 | where ok = case common length a of | ||
203 | Nothing -> False | ||
204 | Just c -> c == length b | ||
172 | 205 | ||
173 | transL m = matrixFromVector RowMajor (rows m) $ transdataG (cols m) (cdat m) (rows m) | 206 | transL m = matrixFromVector RowMajor (rows m) $ transdataG (cols m) (cdat m) (rows m) |
174 | 207 | ||
@@ -201,9 +234,8 @@ foreign import ccall safe "aux.h multiplyC" | |||
201 | 234 | ||
202 | multiply :: (Num a, Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a | 235 | multiply :: (Num a, Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a |
203 | multiply RowMajor a b = multiplyD RowMajor a b | 236 | multiply RowMajor a b = multiplyD RowMajor a b |
204 | multiply ColumnMajor a b = trans $ multiplyT ColumnMajor a b | 237 | multiply ColumnMajor a b = m {rows = cols m, cols = rows m, order = ColumnMajor} |
205 | 238 | where m = multiplyD RowMajor (trans b) (trans a) | |
206 | multiplyT order a b = multiplyD order (trans b) (trans a) | ||
207 | 239 | ||
208 | multiplyD order a b | 240 | multiplyD order a b |
209 | | isReal (baseOf.dat) a = scast $ multiplyAux order cmultiplyR (scast a) (scast b) | 241 | | isReal (baseOf.dat) a = scast $ multiplyAux order cmultiplyR (scast a) (scast b) |
@@ -253,3 +285,33 @@ subMatrix st sz m | |||
253 | 285 | ||
254 | subMatrixG (r0,c0) (rt,ct) x = reshape ct $ fromList $ concat $ map (subList c0 ct) (subList r0 rt (toLists x)) | 286 | subMatrixG (r0,c0) (rt,ct) x = reshape ct $ fromList $ concat $ map (subList c0 ct) (subList r0 rt (toLists x)) |
255 | where subList s n = take n . drop s | 287 | where subList s n = take n . drop s |
288 | |||
289 | --------------------------------------------------------------------- | ||
290 | |||
291 | diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do | ||
292 | m <- createMatrix RowMajor n n | ||
293 | fun // vec v // mat dat m // check msg [dat m] | ||
294 | return m | ||
295 | |||
296 | -- | diagonal matrix from a real vector | ||
297 | diagR :: Vector Double -> Matrix Double | ||
298 | diagR = diagAux c_diagR "diagR" | ||
299 | foreign import ccall "aux.h diagR" c_diagR :: Double :> Double ::> IO Int | ||
300 | |||
301 | -- | diagonal matrix from a real vector | ||
302 | diagC :: Vector (Complex Double) -> Matrix (Complex Double) | ||
303 | diagC = diagAux c_diagC "diagC" | ||
304 | foreign import ccall "aux.h diagC" c_diagC :: (Complex Double) :> (Complex Double) ::> IO Int | ||
305 | |||
306 | -- | diagonal matrix from a vector | ||
307 | diag :: (Num a, Field a) => Vector a -> Matrix a | ||
308 | diag v | ||
309 | | isReal (baseOf) v = scast $ diagR (scast v) | ||
310 | | isComp (baseOf) v = scast $ diagC (scast v) | ||
311 | | otherwise = diagG v | ||
312 | |||
313 | diagG v = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..c]] | ||
314 | where c = dim v | ||
315 | l = toList v | ||
316 | delta i j | i==j = 1 | ||
317 | | otherwise = 0 | ||