summaryrefslogtreecommitdiff
path: root/lib/Data/Packed/Internal/Matrix.hs
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Data/Packed/Internal/Matrix.hs')
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs74
1 files changed, 68 insertions, 6 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index db53cd1..bd333d4 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -74,8 +74,7 @@ common f = commonval . map f where
74 commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing 74 commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing
75 75
76 76
77toLists m | fortran m = transpose $ partit (rows m) . toList . dat $ m 77toLists m = partit (cols m) . toList . cdat $ m
78 | otherwise = partit (cols m) . toList . dat $ m
79 78
80dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp 79dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp
81 where 80 where
@@ -145,6 +144,8 @@ transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2
145--{-# RULES "transdataR" transdata=transdataR #-} 144--{-# RULES "transdataR" transdata=transdataR #-}
146--{-# RULES "transdataC" transdata=transdataC #-} 145--{-# RULES "transdataC" transdata=transdataC #-}
147 146
147-----------------------------------------------------------------------------
148
148-- | creates a Matrix from a list of vectors 149-- | creates a Matrix from a list of vectors
149fromRows :: Field t => [Vector t] -> Matrix t 150fromRows :: Field t => [Vector t] -> Matrix t
150fromRows vs = case common dim vs of 151fromRows vs = case common dim vs of
@@ -160,6 +161,34 @@ toRows m = toRows' 0 where
160 toRows' k | k == r*c = [] 161 toRows' k | k == r*c = []
161 | otherwise = subVector k c v : toRows' (k+c) 162 | otherwise = subVector k c v : toRows' (k+c)
162 163
164-- | Creates a matrix from a list of vectors, as columns
165fromColumns :: Field t => [Vector t] -> Matrix t
166fromColumns m = trans . fromRows $ m
167
168-- | Creates a list of vectors from the columns of a matrix
169toColumns :: Field t => Matrix t -> [Vector t]
170toColumns m = toRows . trans $ m
171
172-- | creates a matrix from a vertical list of matrices
173joinVert :: Field t => [Matrix t] -> Matrix t
174joinVert ms = case common cols ms of
175 Nothing -> error "joinVert on matrices with different number of columns"
176 Just c -> reshape c $ join (map cdat ms)
177
178-- | creates a matrix from a horizontal list of matrices
179joinHoriz :: Field t => [Matrix t] -> Matrix t
180joinHoriz ms = trans. joinVert . map trans $ ms
181
182------------------------------------------------------------------------------
183
184-- | Reverse rows
185flipud :: Field t => Matrix t -> Matrix t
186flipud m = fromRows . reverse . toRows $ m
187
188-- | Reverse columns
189fliprl :: Field t => Matrix t -> Matrix t
190fliprl m = fromColumns . reverse . toColumns $ m
191
163----------------------------------------------------------------- 192-----------------------------------------------------------------
164 193
165liftMatrix f m = m { dat = f dat, tdat = f tdat } -- check sizes 194liftMatrix f m = m { dat = f dat, tdat = f tdat } -- check sizes
@@ -168,7 +197,11 @@ liftMatrix f m = m { dat = f dat, tdat = f tdat } -- check sizes
168 197
169dotL a b = sum (zipWith (*) a b) 198dotL a b = sum (zipWith (*) a b)
170 199
171multiplyL a b = [[dotL x y | y <- transpose b] | x <- a] 200multiplyL a b | ok = [[dotL x y | y <- transpose b] | x <- a]
201 | otherwise = error "inconsistent dimensions in contraction "
202 where ok = case common length a of
203 Nothing -> False
204 Just c -> c == length b
172 205
173transL m = matrixFromVector RowMajor (rows m) $ transdataG (cols m) (cdat m) (rows m) 206transL m = matrixFromVector RowMajor (rows m) $ transdataG (cols m) (cdat m) (rows m)
174 207
@@ -201,9 +234,8 @@ foreign import ccall safe "aux.h multiplyC"
201 234
202multiply :: (Num a, Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a 235multiply :: (Num a, Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a
203multiply RowMajor a b = multiplyD RowMajor a b 236multiply RowMajor a b = multiplyD RowMajor a b
204multiply ColumnMajor a b = trans $ multiplyT ColumnMajor a b 237multiply ColumnMajor a b = m {rows = cols m, cols = rows m, order = ColumnMajor}
205 238 where m = multiplyD RowMajor (trans b) (trans a)
206multiplyT order a b = multiplyD order (trans b) (trans a)
207 239
208multiplyD order a b 240multiplyD order a b
209 | isReal (baseOf.dat) a = scast $ multiplyAux order cmultiplyR (scast a) (scast b) 241 | isReal (baseOf.dat) a = scast $ multiplyAux order cmultiplyR (scast a) (scast b)
@@ -253,3 +285,33 @@ subMatrix st sz m
253 285
254subMatrixG (r0,c0) (rt,ct) x = reshape ct $ fromList $ concat $ map (subList c0 ct) (subList r0 rt (toLists x)) 286subMatrixG (r0,c0) (rt,ct) x = reshape ct $ fromList $ concat $ map (subList c0 ct) (subList r0 rt (toLists x))
255 where subList s n = take n . drop s 287 where subList s n = take n . drop s
288
289---------------------------------------------------------------------
290
291diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do
292 m <- createMatrix RowMajor n n
293 fun // vec v // mat dat m // check msg [dat m]
294 return m
295
296-- | diagonal matrix from a real vector
297diagR :: Vector Double -> Matrix Double
298diagR = diagAux c_diagR "diagR"
299foreign import ccall "aux.h diagR" c_diagR :: Double :> Double ::> IO Int
300
301-- | diagonal matrix from a real vector
302diagC :: Vector (Complex Double) -> Matrix (Complex Double)
303diagC = diagAux c_diagC "diagC"
304foreign import ccall "aux.h diagC" c_diagC :: (Complex Double) :> (Complex Double) ::> IO Int
305
306-- | diagonal matrix from a vector
307diag :: (Num a, Field a) => Vector a -> Matrix a
308diag v
309 | isReal (baseOf) v = scast $ diagR (scast v)
310 | isComp (baseOf) v = scast $ diagC (scast v)
311 | otherwise = diagG v
312
313diagG v = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..c]]
314 where c = dim v
315 l = toList v
316 delta i j | i==j = 1
317 | otherwise = 0