diff options
-rw-r--r-- | packages/base/src/Internal/Matrix.hs (renamed from packages/base/src/Data/Packed/Internal/Matrix.hs) | 163 |
1 files changed, 53 insertions, 110 deletions
diff --git a/packages/base/src/Data/Packed/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 66eef15..44365d0 100644 --- a/packages/base/src/Data/Packed/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -2,10 +2,11 @@ | |||
2 | {-# LANGUAGE FlexibleContexts #-} | 2 | {-# LANGUAGE FlexibleContexts #-} |
3 | {-# LANGUAGE FlexibleInstances #-} | 3 | {-# LANGUAGE FlexibleInstances #-} |
4 | {-# LANGUAGE BangPatterns #-} | 4 | {-# LANGUAGE BangPatterns #-} |
5 | {-# LANGUAGE TypeOperators #-} | ||
5 | 6 | ||
6 | -- | | 7 | -- | |
7 | -- Module : Data.Packed.Internal.Matrix | 8 | -- Module : Internal.Matrix |
8 | -- Copyright : (c) Alberto Ruiz 2007 | 9 | -- Copyright : (c) Alberto Ruiz 2007-15 |
9 | -- License : BSD3 | 10 | -- License : BSD3 |
10 | -- Maintainer : Alberto Ruiz | 11 | -- Maintainer : Alberto Ruiz |
11 | -- Stability : provisional | 12 | -- Stability : provisional |
@@ -13,36 +14,23 @@ | |||
13 | -- Internal matrix representation | 14 | -- Internal matrix representation |
14 | -- | 15 | -- |
15 | 16 | ||
16 | module Data.Packed.Internal.Matrix( | 17 | module Internal.Matrix where |
17 | Matrix(..), rows, cols, cdat, fdat, | 18 | |
18 | MatrixOrder(..), orderOf, | 19 | |
19 | createMatrix, mat, omat, | 20 | import Internal.Tools ( splitEvery, fi, compatdim, (//) ) |
20 | cmat, fmat, | 21 | import Internal.Vector |
21 | toLists, flatten, reshape, | 22 | import Internal.Devel |
22 | Element(..), | 23 | import Internal.Vectorized |
23 | trans, | 24 | import Data.Vector.Storable ( unsafeWith, fromList ) |
24 | fromRows, toRows, fromColumns, toColumns, | 25 | import Foreign.Marshal.Alloc ( free ) |
25 | matrixFromVector, | 26 | import Foreign.Ptr ( Ptr ) |
26 | subMatrix, | 27 | import Foreign.Storable ( Storable ) |
27 | liftMatrix, liftMatrix2, | 28 | import Data.Complex ( Complex ) |
28 | (@@>), atM', | 29 | import Foreign.C.Types ( CInt(..) ) |
29 | singleton, | 30 | import Foreign.C.String ( CString, newCString ) |
30 | emptyM, | 31 | import System.IO.Unsafe ( unsafePerformIO ) |
31 | size, shSize, conformVs, conformMs, conformVTo, conformMTo | 32 | import Control.DeepSeq ( NFData(..) ) |
32 | ) where | 33 | |
33 | |||
34 | import Data.Packed.Internal.Common | ||
35 | import Data.Packed.Internal.Signatures | ||
36 | import Data.Packed.Internal.Vector | ||
37 | |||
38 | import Foreign.Marshal.Alloc(alloca, free) | ||
39 | import Foreign.Marshal.Array(newArray) | ||
40 | import Foreign.Ptr(Ptr, castPtr) | ||
41 | import Foreign.Storable(Storable, peekElemOff, pokeElemOff, poke, sizeOf) | ||
42 | import Data.Complex(Complex) | ||
43 | import Foreign.C.Types | ||
44 | import System.IO.Unsafe(unsafePerformIO) | ||
45 | import Control.DeepSeq | ||
46 | 34 | ||
47 | ----------------------------------------------------------------- | 35 | ----------------------------------------------------------------- |
48 | 36 | ||
@@ -93,8 +81,8 @@ data Matrix t = Matrix { irows :: {-# UNPACK #-} !Int | |||
93 | -- RowMajor: preferred by C, fdat may require a transposition | 81 | -- RowMajor: preferred by C, fdat may require a transposition |
94 | -- ColumnMajor: preferred by LAPACK, cdat may require a transposition | 82 | -- ColumnMajor: preferred by LAPACK, cdat may require a transposition |
95 | 83 | ||
96 | cdat = xdat | 84 | --cdat = xdat |
97 | fdat = xdat | 85 | --fdat = xdat |
98 | 86 | ||
99 | rows :: Matrix t -> Int | 87 | rows :: Matrix t -> Int |
100 | rows = irows | 88 | rows = irows |
@@ -204,9 +192,7 @@ toColumns m = toRows . trans $ m | |||
204 | (@@>) :: Storable t => Matrix t -> (Int,Int) -> t | 192 | (@@>) :: Storable t => Matrix t -> (Int,Int) -> t |
205 | infixl 9 @@> | 193 | infixl 9 @@> |
206 | m@Matrix {irows = r, icols = c} @@> (i,j) | 194 | m@Matrix {irows = r, icols = c} @@> (i,j) |
207 | | safe = if i<0 || i>=r || j<0 || j>=c | 195 | | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" |
208 | then error "matrix indexing out of range" | ||
209 | else atM' m i j | ||
210 | | otherwise = atM' m i j | 196 | | otherwise = atM' m i j |
211 | {-# INLINE (@@>) #-} | 197 | {-# INLINE (@@>) #-} |
212 | 198 | ||
@@ -243,7 +229,7 @@ reshape :: Storable t => Int -> Vector t -> Matrix t | |||
243 | reshape 0 v = matrixFromVector RowMajor 0 0 v | 229 | reshape 0 v = matrixFromVector RowMajor 0 0 v |
244 | reshape c v = matrixFromVector RowMajor (dim v `div` c) c v | 230 | reshape c v = matrixFromVector RowMajor (dim v `div` c) c v |
245 | 231 | ||
246 | singleton x = reshape 1 (fromList [x]) | 232 | --singleton x = reshape 1 (fromList [x]) |
247 | 233 | ||
248 | -- | application of a vector function on the flattened matrix elements | 234 | -- | application of a vector function on the flattened matrix elements |
249 | liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b | 235 | liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b |
@@ -273,14 +259,8 @@ compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2 | |||
273 | >instance Element Foo | 259 | >instance Element Foo |
274 | -} | 260 | -} |
275 | class (Storable a) => Element a where | 261 | class (Storable a) => Element a where |
276 | subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position | ||
277 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix | ||
278 | -> Matrix a -> Matrix a | ||
279 | subMatrixD = subMatrix' | ||
280 | transdata :: Int -> Vector a -> Int -> Vector a | 262 | transdata :: Int -> Vector a -> Int -> Vector a |
281 | transdata = transdataP -- transdata' | ||
282 | constantD :: a -> Int -> Vector a | 263 | constantD :: a -> Int -> Vector a |
283 | constantD = constantP -- constant' | ||
284 | extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> Matrix a | 264 | extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> Matrix a |
285 | sortI :: Ord a => Vector a -> Vector CInt | 265 | sortI :: Ord a => Vector a -> Vector CInt |
286 | sortV :: Ord a => Vector a -> Vector a | 266 | sortV :: Ord a => Vector a -> Vector a |
@@ -357,57 +337,14 @@ transdataAux fun c1 d c2 = | |||
357 | r2 = dim d `div` c2 | 337 | r2 = dim d `div` c2 |
358 | noneed = dim d == 0 || r1 == 1 || c1 == 1 | 338 | noneed = dim d == 0 || r1 == 1 || c1 == 1 |
359 | 339 | ||
360 | transdataP :: Storable a => Int -> Vector a -> Int -> Vector a | ||
361 | transdataP c1 d c2 = | ||
362 | if noneed | ||
363 | then d | ||
364 | else unsafePerformIO $ do | ||
365 | v <- createVector (dim d) | ||
366 | unsafeWith d $ \pd -> | ||
367 | unsafeWith v $ \pv -> | ||
368 | ctransP (fi r1) (fi c1) (castPtr pd) (fi sz) (fi r2) (fi c2) (castPtr pv) (fi sz) // check "transdataP" | ||
369 | return v | ||
370 | where r1 = dim d `div` c1 | ||
371 | r2 = dim d `div` c2 | ||
372 | sz = sizeOf (d @> 0) | ||
373 | noneed = dim d == 0 || r1 == 1 || c1 == 1 | ||
374 | |||
375 | foreign import ccall unsafe "transF" ctransF :: TFMFM | ||
376 | foreign import ccall unsafe "transR" ctransR :: TMM | ||
377 | foreign import ccall unsafe "transQ" ctransQ :: TQMQM | ||
378 | foreign import ccall unsafe "transC" ctransC :: TCMCM | ||
379 | foreign import ccall unsafe "transI" ctransI :: CM CInt (CM CInt (IO CInt)) | ||
380 | foreign import ccall unsafe "transP" ctransP :: CInt -> CInt -> Ptr () -> CInt -> CInt -> CInt -> Ptr () -> CInt -> IO CInt | ||
381 | |||
382 | ---------------------------------------------------------------------- | ||
383 | |||
384 | constantAux fun x n = unsafePerformIO $ do | ||
385 | v <- createVector n | ||
386 | px <- newArray [x] | ||
387 | app1 (fun px) vec v "constantAux" | ||
388 | free px | ||
389 | return v | ||
390 | |||
391 | foreign import ccall unsafe "constantF" cconstantF :: Ptr Float -> TF | ||
392 | |||
393 | foreign import ccall unsafe "constantR" cconstantR :: Ptr Double -> TV | ||
394 | 340 | ||
395 | foreign import ccall unsafe "constantQ" cconstantQ :: Ptr (Complex Float) -> TQV | 341 | type TMM t = t ..> t ..> Ok |
396 | 342 | ||
397 | foreign import ccall unsafe "constantC" cconstantC :: Ptr (Complex Double) -> TCV | 343 | foreign import ccall unsafe "transF" ctransF :: TMM Float |
398 | 344 | foreign import ccall unsafe "transR" ctransR :: TMM Double | |
399 | foreign import ccall unsafe "constantI" cconstantI :: Ptr CInt -> CV CInt (IO CInt) | 345 | foreign import ccall unsafe "transQ" ctransQ :: TMM (Complex Float) |
400 | 346 | foreign import ccall unsafe "transC" ctransC :: TMM (Complex Double) | |
401 | constantP :: Storable a => a -> Int -> Vector a | 347 | foreign import ccall unsafe "transI" ctransI :: TMM CInt |
402 | constantP a n = unsafePerformIO $ do | ||
403 | let sz = sizeOf a | ||
404 | v <- createVector n | ||
405 | unsafeWith v $ \p -> do | ||
406 | alloca $ \k -> do | ||
407 | poke k a | ||
408 | cconstantP (castPtr k) (fi n) (castPtr p) (fi sz) // check "constantP" | ||
409 | return v | ||
410 | foreign import ccall unsafe "constantP" cconstantP :: Ptr () -> CInt -> Ptr () -> CInt -> IO CInt | ||
411 | 348 | ||
412 | ---------------------------------------------------------------------- | 349 | ---------------------------------------------------------------------- |
413 | 350 | ||
@@ -418,26 +355,11 @@ subMatrix :: Element a | |||
418 | -> Matrix a -- ^ input matrix | 355 | -> Matrix a -- ^ input matrix |
419 | -> Matrix a -- ^ result | 356 | -> Matrix a -- ^ result |
420 | subMatrix (r0,c0) (rt,ct) m | 357 | subMatrix (r0,c0) (rt,ct) m |
421 | | 0 <= r0 && 0 <= rt && r0+rt <= (rows m) && | 358 | | 0 <= r0 && 0 <= rt && r0+rt <= rows m && |
422 | 0 <= c0 && 0 <= ct && c0+ct <= (cols m) = subMatrixD (r0,c0) (rt,ct) m | 359 | 0 <= c0 && 0 <= ct && c0+ct <= cols m = extractR m 0 (idxs[r0,r0+rt-1]) 0 (idxs[c0,c0+ct-1]) |
423 | | otherwise = error $ "wrong subMatrix "++ | 360 | | otherwise = error $ "wrong subMatrix "++ |
424 | show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m) | 361 | show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m) |
425 | 362 | ||
426 | subMatrix'' (r0,c0) (rt,ct) c v = unsafePerformIO $ do | ||
427 | w <- createVector (rt*ct) | ||
428 | unsafeWith v $ \p -> | ||
429 | unsafeWith w $ \q -> do | ||
430 | let go (-1) _ = return () | ||
431 | go !i (-1) = go (i-1) (ct-1) | ||
432 | go !i !j = do x <- peekElemOff p ((i+r0)*c+j+c0) | ||
433 | pokeElemOff q (i*ct+j) x | ||
434 | go i (j-1) | ||
435 | go (rt-1) (ct-1) | ||
436 | return w | ||
437 | |||
438 | subMatrix' (r0,c0) (rt,ct) (Matrix { icols = c, xdat = v, order = RowMajor}) = Matrix rt ct (subMatrix'' (r0,c0) (rt,ct) c v) RowMajor | ||
439 | subMatrix' (r0,c0) (rt,ct) m = trans $ subMatrix' (c0,r0) (ct,rt) (trans m) | ||
440 | |||
441 | -------------------------------------------------------------------------- | 363 | -------------------------------------------------------------------------- |
442 | 364 | ||
443 | maxZ xs = if minimum xs == 0 then 0 else maximum xs | 365 | maxZ xs = if minimum xs == 0 then 0 else maximum xs |
@@ -580,4 +502,25 @@ foreign import ccall unsafe "remapI" c_remapI :: Rem CInt | |||
580 | foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double) | 502 | foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double) |
581 | foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float) | 503 | foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float) |
582 | 504 | ||
505 | -------------------------------------------------------------------------------- | ||
506 | |||
507 | foreign import ccall unsafe "saveMatrix" c_saveMatrix | ||
508 | :: CString -> CString -> Double ..> Ok | ||
509 | |||
510 | {- | save a matrix as a 2D ASCII table | ||
511 | -} | ||
512 | saveMatrix | ||
513 | :: FilePath | ||
514 | -> String -- ^ \"printf\" format (e.g. \"%.2f\", \"%g\", etc.) | ||
515 | -> Matrix Double | ||
516 | -> IO () | ||
517 | saveMatrix name format m = do | ||
518 | cname <- newCString name | ||
519 | cformat <- newCString format | ||
520 | app1 (c_saveMatrix cname cformat) mat m "saveMatrix" | ||
521 | free cname | ||
522 | free cformat | ||
523 | return () | ||
524 | |||
525 | -------------------------------------------------------------------------------- | ||
583 | 526 | ||