summaryrefslogtreecommitdiff
path: root/lib/Data/Packed/Internal/Matrix.hs
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Data/Packed/Internal/Matrix.hs')
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs116
1 files changed, 80 insertions, 36 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index 2db4838..f9dd9a9 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -12,7 +12,7 @@
12-- Internal matrix representation 12-- Internal matrix representation
13-- 13--
14----------------------------------------------------------------------------- 14-----------------------------------------------------------------------------
15-- --#hide 15-- #hide
16 16
17module Data.Packed.Internal.Matrix where 17module Data.Packed.Internal.Matrix where
18 18
@@ -57,10 +57,14 @@ import Data.Maybe(fromJust)
57 57
58data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) 58data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq)
59 59
60-- | Matrix representation suitable for GSL and LAPACK computations.
60data Matrix t = MC { rows :: Int, cols :: Int, cdat :: Vector t, fdat :: Vector t } 61data Matrix t = MC { rows :: Int, cols :: Int, cdat :: Vector t, fdat :: Vector t }
61 | MF { rows :: Int, cols :: Int, fdat :: Vector t, cdat :: Vector t } 62 | MF { rows :: Int, cols :: Int, fdat :: Vector t, cdat :: Vector t }
62 63
63-- transposition just changes the data order 64-- MC: preferred by C, fdat may require a transposition
65-- MF: preferred by LAPACK, cdat may require a transposition
66
67-- | matrix transpose
64trans :: Matrix t -> Matrix t 68trans :: Matrix t -> Matrix t
65trans MC {rows = r, cols = c, cdat = d, fdat = dt } = MF {rows = c, cols = r, fdat = d, cdat = dt } 69trans MC {rows = r, cols = c, cdat = d, fdat = dt } = MF {rows = c, cols = r, fdat = d, cdat = dt }
66trans MF {rows = r, cols = c, fdat = d, cdat = dt } = MC {rows = c, cols = r, cdat = d, fdat = dt } 70trans MF {rows = r, cols = c, fdat = d, cdat = dt } = MC {rows = c, cols = r, cdat = d, fdat = dt }
@@ -166,32 +170,29 @@ compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2
166 170
167---------------------------------------------------------------- 171----------------------------------------------------------------
168 172
169-- | element types for which optimized matrix computations are provided 173-- | Optimized matrix computations are provided for elements in the Field class.
170class Storable a => Field a where 174class Storable a => Field a where
171 -- | @constant val n@ creates a vector with @n@ elements, all equal to @val@. 175 constantD :: a -> Int -> Vector a
172 constant :: a -> Int -> Vector a
173 transdata :: Int -> Vector a -> Int -> Vector a 176 transdata :: Int -> Vector a -> Int -> Vector a
174 multiplyD :: MatrixOrder -> Matrix a -> Matrix a -> Matrix a 177 multiplyD :: Matrix a -> Matrix a -> Matrix a
175 -- | extracts a submatrix froma a matrix 178 subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position
176 subMatrix :: (Int,Int) -- ^ (r0,c0) starting position 179 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
177 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix 180 -> Matrix a -> Matrix a
178 -> Matrix a -> Matrix a 181 diagD :: Vector a -> Matrix a
179 -- | creates a square matrix with the given diagonal
180 diag :: Vector a -> Matrix a
181 182
182instance Field Double where 183instance Field Double where
183 constant = constantR 184 constantD = constantR
184 transdata = transdataR 185 transdata = transdataR
185 multiplyD = multiplyR 186 multiplyD = multiplyR
186 subMatrix = subMatrixR 187 subMatrixD = subMatrixR
187 diag = diagR 188 diagD = diagR
188 189
189instance Field (Complex Double) where 190instance Field (Complex Double) where
190 constant = constantC 191 constantD = constantC
191 transdata = transdataC 192 transdata = transdataC
192 multiplyD = multiplyC 193 multiplyD = multiplyC
193 subMatrix = subMatrixC 194 subMatrixD = subMatrixC
194 diag = diagC 195 diagD = diagC
195 196
196------------------------------------------------------------------ 197------------------------------------------------------------------
197 198
@@ -209,6 +210,15 @@ dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unw
209 210
210------------------------------------------------------------------ 211------------------------------------------------------------------
211 212
213(>|<) :: (Field a) => Int -> Int -> [a] -> Matrix a
214r >|< c = f where
215 f l | dim v == r*c = matrixFromVector ColumnMajor c v
216 | otherwise = error $ "inconsistent list size = "
217 ++show (dim v) ++" in ("++show r++"><"++show c++")"
218 where v = fromList l
219
220-------------------------------------------------------------------
221
212transdataR :: Int -> Vector Double -> Int -> Vector Double 222transdataR :: Int -> Vector Double -> Int -> Vector Double
213transdataR = transdataAux ctransR 223transdataR = transdataAux ctransR
214 224
@@ -237,10 +247,10 @@ foreign import ccall safe "aux.h transC"
237gmatC MF {rows = r, cols = c, fdat = d} f = f 1 c r (ptr d) 247gmatC MF {rows = r, cols = c, fdat = d} f = f 1 c r (ptr d)
238gmatC MC {rows = r, cols = c, cdat = d} f = f 0 r c (ptr d) 248gmatC MC {rows = r, cols = c, cdat = d} f = f 0 r c (ptr d)
239 249
240multiplyAux fun order a b = unsafePerformIO $ do 250multiplyAux fun a b = unsafePerformIO $ do
241 when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ 251 when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++
242 show (rows a,cols a) ++ " x " ++ show (rows b, cols b) 252 show (rows a,cols a) ++ " x " ++ show (rows b, cols b)
243 r <- createMatrix order (rows a) (cols b) 253 r <- createMatrix RowMajor (rows a) (cols b)
244 fun // gmatC a // gmatC b // mat dat r // check "multiplyAux" [dat a, dat b] 254 fun // gmatC a // gmatC b // mat dat r // check "multiplyAux" [dat a, dat b]
245 return r 255 return r
246 256
@@ -258,33 +268,41 @@ foreign import ccall safe "aux.h multiplyC"
258 -> Int -> Int -> Ptr (Complex Double) 268 -> Int -> Int -> Ptr (Complex Double)
259 -> IO Int 269 -> IO Int
260 270
261multiply :: (Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a 271multiply' :: (Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a
262multiply RowMajor a b = multiplyD RowMajor a b 272multiply' RowMajor a b = multiplyD a b
263multiply ColumnMajor a b = MF {rows = c, cols = r, fdat = d, cdat = dt } 273multiply' ColumnMajor a b = trans $ multiplyD (trans b) (trans a)
264 where MC {rows = r, cols = c, cdat = d, fdat = dt } = multiplyD RowMajor (trans b) (trans a) 274
265-- FIXME using MatrixFromVector 275
276-- | matrix product
277multiply :: (Field a) => Matrix a -> Matrix a -> Matrix a
278multiply = multiplyD
266 279
267---------------------------------------------------------------------- 280----------------------------------------------------------------------
268 281
269-- | extraction of a submatrix of a real matrix 282-- | extraction of a submatrix from a real matrix
270subMatrixR :: (Int,Int) -- ^ (r0,c0) starting position 283subMatrixR :: (Int,Int) -> (Int,Int) -> Matrix Double -> Matrix Double
271 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
272 -> Matrix Double -> Matrix Double
273subMatrixR (r0,c0) (rt,ct) x = unsafePerformIO $ do 284subMatrixR (r0,c0) (rt,ct) x = unsafePerformIO $ do
274 r <- createMatrix RowMajor rt ct 285 r <- createMatrix RowMajor rt ct
275 c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // mat cdat x // mat dat r // check "subMatrixR" [dat r] 286 c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // mat cdat x // mat dat r // check "subMatrixR" [dat r]
276 return r 287 return r
277foreign import ccall "aux.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM 288foreign import ccall "aux.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM
278 289
279-- | extraction of a submatrix of a complex matrix 290-- | extraction of a submatrix from a complex matrix
280subMatrixC :: (Int,Int) -- ^ (r0,c0) starting position 291subMatrixC :: (Int,Int) -> (Int,Int) -> Matrix (Complex Double) -> Matrix (Complex Double)
281 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
282 -> Matrix (Complex Double) -> Matrix (Complex Double)
283subMatrixC (r0,c0) (rt,ct) x = 292subMatrixC (r0,c0) (rt,ct) x =
284 reshape ct . asComplex . cdat . 293 reshape ct . asComplex . cdat .
285 subMatrixR (r0,2*c0) (rt,2*ct) . 294 subMatrixR (r0,2*c0) (rt,2*ct) .
286 reshape (2*cols x) . asReal . cdat $ x 295 reshape (2*cols x) . asReal . cdat $ x
287 296
297-- | Extracts a submatrix from a matrix.
298subMatrix :: Field a
299 => (Int,Int) -- ^ (r0,c0) starting position
300 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
301 -> Matrix a -- ^ input matrix
302 -> Matrix a -- ^ result
303subMatrix = subMatrixD
304
305
288--------------------------------------------------------------------- 306---------------------------------------------------------------------
289 307
290diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do 308diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do
@@ -302,6 +320,10 @@ diagC :: Vector (Complex Double) -> Matrix (Complex Double)
302diagC = diagAux c_diagC "diagC" 320diagC = diagAux c_diagC "diagC"
303foreign import ccall "aux.h diagC" c_diagC :: TCVCM 321foreign import ccall "aux.h diagC" c_diagC :: TCVCM
304 322
323-- | creates a square matrix with the given diagonal
324diag :: Field a => Vector a -> Matrix a
325diag = diagD
326
305------------------------------------------------------------------------ 327------------------------------------------------------------------------
306 328
307constantAux fun x n = unsafePerformIO $ do 329constantAux fun x n = unsafePerformIO $ do
@@ -321,6 +343,28 @@ constantC = constantAux cconstantC
321foreign import ccall safe "aux.h constantC" 343foreign import ccall safe "aux.h constantC"
322 cconstantC :: Ptr (Complex Double) -> TCV -- Complex Double :> IO Int 344 cconstantC :: Ptr (Complex Double) -> TCV -- Complex Double :> IO Int
323 345
346{- | creates a vector with a given number of equal components:
347
348@> constant 2 7
3497 |> [2.0,2.0,2.0,2.0,2.0,2.0,2.0]@
350-}
351constant :: Field a => a -> Int -> Vector a
352constant = constantD
353
354--------------------------------------------------------------------------
355
356-- | obtains the complex conjugate of a complex vector
357conj :: Vector (Complex Double) -> Vector (Complex Double)
358conj v = asComplex $ cdat $ reshape 2 (asReal v) `multiply` diag (fromList [1,-1])
359
360-- | creates a complex vector from vectors with real and imaginary parts
361toComplex :: (Vector Double, Vector Double) -> Vector (Complex Double)
362toComplex (r,i) = asComplex $ cdat $ fromColumns [r,i]
363
364-- | converts a real vector into a complex representation (with zero imaginary parts)
365comp :: Vector Double -> Vector (Complex Double)
366comp v = toComplex (v,constant 0 (dim v))
367
324------------------------------------------------------------------------- 368-------------------------------------------------------------------------
325 369
326-- Generic definitions 370-- Generic definitions