diff options
Diffstat (limited to 'packages/base/src')
-rw-r--r-- | packages/base/src/Internal/Chain.hs | 2 | ||||
-rw-r--r-- | packages/base/src/Internal/LAPACK.hs | 9 | ||||
-rw-r--r-- | packages/base/src/Internal/Matrix.hs | 83 | ||||
-rw-r--r-- | packages/base/src/Internal/ST.hs | 10 |
4 files changed, 43 insertions, 61 deletions
diff --git a/packages/base/src/Internal/Chain.hs b/packages/base/src/Internal/Chain.hs index fa518d1..f87eb02 100644 --- a/packages/base/src/Internal/Chain.hs +++ b/packages/base/src/Internal/Chain.hs | |||
@@ -22,7 +22,7 @@ module Internal.Chain ( | |||
22 | 22 | ||
23 | import Data.Maybe | 23 | import Data.Maybe |
24 | 24 | ||
25 | import Internal.Matrix hiding (order) | 25 | import Internal.Matrix |
26 | import Internal.Numeric | 26 | import Internal.Numeric |
27 | 27 | ||
28 | import qualified Data.Array.IArray as A | 28 | import qualified Data.Array.IArray as A |
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index 3a9abbb..fc9e3ad 100644 --- a/packages/base/src/Internal/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs | |||
@@ -1,4 +1,5 @@ | |||
1 | {-# LANGUAGE TypeOperators #-} | 1 | {-# LANGUAGE TypeOperators #-} |
2 | {-# LANGUAGE ViewPatterns #-} | ||
2 | 3 | ||
3 | ----------------------------------------------------------------------------- | 4 | ----------------------------------------------------------------------------- |
4 | -- | | 5 | -- | |
@@ -49,11 +50,11 @@ foreign import ccall unsafe "multiplyQ" cgemmc :: CInt -> CInt -> TMMM Q | |||
49 | foreign import ccall unsafe "multiplyI" c_multiplyI :: I -> I ::> I ::> I ::> Ok | 50 | foreign import ccall unsafe "multiplyI" c_multiplyI :: I -> I ::> I ::> I ::> Ok |
50 | foreign import ccall unsafe "multiplyL" c_multiplyL :: Z -> Z ::> Z ::> Z ::> Ok | 51 | foreign import ccall unsafe "multiplyL" c_multiplyL :: Z -> Z ::> Z ::> Z ::> Ok |
51 | 52 | ||
52 | isT Matrix{order = ColumnMajor} = 0 | 53 | isT (rowOrder -> False) = 0 |
53 | isT Matrix{order = RowMajor} = 1 | 54 | isT _ = 1 |
54 | 55 | ||
55 | tt x@Matrix{order = ColumnMajor} = x | 56 | tt x@(rowOrder -> False) = x |
56 | tt x@Matrix{order = RowMajor} = trans x | 57 | tt x = trans x |
57 | 58 | ||
58 | multiplyAux f st a b = unsafePerformIO $ do | 59 | multiplyAux f st a b = unsafePerformIO $ do |
59 | when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++ | 60 | when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++ |
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index db0a609..c0d1318 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -3,7 +3,9 @@ | |||
3 | {-# LANGUAGE FlexibleInstances #-} | 3 | {-# LANGUAGE FlexibleInstances #-} |
4 | {-# LANGUAGE BangPatterns #-} | 4 | {-# LANGUAGE BangPatterns #-} |
5 | {-# LANGUAGE TypeOperators #-} | 5 | {-# LANGUAGE TypeOperators #-} |
6 | {-# LANGUAGE TypeFamilies #-} | 6 | {-# LANGUAGE TypeFamilies #-} |
7 | {-# LANGUAGE ViewPatterns #-} | ||
8 | |||
7 | 9 | ||
8 | 10 | ||
9 | -- | | 11 | -- | |
@@ -74,10 +76,14 @@ The elements are stored in a continuous memory array. | |||
74 | 76 | ||
75 | -} | 77 | -} |
76 | 78 | ||
77 | data Matrix t = Matrix { irows :: {-# UNPACK #-} !Int | 79 | data Matrix t = Matrix |
78 | , icols :: {-# UNPACK #-} !Int | 80 | { irows :: {-# UNPACK #-} !Int |
79 | , xdat :: {-# UNPACK #-} !(Vector t) | 81 | , icols :: {-# UNPACK #-} !Int |
80 | , order :: !MatrixOrder } | 82 | , xRow :: {-# UNPACK #-} !CInt |
83 | , xCol :: {-# UNPACK #-} !CInt | ||
84 | -- , rowOrder :: {-# UNPACK #-} !Bool | ||
85 | , xdat :: {-# UNPACK #-} !(Vector t) | ||
86 | } | ||
81 | -- RowMajor: preferred by C, fdat may require a transposition | 87 | -- RowMajor: preferred by C, fdat may require a transposition |
82 | -- ColumnMajor: preferred by LAPACK, cdat may require a transposition | 88 | -- ColumnMajor: preferred by LAPACK, cdat may require a transposition |
83 | 89 | ||
@@ -88,49 +94,32 @@ rows = irows | |||
88 | cols :: Matrix t -> Int | 94 | cols :: Matrix t -> Int |
89 | cols = icols | 95 | cols = icols |
90 | 96 | ||
91 | orderOf :: Matrix t -> MatrixOrder | 97 | rowOrder m = xRow m > 1 |
92 | orderOf = order | 98 | {-# INLINE rowOrder #-} |
93 | |||
94 | stepRow :: Matrix t -> CInt | ||
95 | stepRow Matrix {icols = c, order = RowMajor } = fromIntegral c | ||
96 | stepRow _ = 1 | ||
97 | 99 | ||
98 | stepCol :: Matrix t -> CInt | 100 | orderOf :: Matrix t -> MatrixOrder |
99 | stepCol Matrix {irows = r, order = ColumnMajor } = fromIntegral r | 101 | orderOf m = if rowOrder m then RowMajor else ColumnMajor |
100 | stepCol _ = 1 | ||
101 | 102 | ||
102 | 103 | ||
103 | -- | Matrix transpose. | 104 | -- | Matrix transpose. |
104 | trans :: Matrix t -> Matrix t | 105 | trans :: Matrix t -> Matrix t |
105 | trans Matrix {irows = r, icols = c, xdat = d, order = o } = Matrix { irows = c, icols = r, xdat = d, order = transOrder o} | 106 | trans m@Matrix { irows = r, icols = c } | rowOrder m = |
107 | m { irows = c, icols = r, xRow = 1, xCol = fi c } | ||
108 | trans m@Matrix { irows = r, icols = c } = | ||
109 | m { irows = c, icols = r, xRow = fi r, xCol = 1 } | ||
106 | 110 | ||
107 | cmat :: (Element t) => Matrix t -> Matrix t | 111 | cmat :: (Element t) => Matrix t -> Matrix t |
108 | cmat m@Matrix{order = RowMajor} = m | 112 | cmat m | rowOrder m = m |
109 | cmat Matrix {irows = r, icols = c, xdat = d, order = ColumnMajor } = Matrix { irows = r, icols = c, xdat = transdata r d c, order = RowMajor} | 113 | cmat m@Matrix { irows = r, icols = c, xdat = d } = |
114 | m { xdat = transdata r d c, xRow = fi c, xCol = 1 } | ||
110 | 115 | ||
111 | fmat :: (Element t) => Matrix t -> Matrix t | 116 | fmat :: (Element t) => Matrix t -> Matrix t |
112 | fmat m@Matrix{order = ColumnMajor} = m | 117 | fmat m | not (rowOrder m) = m |
113 | fmat Matrix {irows = r, icols = c, xdat = d, order = RowMajor } = Matrix { irows = r, icols = c, xdat = transdata c d r, order = ColumnMajor} | 118 | fmat m@Matrix { irows = r, icols = c, xdat = d} = |
114 | 119 | m { xdat = transdata c d r, xRow = 1, xCol = fi r } | |
115 | -- C-Haskell matrix adapter | ||
116 | -- mat :: Adapt (CInt -> CInt -> Ptr t -> r) (Matrix t) r | ||
117 | |||
118 | mat :: (Storable t) => Matrix t -> (((CInt -> CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b | ||
119 | mat a f = | ||
120 | unsafeWith (xdat a) $ \p -> do | ||
121 | let m g = do | ||
122 | g (fi (rows a)) (fi (cols a)) p | ||
123 | f m | ||
124 | |||
125 | omat :: (Storable t) => Matrix t -> (((CInt -> CInt -> CInt -> CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b | ||
126 | omat a f = | ||
127 | unsafeWith (xdat a) $ \p -> do | ||
128 | let m g = do | ||
129 | g (fi (rows a)) (fi (cols a)) (stepRow a) (stepCol a) p | ||
130 | f m | ||
131 | 120 | ||
132 | -------------------------------------------------------------------------------- | ||
133 | 121 | ||
122 | -- C-Haskell matrix adapters | ||
134 | {-# INLINE amatr #-} | 123 | {-# INLINE amatr #-} |
135 | amatr :: Storable a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b | 124 | amatr :: Storable a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b |
136 | amatr f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c)) | 125 | amatr f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c)) |
@@ -144,14 +133,8 @@ amat f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c sr sc)) | |||
144 | where | 133 | where |
145 | r = fromIntegral (rows x) | 134 | r = fromIntegral (rows x) |
146 | c = fromIntegral (cols x) | 135 | c = fromIntegral (cols x) |
147 | sr = stepRow x | 136 | sr = xRow x |
148 | sc = stepCol x | 137 | sc = xCol x |
149 | |||
150 | {-# INLINE arrmat #-} | ||
151 | arrmat :: Storable a => (Ptr CInt -> Ptr a -> b) -> Matrix a -> b | ||
152 | arrmat f x = inlinePerformIO (unsafeWith s (\p -> unsafeWith (xdat x) (return . f p))) | ||
153 | where | ||
154 | s = fromList [fi (rows x), fi (cols x), stepRow x, stepCol x] | ||
155 | 138 | ||
156 | 139 | ||
157 | instance Storable t => TransArray (Matrix t) | 140 | instance Storable t => TransArray (Matrix t) |
@@ -163,8 +146,6 @@ instance Storable t => TransArray (Matrix t) | |||
163 | {-# INLINE apply #-} | 146 | {-# INLINE apply #-} |
164 | applyRaw = amatr | 147 | applyRaw = amatr |
165 | {-# INLINE applyRaw #-} | 148 | {-# INLINE applyRaw #-} |
166 | applyArray = arrmat | ||
167 | {-# INLINE applyArray #-} | ||
168 | 149 | ||
169 | infixl 1 # | 150 | infixl 1 # |
170 | a # b = apply a b | 151 | a # b = apply a b |
@@ -246,8 +227,7 @@ m@Matrix {irows = r, icols = c} @@> (i,j) | |||
246 | {-# INLINE (@@>) #-} | 227 | {-# INLINE (@@>) #-} |
247 | 228 | ||
248 | -- Unsafe matrix access without range checking | 229 | -- Unsafe matrix access without range checking |
249 | atM' Matrix {icols = c, xdat = v, order = RowMajor} i j = v `at'` (i*c+j) | 230 | atM' m i j = xdat m `at'` (i * (ti $ xRow m) + j * (ti $ xCol m)) |
250 | atM' Matrix {irows = r, xdat = v, order = ColumnMajor} i j = v `at'` (j*r+i) | ||
251 | {-# INLINE atM' #-} | 231 | {-# INLINE atM' #-} |
252 | 232 | ||
253 | ------------------------------------------------------------------ | 233 | ------------------------------------------------------------------ |
@@ -256,7 +236,8 @@ matrixFromVector o r c v | |||
256 | | r * c == dim v = m | 236 | | r * c == dim v = m |
257 | | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m | 237 | | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m |
258 | where | 238 | where |
259 | m = Matrix { irows = r, icols = c, xdat = v, order = o } | 239 | m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = fi c, xCol = 1 } |
240 | | otherwise = Matrix { irows = r, icols = c, xdat = v, xRow = 1 , xCol = fi r } | ||
260 | 241 | ||
261 | -- allocates memory for a new matrix | 242 | -- allocates memory for a new matrix |
262 | createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a) | 243 | createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a) |
@@ -282,7 +263,7 @@ reshape c v = matrixFromVector RowMajor (dim v `div` c) c v | |||
282 | 263 | ||
283 | -- | application of a vector function on the flattened matrix elements | 264 | -- | application of a vector function on the flattened matrix elements |
284 | liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b | 265 | liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b |
285 | liftMatrix f Matrix { irows = r, icols = c, xdat = d, order = o } = matrixFromVector o r c (f d) | 266 | liftMatrix f m@Matrix { irows = r, icols = c, xdat = d} = matrixFromVector (orderOf m) r c (f d) |
286 | 267 | ||
287 | -- | application of a vector function on the flattened matrices elements | 268 | -- | application of a vector function on the flattened matrices elements |
288 | liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t | 269 | liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t |
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index d1defda..c98ff0e 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs | |||
@@ -109,13 +109,13 @@ newVector x n = do | |||
109 | 109 | ||
110 | {-# INLINE ioReadM #-} | 110 | {-# INLINE ioReadM #-} |
111 | ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t | 111 | ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t |
112 | ioReadM (Matrix _ nc cv RowMajor) r c = ioReadV cv (r*nc+c) | 112 | ioReadM m r c = ioReadV (xdat m) (r * (ti $ xRow m) + c * (ti $ xCol m)) |
113 | ioReadM (Matrix nr _ fv ColumnMajor) r c = ioReadV fv (c*nr+r) | 113 | |
114 | 114 | ||
115 | {-# INLINE ioWriteM #-} | 115 | {-# INLINE ioWriteM #-} |
116 | ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO () | 116 | ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO () |
117 | ioWriteM (Matrix _ nc cv RowMajor) r c val = ioWriteV cv (r*nc+c) val | 117 | ioWriteM m r c val = ioWriteV (xdat m) (r * (ti $ xRow m) + c * (ti $ xCol m)) val |
118 | ioWriteM (Matrix nr _ fv ColumnMajor) r c val = ioWriteV fv (c*nr+r) val | 118 | |
119 | 119 | ||
120 | newtype STMatrix s t = STMatrix (Matrix t) | 120 | newtype STMatrix s t = STMatrix (Matrix t) |
121 | 121 | ||
@@ -150,7 +150,7 @@ unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x | |||
150 | freezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) | 150 | freezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) |
151 | freezeMatrix m = liftSTMatrix id m | 151 | freezeMatrix m = liftSTMatrix id m |
152 | 152 | ||
153 | cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o) | 153 | cloneMatrix m = cloneVector (xdat m) >>= return . (\d' -> m{xdat = d'}) |
154 | 154 | ||
155 | {-# INLINE safeIndexM #-} | 155 | {-# INLINE safeIndexM #-} |
156 | safeIndexM f (STMatrix m) r c | 156 | safeIndexM f (STMatrix m) r c |