diff options
Diffstat (limited to 'packages/base/src')
-rw-r--r-- | packages/base/src/Internal/LAPACK.hs | 2 | ||||
-rw-r--r-- | packages/base/src/Internal/Matrix.hs | 18 | ||||
-rw-r--r-- | packages/base/src/Internal/ST.hs | 2 | ||||
-rw-r--r-- | packages/base/src/Internal/Util.hs | 2 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Devel.hs | 2 |
5 files changed, 8 insertions, 18 deletions
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index 13340f2..f2fc68d 100644 --- a/packages/base/src/Internal/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs | |||
@@ -513,7 +513,7 @@ qrgrC :: Int -> (Matrix (Complex Double), Vector (Complex Double)) -> Matrix (Co | |||
513 | qrgrC = qrgrAux zungqr "qrgrC" | 513 | qrgrC = qrgrAux zungqr "qrgrC" |
514 | 514 | ||
515 | qrgrAux f st n (a, tau) = unsafePerformIO $ do | 515 | qrgrAux f st n (a, tau) = unsafePerformIO $ do |
516 | res <- copy ColumnMajor (sliceMatrix (0,0) (rows a,n) a) | 516 | res <- copy ColumnMajor (subMatrix (0,0) (rows a,n) a) |
517 | f # (subVector 0 n tau') # res #| st | 517 | f # (subVector 0 n tau') # res #| st |
518 | return res | 518 | return res |
519 | where | 519 | where |
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index df56207..5163421 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -70,7 +70,7 @@ is1d (size->(r,c)) = r==1 || c==1 | |||
70 | {-# INLINE is1d #-} | 70 | {-# INLINE is1d #-} |
71 | 71 | ||
72 | -- data is not contiguous | 72 | -- data is not contiguous |
73 | isSlice m@(size->(r,c)) = (c < xRow m || r < xCol m) && min r c > 1 | 73 | isSlice m@(size->(r,c)) = r*c < dim (xdat m) |
74 | {-# INLINE isSlice #-} | 74 | {-# INLINE isSlice #-} |
75 | 75 | ||
76 | orderOf :: Matrix t -> MatrixOrder | 76 | orderOf :: Matrix t -> MatrixOrder |
@@ -359,26 +359,16 @@ instance Element Z where | |||
359 | 359 | ||
360 | ------------------------------------------------------------------- | 360 | ------------------------------------------------------------------- |
361 | 361 | ||
362 | -- | reference to a rectangular slice of a matrix (no data copy) | ||
362 | subMatrix :: Element a | 363 | subMatrix :: Element a |
363 | => (Int,Int) -- ^ (r0,c0) starting position | ||
364 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix | ||
365 | -> Matrix a -- ^ input matrix | ||
366 | -> Matrix a -- ^ result | ||
367 | subMatrix (r0,c0) (rt,ct) m | ||
368 | | 0 <= r0 && 0 <= rt && r0+rt <= rows m && | ||
369 | 0 <= c0 && 0 <= ct && c0+ct <= cols m = unsafePerformIO $ extractR RowMajor m 0 (idxs[r0,r0+rt-1]) 0 (idxs[c0,c0+ct-1]) | ||
370 | | otherwise = error $ "wrong subMatrix "++show ((r0,c0),(rt,ct))++" of "++shSize m | ||
371 | |||
372 | |||
373 | sliceMatrix :: Element a | ||
374 | => (Int,Int) -- ^ (r0,c0) starting position | 364 | => (Int,Int) -- ^ (r0,c0) starting position |
375 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix | 365 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix |
376 | -> Matrix a -- ^ input matrix | 366 | -> Matrix a -- ^ input matrix |
377 | -> Matrix a -- ^ result | 367 | -> Matrix a -- ^ result |
378 | sliceMatrix (r0,c0) (rt,ct) m | 368 | subMatrix (r0,c0) (rt,ct) m |
379 | | 0 <= r0 && 0 <= rt && r0+rt <= rows m && | 369 | | 0 <= r0 && 0 <= rt && r0+rt <= rows m && |
380 | 0 <= c0 && 0 <= ct && c0+ct <= cols m = res | 370 | 0 <= c0 && 0 <= ct && c0+ct <= cols m = res |
381 | | otherwise = error $ "wrong sliceMatrix "++show ((r0,c0),(rt,ct))++" of "++shSize m | 371 | | otherwise = error $ "wrong subMatrix "++show ((r0,c0),(rt,ct))++" of "++shSize m |
382 | where | 372 | where |
383 | p = r0 * xRow m + c0 * xCol m | 373 | p = r0 * xRow m + c0 * xCol m |
384 | tot | rowOrder m = ct + (rt-1) * xRow m | 374 | tot | rowOrder m = ct + (rt-1) * xRow m |
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index 62dfddf..544c9e4 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs | |||
@@ -231,7 +231,7 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[ | |||
231 | -- | r0 c0 height width | 231 | -- | r0 c0 height width |
232 | data Slice s t = Slice (STMatrix s t) Int Int Int Int | 232 | data Slice s t = Slice (STMatrix s t) Int Int Int Int |
233 | 233 | ||
234 | slice (Slice (STMatrix m) r0 c0 nr nc) = sliceMatrix (r0,c0) (nr,nc) m | 234 | slice (Slice (STMatrix m) r0 c0 nr nc) = subMatrix (r0,c0) (nr,nc) m |
235 | 235 | ||
236 | gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () | 236 | gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () |
237 | gemmm beta (slice->r) alpha (slice->a) (slice->b) = res | 237 | gemmm beta (slice->r) alpha (slice->a) (slice->b) = res |
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs index 258c3a3..4123e6c 100644 --- a/packages/base/src/Internal/Util.hs +++ b/packages/base/src/Internal/Util.hs | |||
@@ -849,7 +849,7 @@ viewBlock' r c m | |||
849 | m12 = subm (0,c) (r,ct-c) m | 849 | m12 = subm (0,c) (r,ct-c) m |
850 | m21 = subm (r,0) (rt-r,c) m | 850 | m21 = subm (r,0) (rt-r,c) m |
851 | m22 = subm (r,c) (rt-r,ct-c) m | 851 | m22 = subm (r,c) (rt-r,ct-c) m |
852 | subm = sliceMatrix | 852 | subm = subMatrix |
853 | 853 | ||
854 | viewBlock m = viewBlock' n n m | 854 | viewBlock m = viewBlock' n n m |
855 | where | 855 | where |
diff --git a/packages/base/src/Numeric/LinearAlgebra/Devel.hs b/packages/base/src/Numeric/LinearAlgebra/Devel.hs index f6fa92a..57a68e7 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Devel.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Devel.hs | |||
@@ -62,7 +62,7 @@ module Numeric.LinearAlgebra.Devel( | |||
62 | GMatrix(..), | 62 | GMatrix(..), |
63 | 63 | ||
64 | -- * Misc | 64 | -- * Misc |
65 | toByteString, fromByteString, sliceMatrix, showInternal | 65 | toByteString, fromByteString, showInternal |
66 | 66 | ||
67 | ) where | 67 | ) where |
68 | 68 | ||