diff options
Diffstat (limited to 'lib/Data/Packed/Internal/Matrix.hs')
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 116 |
1 files changed, 80 insertions, 36 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 2db4838..f9dd9a9 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -12,7 +12,7 @@ | |||
12 | -- Internal matrix representation | 12 | -- Internal matrix representation |
13 | -- | 13 | -- |
14 | ----------------------------------------------------------------------------- | 14 | ----------------------------------------------------------------------------- |
15 | -- --#hide | 15 | -- #hide |
16 | 16 | ||
17 | module Data.Packed.Internal.Matrix where | 17 | module Data.Packed.Internal.Matrix where |
18 | 18 | ||
@@ -57,10 +57,14 @@ import Data.Maybe(fromJust) | |||
57 | 57 | ||
58 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) | 58 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) |
59 | 59 | ||
60 | -- | Matrix representation suitable for GSL and LAPACK computations. | ||
60 | data Matrix t = MC { rows :: Int, cols :: Int, cdat :: Vector t, fdat :: Vector t } | 61 | data Matrix t = MC { rows :: Int, cols :: Int, cdat :: Vector t, fdat :: Vector t } |
61 | | MF { rows :: Int, cols :: Int, fdat :: Vector t, cdat :: Vector t } | 62 | | MF { rows :: Int, cols :: Int, fdat :: Vector t, cdat :: Vector t } |
62 | 63 | ||
63 | -- transposition just changes the data order | 64 | -- MC: preferred by C, fdat may require a transposition |
65 | -- MF: preferred by LAPACK, cdat may require a transposition | ||
66 | |||
67 | -- | matrix transpose | ||
64 | trans :: Matrix t -> Matrix t | 68 | trans :: Matrix t -> Matrix t |
65 | trans MC {rows = r, cols = c, cdat = d, fdat = dt } = MF {rows = c, cols = r, fdat = d, cdat = dt } | 69 | trans MC {rows = r, cols = c, cdat = d, fdat = dt } = MF {rows = c, cols = r, fdat = d, cdat = dt } |
66 | trans MF {rows = r, cols = c, fdat = d, cdat = dt } = MC {rows = c, cols = r, cdat = d, fdat = dt } | 70 | trans MF {rows = r, cols = c, fdat = d, cdat = dt } = MC {rows = c, cols = r, cdat = d, fdat = dt } |
@@ -166,32 +170,29 @@ compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2 | |||
166 | 170 | ||
167 | ---------------------------------------------------------------- | 171 | ---------------------------------------------------------------- |
168 | 172 | ||
169 | -- | element types for which optimized matrix computations are provided | 173 | -- | Optimized matrix computations are provided for elements in the Field class. |
170 | class Storable a => Field a where | 174 | class Storable a => Field a where |
171 | -- | @constant val n@ creates a vector with @n@ elements, all equal to @val@. | 175 | constantD :: a -> Int -> Vector a |
172 | constant :: a -> Int -> Vector a | ||
173 | transdata :: Int -> Vector a -> Int -> Vector a | 176 | transdata :: Int -> Vector a -> Int -> Vector a |
174 | multiplyD :: MatrixOrder -> Matrix a -> Matrix a -> Matrix a | 177 | multiplyD :: Matrix a -> Matrix a -> Matrix a |
175 | -- | extracts a submatrix froma a matrix | 178 | subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position |
176 | subMatrix :: (Int,Int) -- ^ (r0,c0) starting position | 179 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix |
177 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix | 180 | -> Matrix a -> Matrix a |
178 | -> Matrix a -> Matrix a | 181 | diagD :: Vector a -> Matrix a |
179 | -- | creates a square matrix with the given diagonal | ||
180 | diag :: Vector a -> Matrix a | ||
181 | 182 | ||
182 | instance Field Double where | 183 | instance Field Double where |
183 | constant = constantR | 184 | constantD = constantR |
184 | transdata = transdataR | 185 | transdata = transdataR |
185 | multiplyD = multiplyR | 186 | multiplyD = multiplyR |
186 | subMatrix = subMatrixR | 187 | subMatrixD = subMatrixR |
187 | diag = diagR | 188 | diagD = diagR |
188 | 189 | ||
189 | instance Field (Complex Double) where | 190 | instance Field (Complex Double) where |
190 | constant = constantC | 191 | constantD = constantC |
191 | transdata = transdataC | 192 | transdata = transdataC |
192 | multiplyD = multiplyC | 193 | multiplyD = multiplyC |
193 | subMatrix = subMatrixC | 194 | subMatrixD = subMatrixC |
194 | diag = diagC | 195 | diagD = diagC |
195 | 196 | ||
196 | ------------------------------------------------------------------ | 197 | ------------------------------------------------------------------ |
197 | 198 | ||
@@ -209,6 +210,15 @@ dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unw | |||
209 | 210 | ||
210 | ------------------------------------------------------------------ | 211 | ------------------------------------------------------------------ |
211 | 212 | ||
213 | (>|<) :: (Field a) => Int -> Int -> [a] -> Matrix a | ||
214 | r >|< c = f where | ||
215 | f l | dim v == r*c = matrixFromVector ColumnMajor c v | ||
216 | | otherwise = error $ "inconsistent list size = " | ||
217 | ++show (dim v) ++" in ("++show r++"><"++show c++")" | ||
218 | where v = fromList l | ||
219 | |||
220 | ------------------------------------------------------------------- | ||
221 | |||
212 | transdataR :: Int -> Vector Double -> Int -> Vector Double | 222 | transdataR :: Int -> Vector Double -> Int -> Vector Double |
213 | transdataR = transdataAux ctransR | 223 | transdataR = transdataAux ctransR |
214 | 224 | ||
@@ -237,10 +247,10 @@ foreign import ccall safe "aux.h transC" | |||
237 | gmatC MF {rows = r, cols = c, fdat = d} f = f 1 c r (ptr d) | 247 | gmatC MF {rows = r, cols = c, fdat = d} f = f 1 c r (ptr d) |
238 | gmatC MC {rows = r, cols = c, cdat = d} f = f 0 r c (ptr d) | 248 | gmatC MC {rows = r, cols = c, cdat = d} f = f 0 r c (ptr d) |
239 | 249 | ||
240 | multiplyAux fun order a b = unsafePerformIO $ do | 250 | multiplyAux fun a b = unsafePerformIO $ do |
241 | when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ | 251 | when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ |
242 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) | 252 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) |
243 | r <- createMatrix order (rows a) (cols b) | 253 | r <- createMatrix RowMajor (rows a) (cols b) |
244 | fun // gmatC a // gmatC b // mat dat r // check "multiplyAux" [dat a, dat b] | 254 | fun // gmatC a // gmatC b // mat dat r // check "multiplyAux" [dat a, dat b] |
245 | return r | 255 | return r |
246 | 256 | ||
@@ -258,33 +268,41 @@ foreign import ccall safe "aux.h multiplyC" | |||
258 | -> Int -> Int -> Ptr (Complex Double) | 268 | -> Int -> Int -> Ptr (Complex Double) |
259 | -> IO Int | 269 | -> IO Int |
260 | 270 | ||
261 | multiply :: (Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a | 271 | multiply' :: (Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a |
262 | multiply RowMajor a b = multiplyD RowMajor a b | 272 | multiply' RowMajor a b = multiplyD a b |
263 | multiply ColumnMajor a b = MF {rows = c, cols = r, fdat = d, cdat = dt } | 273 | multiply' ColumnMajor a b = trans $ multiplyD (trans b) (trans a) |
264 | where MC {rows = r, cols = c, cdat = d, fdat = dt } = multiplyD RowMajor (trans b) (trans a) | 274 | |
265 | -- FIXME using MatrixFromVector | 275 | |
276 | -- | matrix product | ||
277 | multiply :: (Field a) => Matrix a -> Matrix a -> Matrix a | ||
278 | multiply = multiplyD | ||
266 | 279 | ||
267 | ---------------------------------------------------------------------- | 280 | ---------------------------------------------------------------------- |
268 | 281 | ||
269 | -- | extraction of a submatrix of a real matrix | 282 | -- | extraction of a submatrix from a real matrix |
270 | subMatrixR :: (Int,Int) -- ^ (r0,c0) starting position | 283 | subMatrixR :: (Int,Int) -> (Int,Int) -> Matrix Double -> Matrix Double |
271 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix | ||
272 | -> Matrix Double -> Matrix Double | ||
273 | subMatrixR (r0,c0) (rt,ct) x = unsafePerformIO $ do | 284 | subMatrixR (r0,c0) (rt,ct) x = unsafePerformIO $ do |
274 | r <- createMatrix RowMajor rt ct | 285 | r <- createMatrix RowMajor rt ct |
275 | c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // mat cdat x // mat dat r // check "subMatrixR" [dat r] | 286 | c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // mat cdat x // mat dat r // check "subMatrixR" [dat r] |
276 | return r | 287 | return r |
277 | foreign import ccall "aux.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM | 288 | foreign import ccall "aux.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM |
278 | 289 | ||
279 | -- | extraction of a submatrix of a complex matrix | 290 | -- | extraction of a submatrix from a complex matrix |
280 | subMatrixC :: (Int,Int) -- ^ (r0,c0) starting position | 291 | subMatrixC :: (Int,Int) -> (Int,Int) -> Matrix (Complex Double) -> Matrix (Complex Double) |
281 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix | ||
282 | -> Matrix (Complex Double) -> Matrix (Complex Double) | ||
283 | subMatrixC (r0,c0) (rt,ct) x = | 292 | subMatrixC (r0,c0) (rt,ct) x = |
284 | reshape ct . asComplex . cdat . | 293 | reshape ct . asComplex . cdat . |
285 | subMatrixR (r0,2*c0) (rt,2*ct) . | 294 | subMatrixR (r0,2*c0) (rt,2*ct) . |
286 | reshape (2*cols x) . asReal . cdat $ x | 295 | reshape (2*cols x) . asReal . cdat $ x |
287 | 296 | ||
297 | -- | Extracts a submatrix from a matrix. | ||
298 | subMatrix :: Field a | ||
299 | => (Int,Int) -- ^ (r0,c0) starting position | ||
300 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix | ||
301 | -> Matrix a -- ^ input matrix | ||
302 | -> Matrix a -- ^ result | ||
303 | subMatrix = subMatrixD | ||
304 | |||
305 | |||
288 | --------------------------------------------------------------------- | 306 | --------------------------------------------------------------------- |
289 | 307 | ||
290 | diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do | 308 | diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do |
@@ -302,6 +320,10 @@ diagC :: Vector (Complex Double) -> Matrix (Complex Double) | |||
302 | diagC = diagAux c_diagC "diagC" | 320 | diagC = diagAux c_diagC "diagC" |
303 | foreign import ccall "aux.h diagC" c_diagC :: TCVCM | 321 | foreign import ccall "aux.h diagC" c_diagC :: TCVCM |
304 | 322 | ||
323 | -- | creates a square matrix with the given diagonal | ||
324 | diag :: Field a => Vector a -> Matrix a | ||
325 | diag = diagD | ||
326 | |||
305 | ------------------------------------------------------------------------ | 327 | ------------------------------------------------------------------------ |
306 | 328 | ||
307 | constantAux fun x n = unsafePerformIO $ do | 329 | constantAux fun x n = unsafePerformIO $ do |
@@ -321,6 +343,28 @@ constantC = constantAux cconstantC | |||
321 | foreign import ccall safe "aux.h constantC" | 343 | foreign import ccall safe "aux.h constantC" |
322 | cconstantC :: Ptr (Complex Double) -> TCV -- Complex Double :> IO Int | 344 | cconstantC :: Ptr (Complex Double) -> TCV -- Complex Double :> IO Int |
323 | 345 | ||
346 | {- | creates a vector with a given number of equal components: | ||
347 | |||
348 | @> constant 2 7 | ||
349 | 7 |> [2.0,2.0,2.0,2.0,2.0,2.0,2.0]@ | ||
350 | -} | ||
351 | constant :: Field a => a -> Int -> Vector a | ||
352 | constant = constantD | ||
353 | |||
354 | -------------------------------------------------------------------------- | ||
355 | |||
356 | -- | obtains the complex conjugate of a complex vector | ||
357 | conj :: Vector (Complex Double) -> Vector (Complex Double) | ||
358 | conj v = asComplex $ cdat $ reshape 2 (asReal v) `multiply` diag (fromList [1,-1]) | ||
359 | |||
360 | -- | creates a complex vector from vectors with real and imaginary parts | ||
361 | toComplex :: (Vector Double, Vector Double) -> Vector (Complex Double) | ||
362 | toComplex (r,i) = asComplex $ cdat $ fromColumns [r,i] | ||
363 | |||
364 | -- | converts a real vector into a complex representation (with zero imaginary parts) | ||
365 | comp :: Vector Double -> Vector (Complex Double) | ||
366 | comp v = toComplex (v,constant 0 (dim v)) | ||
367 | |||
324 | ------------------------------------------------------------------------- | 368 | ------------------------------------------------------------------------- |
325 | 369 | ||
326 | -- Generic definitions | 370 | -- Generic definitions |