diff options
Diffstat (limited to 'packages/base/src/Internal')
-rw-r--r-- | packages/base/src/Internal/Matrix.hs | 51 | ||||
-rw-r--r-- | packages/base/src/Internal/ST.hs | 4 |
2 files changed, 37 insertions, 18 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index c0d1318..f76b9dc 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -79,9 +79,8 @@ The elements are stored in a continuous memory array. | |||
79 | data Matrix t = Matrix | 79 | data Matrix t = Matrix |
80 | { irows :: {-# UNPACK #-} !Int | 80 | { irows :: {-# UNPACK #-} !Int |
81 | , icols :: {-# UNPACK #-} !Int | 81 | , icols :: {-# UNPACK #-} !Int |
82 | , xRow :: {-# UNPACK #-} !CInt | 82 | , xRow :: {-# UNPACK #-} !Int |
83 | , xCol :: {-# UNPACK #-} !CInt | 83 | , xCol :: {-# UNPACK #-} !Int |
84 | -- , rowOrder :: {-# UNPACK #-} !Bool | ||
85 | , xdat :: {-# UNPACK #-} !(Vector t) | 84 | , xdat :: {-# UNPACK #-} !(Vector t) |
86 | } | 85 | } |
87 | -- RowMajor: preferred by C, fdat may require a transposition | 86 | -- RowMajor: preferred by C, fdat may require a transposition |
@@ -90,13 +89,18 @@ data Matrix t = Matrix | |||
90 | 89 | ||
91 | rows :: Matrix t -> Int | 90 | rows :: Matrix t -> Int |
92 | rows = irows | 91 | rows = irows |
92 | {-# INLINE rows #-} | ||
93 | 93 | ||
94 | cols :: Matrix t -> Int | 94 | cols :: Matrix t -> Int |
95 | cols = icols | 95 | cols = icols |
96 | {-# INLINE cols #-} | ||
96 | 97 | ||
97 | rowOrder m = xRow m > 1 | 98 | rowOrder m = xRow m > 1 |
98 | {-# INLINE rowOrder #-} | 99 | {-# INLINE rowOrder #-} |
99 | 100 | ||
101 | isSlice m = cols m < xRow m || rows m < xCol m | ||
102 | {-# INLINE isSlice #-} | ||
103 | |||
100 | orderOf :: Matrix t -> MatrixOrder | 104 | orderOf :: Matrix t -> MatrixOrder |
101 | orderOf m = if rowOrder m then RowMajor else ColumnMajor | 105 | orderOf m = if rowOrder m then RowMajor else ColumnMajor |
102 | 106 | ||
@@ -104,19 +108,19 @@ orderOf m = if rowOrder m then RowMajor else ColumnMajor | |||
104 | -- | Matrix transpose. | 108 | -- | Matrix transpose. |
105 | trans :: Matrix t -> Matrix t | 109 | trans :: Matrix t -> Matrix t |
106 | trans m@Matrix { irows = r, icols = c } | rowOrder m = | 110 | trans m@Matrix { irows = r, icols = c } | rowOrder m = |
107 | m { irows = c, icols = r, xRow = 1, xCol = fi c } | 111 | m { irows = c, icols = r, xRow = 1, xCol = c } |
108 | trans m@Matrix { irows = r, icols = c } = | 112 | trans m@Matrix { irows = r, icols = c } = |
109 | m { irows = c, icols = r, xRow = fi r, xCol = 1 } | 113 | m { irows = c, icols = r, xRow = r, xCol = 1 } |
110 | 114 | ||
111 | cmat :: (Element t) => Matrix t -> Matrix t | 115 | cmat :: (Element t) => Matrix t -> Matrix t |
112 | cmat m | rowOrder m = m | 116 | cmat m | rowOrder m = m |
113 | cmat m@Matrix { irows = r, icols = c, xdat = d } = | 117 | cmat m@Matrix { irows = r, icols = c, xdat = d } = |
114 | m { xdat = transdata r d c, xRow = fi c, xCol = 1 } | 118 | m { xdat = transdata r d c, xRow = c, xCol = 1 } |
115 | 119 | ||
116 | fmat :: (Element t) => Matrix t -> Matrix t | 120 | fmat :: (Element t) => Matrix t -> Matrix t |
117 | fmat m | not (rowOrder m) = m | 121 | fmat m | not (rowOrder m) = m |
118 | fmat m@Matrix { irows = r, icols = c, xdat = d} = | 122 | fmat m@Matrix { irows = r, icols = c, xdat = d} = |
119 | m { xdat = transdata c d r, xRow = 1, xCol = fi r } | 123 | m { xdat = transdata c d r, xRow = 1, xCol = r } |
120 | 124 | ||
121 | 125 | ||
122 | -- C-Haskell matrix adapters | 126 | -- C-Haskell matrix adapters |
@@ -124,17 +128,17 @@ fmat m@Matrix { irows = r, icols = c, xdat = d} = | |||
124 | amatr :: Storable a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b | 128 | amatr :: Storable a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b |
125 | amatr f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c)) | 129 | amatr f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c)) |
126 | where | 130 | where |
127 | r = fromIntegral (rows x) | 131 | r = fi (rows x) |
128 | c = fromIntegral (cols x) | 132 | c = fi (cols x) |
129 | 133 | ||
130 | {-# INLINE amat #-} | 134 | {-# INLINE amat #-} |
131 | amat :: Storable a => (CInt -> CInt -> CInt -> CInt -> Ptr a -> b) -> Matrix a -> b | 135 | amat :: Storable a => (CInt -> CInt -> CInt -> CInt -> Ptr a -> b) -> Matrix a -> b |
132 | amat f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c sr sc)) | 136 | amat f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c sr sc)) |
133 | where | 137 | where |
134 | r = fromIntegral (rows x) | 138 | r = fi (rows x) |
135 | c = fromIntegral (cols x) | 139 | c = fi (cols x) |
136 | sr = xRow x | 140 | sr = fi (xRow x) |
137 | sc = xCol x | 141 | sc = fi (xCol x) |
138 | 142 | ||
139 | 143 | ||
140 | instance Storable t => TransArray (Matrix t) | 144 | instance Storable t => TransArray (Matrix t) |
@@ -227,7 +231,7 @@ m@Matrix {irows = r, icols = c} @@> (i,j) | |||
227 | {-# INLINE (@@>) #-} | 231 | {-# INLINE (@@>) #-} |
228 | 232 | ||
229 | -- Unsafe matrix access without range checking | 233 | -- Unsafe matrix access without range checking |
230 | atM' m i j = xdat m `at'` (i * (ti $ xRow m) + j * (ti $ xCol m)) | 234 | atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m)) |
231 | {-# INLINE atM' #-} | 235 | {-# INLINE atM' #-} |
232 | 236 | ||
233 | ------------------------------------------------------------------ | 237 | ------------------------------------------------------------------ |
@@ -236,8 +240,8 @@ matrixFromVector o r c v | |||
236 | | r * c == dim v = m | 240 | | r * c == dim v = m |
237 | | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m | 241 | | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m |
238 | where | 242 | where |
239 | m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = fi c, xCol = 1 } | 243 | m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = c, xCol = 1 } |
240 | | otherwise = Matrix { irows = r, icols = c, xdat = v, xRow = 1 , xCol = fi r } | 244 | | otherwise = Matrix { irows = r, icols = c, xdat = v, xRow = 1, xCol = r } |
241 | 245 | ||
242 | -- allocates memory for a new matrix | 246 | -- allocates memory for a new matrix |
243 | createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a) | 247 | createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a) |
@@ -411,6 +415,21 @@ subMatrix (r0,c0) (rt,ct) m | |||
411 | | otherwise = error $ "wrong subMatrix "++ | 415 | | otherwise = error $ "wrong subMatrix "++ |
412 | show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m) | 416 | show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m) |
413 | 417 | ||
418 | |||
419 | sliceMatrix :: Element a | ||
420 | => (Int,Int) -- ^ (r0,c0) starting position | ||
421 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix | ||
422 | -> Matrix a -- ^ input matrix | ||
423 | -> Matrix a -- ^ result | ||
424 | sliceMatrix (r0,c0) (rt,ct) m | ||
425 | | 0 <= r0 && 0 <= rt && r0+rt <= rows m && | ||
426 | 0 <= c0 && 0 <= ct && c0+ct <= cols m = res | ||
427 | | otherwise = error $ "wrong sliceMatrix "++ | ||
428 | show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m) | ||
429 | where | ||
430 | t = r0 * xRow m + c0 * xCol m | ||
431 | res = m { irows = rt, icols = ct, xdat = subVector t (rt*ct) (xdat m) } | ||
432 | |||
414 | -------------------------------------------------------------------------- | 433 | -------------------------------------------------------------------------- |
415 | 434 | ||
416 | maxZ xs = if minimum xs == 0 then 0 else maximum xs | 435 | maxZ xs = if minimum xs == 0 then 0 else maximum xs |
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index c98ff0e..92654e4 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs | |||
@@ -109,12 +109,12 @@ newVector x n = do | |||
109 | 109 | ||
110 | {-# INLINE ioReadM #-} | 110 | {-# INLINE ioReadM #-} |
111 | ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t | 111 | ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t |
112 | ioReadM m r c = ioReadV (xdat m) (r * (ti $ xRow m) + c * (ti $ xCol m)) | 112 | ioReadM m r c = ioReadV (xdat m) (r * xRow m + c * xCol m) |
113 | 113 | ||
114 | 114 | ||
115 | {-# INLINE ioWriteM #-} | 115 | {-# INLINE ioWriteM #-} |
116 | ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO () | 116 | ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO () |
117 | ioWriteM m r c val = ioWriteV (xdat m) (r * (ti $ xRow m) + c * (ti $ xCol m)) val | 117 | ioWriteM m r c val = ioWriteV (xdat m) (r * xRow m + c * xCol m) val |
118 | 118 | ||
119 | 119 | ||
120 | newtype STMatrix s t = STMatrix (Matrix t) | 120 | newtype STMatrix s t = STMatrix (Matrix t) |