diff options
Diffstat (limited to 'packages/base/src/Internal/Matrix.hs')
-rw-r--r-- | packages/base/src/Internal/Matrix.hs | 83 |
1 files changed, 32 insertions, 51 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index db0a609..c0d1318 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -3,7 +3,9 @@ | |||
3 | {-# LANGUAGE FlexibleInstances #-} | 3 | {-# LANGUAGE FlexibleInstances #-} |
4 | {-# LANGUAGE BangPatterns #-} | 4 | {-# LANGUAGE BangPatterns #-} |
5 | {-# LANGUAGE TypeOperators #-} | 5 | {-# LANGUAGE TypeOperators #-} |
6 | {-# LANGUAGE TypeFamilies #-} | 6 | {-# LANGUAGE TypeFamilies #-} |
7 | {-# LANGUAGE ViewPatterns #-} | ||
8 | |||
7 | 9 | ||
8 | 10 | ||
9 | -- | | 11 | -- | |
@@ -74,10 +76,14 @@ The elements are stored in a continuous memory array. | |||
74 | 76 | ||
75 | -} | 77 | -} |
76 | 78 | ||
77 | data Matrix t = Matrix { irows :: {-# UNPACK #-} !Int | 79 | data Matrix t = Matrix |
78 | , icols :: {-# UNPACK #-} !Int | 80 | { irows :: {-# UNPACK #-} !Int |
79 | , xdat :: {-# UNPACK #-} !(Vector t) | 81 | , icols :: {-# UNPACK #-} !Int |
80 | , order :: !MatrixOrder } | 82 | , xRow :: {-# UNPACK #-} !CInt |
83 | , xCol :: {-# UNPACK #-} !CInt | ||
84 | -- , rowOrder :: {-# UNPACK #-} !Bool | ||
85 | , xdat :: {-# UNPACK #-} !(Vector t) | ||
86 | } | ||
81 | -- RowMajor: preferred by C, fdat may require a transposition | 87 | -- RowMajor: preferred by C, fdat may require a transposition |
82 | -- ColumnMajor: preferred by LAPACK, cdat may require a transposition | 88 | -- ColumnMajor: preferred by LAPACK, cdat may require a transposition |
83 | 89 | ||
@@ -88,49 +94,32 @@ rows = irows | |||
88 | cols :: Matrix t -> Int | 94 | cols :: Matrix t -> Int |
89 | cols = icols | 95 | cols = icols |
90 | 96 | ||
91 | orderOf :: Matrix t -> MatrixOrder | 97 | rowOrder m = xRow m > 1 |
92 | orderOf = order | 98 | {-# INLINE rowOrder #-} |
93 | |||
94 | stepRow :: Matrix t -> CInt | ||
95 | stepRow Matrix {icols = c, order = RowMajor } = fromIntegral c | ||
96 | stepRow _ = 1 | ||
97 | 99 | ||
98 | stepCol :: Matrix t -> CInt | 100 | orderOf :: Matrix t -> MatrixOrder |
99 | stepCol Matrix {irows = r, order = ColumnMajor } = fromIntegral r | 101 | orderOf m = if rowOrder m then RowMajor else ColumnMajor |
100 | stepCol _ = 1 | ||
101 | 102 | ||
102 | 103 | ||
103 | -- | Matrix transpose. | 104 | -- | Matrix transpose. |
104 | trans :: Matrix t -> Matrix t | 105 | trans :: Matrix t -> Matrix t |
105 | trans Matrix {irows = r, icols = c, xdat = d, order = o } = Matrix { irows = c, icols = r, xdat = d, order = transOrder o} | 106 | trans m@Matrix { irows = r, icols = c } | rowOrder m = |
107 | m { irows = c, icols = r, xRow = 1, xCol = fi c } | ||
108 | trans m@Matrix { irows = r, icols = c } = | ||
109 | m { irows = c, icols = r, xRow = fi r, xCol = 1 } | ||
106 | 110 | ||
107 | cmat :: (Element t) => Matrix t -> Matrix t | 111 | cmat :: (Element t) => Matrix t -> Matrix t |
108 | cmat m@Matrix{order = RowMajor} = m | 112 | cmat m | rowOrder m = m |
109 | cmat Matrix {irows = r, icols = c, xdat = d, order = ColumnMajor } = Matrix { irows = r, icols = c, xdat = transdata r d c, order = RowMajor} | 113 | cmat m@Matrix { irows = r, icols = c, xdat = d } = |
114 | m { xdat = transdata r d c, xRow = fi c, xCol = 1 } | ||
110 | 115 | ||
111 | fmat :: (Element t) => Matrix t -> Matrix t | 116 | fmat :: (Element t) => Matrix t -> Matrix t |
112 | fmat m@Matrix{order = ColumnMajor} = m | 117 | fmat m | not (rowOrder m) = m |
113 | fmat Matrix {irows = r, icols = c, xdat = d, order = RowMajor } = Matrix { irows = r, icols = c, xdat = transdata c d r, order = ColumnMajor} | 118 | fmat m@Matrix { irows = r, icols = c, xdat = d} = |
114 | 119 | m { xdat = transdata c d r, xRow = 1, xCol = fi r } | |
115 | -- C-Haskell matrix adapter | ||
116 | -- mat :: Adapt (CInt -> CInt -> Ptr t -> r) (Matrix t) r | ||
117 | |||
118 | mat :: (Storable t) => Matrix t -> (((CInt -> CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b | ||
119 | mat a f = | ||
120 | unsafeWith (xdat a) $ \p -> do | ||
121 | let m g = do | ||
122 | g (fi (rows a)) (fi (cols a)) p | ||
123 | f m | ||
124 | |||
125 | omat :: (Storable t) => Matrix t -> (((CInt -> CInt -> CInt -> CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b | ||
126 | omat a f = | ||
127 | unsafeWith (xdat a) $ \p -> do | ||
128 | let m g = do | ||
129 | g (fi (rows a)) (fi (cols a)) (stepRow a) (stepCol a) p | ||
130 | f m | ||
131 | 120 | ||
132 | -------------------------------------------------------------------------------- | ||
133 | 121 | ||
122 | -- C-Haskell matrix adapters | ||
134 | {-# INLINE amatr #-} | 123 | {-# INLINE amatr #-} |
135 | amatr :: Storable a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b | 124 | amatr :: Storable a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b |
136 | amatr f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c)) | 125 | amatr f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c)) |
@@ -144,14 +133,8 @@ amat f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c sr sc)) | |||
144 | where | 133 | where |
145 | r = fromIntegral (rows x) | 134 | r = fromIntegral (rows x) |
146 | c = fromIntegral (cols x) | 135 | c = fromIntegral (cols x) |
147 | sr = stepRow x | 136 | sr = xRow x |
148 | sc = stepCol x | 137 | sc = xCol x |
149 | |||
150 | {-# INLINE arrmat #-} | ||
151 | arrmat :: Storable a => (Ptr CInt -> Ptr a -> b) -> Matrix a -> b | ||
152 | arrmat f x = inlinePerformIO (unsafeWith s (\p -> unsafeWith (xdat x) (return . f p))) | ||
153 | where | ||
154 | s = fromList [fi (rows x), fi (cols x), stepRow x, stepCol x] | ||
155 | 138 | ||
156 | 139 | ||
157 | instance Storable t => TransArray (Matrix t) | 140 | instance Storable t => TransArray (Matrix t) |
@@ -163,8 +146,6 @@ instance Storable t => TransArray (Matrix t) | |||
163 | {-# INLINE apply #-} | 146 | {-# INLINE apply #-} |
164 | applyRaw = amatr | 147 | applyRaw = amatr |
165 | {-# INLINE applyRaw #-} | 148 | {-# INLINE applyRaw #-} |
166 | applyArray = arrmat | ||
167 | {-# INLINE applyArray #-} | ||
168 | 149 | ||
169 | infixl 1 # | 150 | infixl 1 # |
170 | a # b = apply a b | 151 | a # b = apply a b |
@@ -246,8 +227,7 @@ m@Matrix {irows = r, icols = c} @@> (i,j) | |||
246 | {-# INLINE (@@>) #-} | 227 | {-# INLINE (@@>) #-} |
247 | 228 | ||
248 | -- Unsafe matrix access without range checking | 229 | -- Unsafe matrix access without range checking |
249 | atM' Matrix {icols = c, xdat = v, order = RowMajor} i j = v `at'` (i*c+j) | 230 | atM' m i j = xdat m `at'` (i * (ti $ xRow m) + j * (ti $ xCol m)) |
250 | atM' Matrix {irows = r, xdat = v, order = ColumnMajor} i j = v `at'` (j*r+i) | ||
251 | {-# INLINE atM' #-} | 231 | {-# INLINE atM' #-} |
252 | 232 | ||
253 | ------------------------------------------------------------------ | 233 | ------------------------------------------------------------------ |
@@ -256,7 +236,8 @@ matrixFromVector o r c v | |||
256 | | r * c == dim v = m | 236 | | r * c == dim v = m |
257 | | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m | 237 | | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m |
258 | where | 238 | where |
259 | m = Matrix { irows = r, icols = c, xdat = v, order = o } | 239 | m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = fi c, xCol = 1 } |
240 | | otherwise = Matrix { irows = r, icols = c, xdat = v, xRow = 1 , xCol = fi r } | ||
260 | 241 | ||
261 | -- allocates memory for a new matrix | 242 | -- allocates memory for a new matrix |
262 | createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a) | 243 | createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a) |
@@ -282,7 +263,7 @@ reshape c v = matrixFromVector RowMajor (dim v `div` c) c v | |||
282 | 263 | ||
283 | -- | application of a vector function on the flattened matrix elements | 264 | -- | application of a vector function on the flattened matrix elements |
284 | liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b | 265 | liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b |
285 | liftMatrix f Matrix { irows = r, icols = c, xdat = d, order = o } = matrixFromVector o r c (f d) | 266 | liftMatrix f m@Matrix { irows = r, icols = c, xdat = d} = matrixFromVector (orderOf m) r c (f d) |
286 | 267 | ||
287 | -- | application of a vector function on the flattened matrices elements | 268 | -- | application of a vector function on the flattened matrices elements |
288 | liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t | 269 | liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t |