diff options
author | Alberto Ruiz <aruiz@um.es> | 2015-06-30 14:38:52 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2015-06-30 14:38:52 +0200 |
commit | 4ada25636995115f2b26707870f611a138f4e20b (patch) | |
tree | 07ade269d5b98ae8284d2064bb3a8ea928e3e405 /packages | |
parent | 4730254f061832591d4a44c86d3bdfa4620f4322 (diff) |
subMatrix changed to non copying slice
Diffstat (limited to 'packages')
-rw-r--r-- | packages/base/src/Internal/LAPACK.hs | 2 | ||||
-rw-r--r-- | packages/base/src/Internal/Matrix.hs | 18 | ||||
-rw-r--r-- | packages/base/src/Internal/ST.hs | 2 | ||||
-rw-r--r-- | packages/base/src/Internal/Util.hs | 2 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Devel.hs | 2 | ||||
-rw-r--r-- | packages/tests/src/Numeric/LinearAlgebra/Tests.hs | 2 |
6 files changed, 9 insertions, 19 deletions
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index 13340f2..f2fc68d 100644 --- a/packages/base/src/Internal/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs | |||
@@ -513,7 +513,7 @@ qrgrC :: Int -> (Matrix (Complex Double), Vector (Complex Double)) -> Matrix (Co | |||
513 | qrgrC = qrgrAux zungqr "qrgrC" | 513 | qrgrC = qrgrAux zungqr "qrgrC" |
514 | 514 | ||
515 | qrgrAux f st n (a, tau) = unsafePerformIO $ do | 515 | qrgrAux f st n (a, tau) = unsafePerformIO $ do |
516 | res <- copy ColumnMajor (sliceMatrix (0,0) (rows a,n) a) | 516 | res <- copy ColumnMajor (subMatrix (0,0) (rows a,n) a) |
517 | f # (subVector 0 n tau') # res #| st | 517 | f # (subVector 0 n tau') # res #| st |
518 | return res | 518 | return res |
519 | where | 519 | where |
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index df56207..5163421 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -70,7 +70,7 @@ is1d (size->(r,c)) = r==1 || c==1 | |||
70 | {-# INLINE is1d #-} | 70 | {-# INLINE is1d #-} |
71 | 71 | ||
72 | -- data is not contiguous | 72 | -- data is not contiguous |
73 | isSlice m@(size->(r,c)) = (c < xRow m || r < xCol m) && min r c > 1 | 73 | isSlice m@(size->(r,c)) = r*c < dim (xdat m) |
74 | {-# INLINE isSlice #-} | 74 | {-# INLINE isSlice #-} |
75 | 75 | ||
76 | orderOf :: Matrix t -> MatrixOrder | 76 | orderOf :: Matrix t -> MatrixOrder |
@@ -359,26 +359,16 @@ instance Element Z where | |||
359 | 359 | ||
360 | ------------------------------------------------------------------- | 360 | ------------------------------------------------------------------- |
361 | 361 | ||
362 | -- | reference to a rectangular slice of a matrix (no data copy) | ||
362 | subMatrix :: Element a | 363 | subMatrix :: Element a |
363 | => (Int,Int) -- ^ (r0,c0) starting position | ||
364 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix | ||
365 | -> Matrix a -- ^ input matrix | ||
366 | -> Matrix a -- ^ result | ||
367 | subMatrix (r0,c0) (rt,ct) m | ||
368 | | 0 <= r0 && 0 <= rt && r0+rt <= rows m && | ||
369 | 0 <= c0 && 0 <= ct && c0+ct <= cols m = unsafePerformIO $ extractR RowMajor m 0 (idxs[r0,r0+rt-1]) 0 (idxs[c0,c0+ct-1]) | ||
370 | | otherwise = error $ "wrong subMatrix "++show ((r0,c0),(rt,ct))++" of "++shSize m | ||
371 | |||
372 | |||
373 | sliceMatrix :: Element a | ||
374 | => (Int,Int) -- ^ (r0,c0) starting position | 364 | => (Int,Int) -- ^ (r0,c0) starting position |
375 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix | 365 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix |
376 | -> Matrix a -- ^ input matrix | 366 | -> Matrix a -- ^ input matrix |
377 | -> Matrix a -- ^ result | 367 | -> Matrix a -- ^ result |
378 | sliceMatrix (r0,c0) (rt,ct) m | 368 | subMatrix (r0,c0) (rt,ct) m |
379 | | 0 <= r0 && 0 <= rt && r0+rt <= rows m && | 369 | | 0 <= r0 && 0 <= rt && r0+rt <= rows m && |
380 | 0 <= c0 && 0 <= ct && c0+ct <= cols m = res | 370 | 0 <= c0 && 0 <= ct && c0+ct <= cols m = res |
381 | | otherwise = error $ "wrong sliceMatrix "++show ((r0,c0),(rt,ct))++" of "++shSize m | 371 | | otherwise = error $ "wrong subMatrix "++show ((r0,c0),(rt,ct))++" of "++shSize m |
382 | where | 372 | where |
383 | p = r0 * xRow m + c0 * xCol m | 373 | p = r0 * xRow m + c0 * xCol m |
384 | tot | rowOrder m = ct + (rt-1) * xRow m | 374 | tot | rowOrder m = ct + (rt-1) * xRow m |
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index 62dfddf..544c9e4 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs | |||
@@ -231,7 +231,7 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[ | |||
231 | -- | r0 c0 height width | 231 | -- | r0 c0 height width |
232 | data Slice s t = Slice (STMatrix s t) Int Int Int Int | 232 | data Slice s t = Slice (STMatrix s t) Int Int Int Int |
233 | 233 | ||
234 | slice (Slice (STMatrix m) r0 c0 nr nc) = sliceMatrix (r0,c0) (nr,nc) m | 234 | slice (Slice (STMatrix m) r0 c0 nr nc) = subMatrix (r0,c0) (nr,nc) m |
235 | 235 | ||
236 | gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () | 236 | gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () |
237 | gemmm beta (slice->r) alpha (slice->a) (slice->b) = res | 237 | gemmm beta (slice->r) alpha (slice->a) (slice->b) = res |
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs index 258c3a3..4123e6c 100644 --- a/packages/base/src/Internal/Util.hs +++ b/packages/base/src/Internal/Util.hs | |||
@@ -849,7 +849,7 @@ viewBlock' r c m | |||
849 | m12 = subm (0,c) (r,ct-c) m | 849 | m12 = subm (0,c) (r,ct-c) m |
850 | m21 = subm (r,0) (rt-r,c) m | 850 | m21 = subm (r,0) (rt-r,c) m |
851 | m22 = subm (r,c) (rt-r,ct-c) m | 851 | m22 = subm (r,c) (rt-r,ct-c) m |
852 | subm = sliceMatrix | 852 | subm = subMatrix |
853 | 853 | ||
854 | viewBlock m = viewBlock' n n m | 854 | viewBlock m = viewBlock' n n m |
855 | where | 855 | where |
diff --git a/packages/base/src/Numeric/LinearAlgebra/Devel.hs b/packages/base/src/Numeric/LinearAlgebra/Devel.hs index f6fa92a..57a68e7 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Devel.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Devel.hs | |||
@@ -62,7 +62,7 @@ module Numeric.LinearAlgebra.Devel( | |||
62 | GMatrix(..), | 62 | GMatrix(..), |
63 | 63 | ||
64 | -- * Misc | 64 | -- * Misc |
65 | toByteString, fromByteString, sliceMatrix, showInternal | 65 | toByteString, fromByteString, showInternal |
66 | 66 | ||
67 | ) where | 67 | ) where |
68 | 68 | ||
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs index 79cb769..d9bc9a0 100644 --- a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs | |||
@@ -536,7 +536,7 @@ sliceTest = utest "slice test" $ and | |||
536 | 536 | ||
537 | testSlice f x@(size->sz@(r,c)) = all (==f x) (map f (g y1 ++ g y2)) | 537 | testSlice f x@(size->sz@(r,c)) = all (==f x) (map f (g y1 ++ g y2)) |
538 | where | 538 | where |
539 | subm = sliceMatrix | 539 | subm = subMatrix |
540 | g y = [ subm (a*r,b*c) sz y | a <-[0..2], b <- [0..2]] | 540 | g y = [ subm (a*r,b*c) sz y | a <-[0..2], b <- [0..2]] |
541 | h z = fromBlocks (replicate 3 (replicate 3 z)) | 541 | h z = fromBlocks (replicate 3 (replicate 3 z)) |
542 | y1 = h x | 542 | y1 = h x |