summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--packages/base/src/Internal/Chain.hs2
-rw-r--r--packages/base/src/Internal/LAPACK.hs9
-rw-r--r--packages/base/src/Internal/Matrix.hs83
-rw-r--r--packages/base/src/Internal/ST.hs10
4 files changed, 43 insertions, 61 deletions
diff --git a/packages/base/src/Internal/Chain.hs b/packages/base/src/Internal/Chain.hs
index fa518d1..f87eb02 100644
--- a/packages/base/src/Internal/Chain.hs
+++ b/packages/base/src/Internal/Chain.hs
@@ -22,7 +22,7 @@ module Internal.Chain (
22 22
23import Data.Maybe 23import Data.Maybe
24 24
25import Internal.Matrix hiding (order) 25import Internal.Matrix
26import Internal.Numeric 26import Internal.Numeric
27 27
28import qualified Data.Array.IArray as A 28import qualified Data.Array.IArray as A
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs
index 3a9abbb..fc9e3ad 100644
--- a/packages/base/src/Internal/LAPACK.hs
+++ b/packages/base/src/Internal/LAPACK.hs
@@ -1,4 +1,5 @@
1{-# LANGUAGE TypeOperators #-} 1{-# LANGUAGE TypeOperators #-}
2{-# LANGUAGE ViewPatterns #-}
2 3
3----------------------------------------------------------------------------- 4-----------------------------------------------------------------------------
4-- | 5-- |
@@ -49,11 +50,11 @@ foreign import ccall unsafe "multiplyQ" cgemmc :: CInt -> CInt -> TMMM Q
49foreign import ccall unsafe "multiplyI" c_multiplyI :: I -> I ::> I ::> I ::> Ok 50foreign import ccall unsafe "multiplyI" c_multiplyI :: I -> I ::> I ::> I ::> Ok
50foreign import ccall unsafe "multiplyL" c_multiplyL :: Z -> Z ::> Z ::> Z ::> Ok 51foreign import ccall unsafe "multiplyL" c_multiplyL :: Z -> Z ::> Z ::> Z ::> Ok
51 52
52isT Matrix{order = ColumnMajor} = 0 53isT (rowOrder -> False) = 0
53isT Matrix{order = RowMajor} = 1 54isT _ = 1
54 55
55tt x@Matrix{order = ColumnMajor} = x 56tt x@(rowOrder -> False) = x
56tt x@Matrix{order = RowMajor} = trans x 57tt x = trans x
57 58
58multiplyAux f st a b = unsafePerformIO $ do 59multiplyAux f st a b = unsafePerformIO $ do
59 when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++ 60 when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index db0a609..c0d1318 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -3,7 +3,9 @@
3{-# LANGUAGE FlexibleInstances #-} 3{-# LANGUAGE FlexibleInstances #-}
4{-# LANGUAGE BangPatterns #-} 4{-# LANGUAGE BangPatterns #-}
5{-# LANGUAGE TypeOperators #-} 5{-# LANGUAGE TypeOperators #-}
6{-# LANGUAGE TypeFamilies #-} 6{-# LANGUAGE TypeFamilies #-}
7{-# LANGUAGE ViewPatterns #-}
8
7 9
8 10
9-- | 11-- |
@@ -74,10 +76,14 @@ The elements are stored in a continuous memory array.
74 76
75-} 77-}
76 78
77data Matrix t = Matrix { irows :: {-# UNPACK #-} !Int 79data Matrix t = Matrix
78 , icols :: {-# UNPACK #-} !Int 80 { irows :: {-# UNPACK #-} !Int
79 , xdat :: {-# UNPACK #-} !(Vector t) 81 , icols :: {-# UNPACK #-} !Int
80 , order :: !MatrixOrder } 82 , xRow :: {-# UNPACK #-} !CInt
83 , xCol :: {-# UNPACK #-} !CInt
84-- , rowOrder :: {-# UNPACK #-} !Bool
85 , xdat :: {-# UNPACK #-} !(Vector t)
86 }
81-- RowMajor: preferred by C, fdat may require a transposition 87-- RowMajor: preferred by C, fdat may require a transposition
82-- ColumnMajor: preferred by LAPACK, cdat may require a transposition 88-- ColumnMajor: preferred by LAPACK, cdat may require a transposition
83 89
@@ -88,49 +94,32 @@ rows = irows
88cols :: Matrix t -> Int 94cols :: Matrix t -> Int
89cols = icols 95cols = icols
90 96
91orderOf :: Matrix t -> MatrixOrder 97rowOrder m = xRow m > 1
92orderOf = order 98{-# INLINE rowOrder #-}
93
94stepRow :: Matrix t -> CInt
95stepRow Matrix {icols = c, order = RowMajor } = fromIntegral c
96stepRow _ = 1
97 99
98stepCol :: Matrix t -> CInt 100orderOf :: Matrix t -> MatrixOrder
99stepCol Matrix {irows = r, order = ColumnMajor } = fromIntegral r 101orderOf m = if rowOrder m then RowMajor else ColumnMajor
100stepCol _ = 1
101 102
102 103
103-- | Matrix transpose. 104-- | Matrix transpose.
104trans :: Matrix t -> Matrix t 105trans :: Matrix t -> Matrix t
105trans Matrix {irows = r, icols = c, xdat = d, order = o } = Matrix { irows = c, icols = r, xdat = d, order = transOrder o} 106trans m@Matrix { irows = r, icols = c } | rowOrder m =
107 m { irows = c, icols = r, xRow = 1, xCol = fi c }
108trans m@Matrix { irows = r, icols = c } =
109 m { irows = c, icols = r, xRow = fi r, xCol = 1 }
106 110
107cmat :: (Element t) => Matrix t -> Matrix t 111cmat :: (Element t) => Matrix t -> Matrix t
108cmat m@Matrix{order = RowMajor} = m 112cmat m | rowOrder m = m
109cmat Matrix {irows = r, icols = c, xdat = d, order = ColumnMajor } = Matrix { irows = r, icols = c, xdat = transdata r d c, order = RowMajor} 113cmat m@Matrix { irows = r, icols = c, xdat = d } =
114 m { xdat = transdata r d c, xRow = fi c, xCol = 1 }
110 115
111fmat :: (Element t) => Matrix t -> Matrix t 116fmat :: (Element t) => Matrix t -> Matrix t
112fmat m@Matrix{order = ColumnMajor} = m 117fmat m | not (rowOrder m) = m
113fmat Matrix {irows = r, icols = c, xdat = d, order = RowMajor } = Matrix { irows = r, icols = c, xdat = transdata c d r, order = ColumnMajor} 118fmat m@Matrix { irows = r, icols = c, xdat = d} =
114 119 m { xdat = transdata c d r, xRow = 1, xCol = fi r }
115-- C-Haskell matrix adapter
116-- mat :: Adapt (CInt -> CInt -> Ptr t -> r) (Matrix t) r
117
118mat :: (Storable t) => Matrix t -> (((CInt -> CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b
119mat a f =
120 unsafeWith (xdat a) $ \p -> do
121 let m g = do
122 g (fi (rows a)) (fi (cols a)) p
123 f m
124
125omat :: (Storable t) => Matrix t -> (((CInt -> CInt -> CInt -> CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b
126omat a f =
127 unsafeWith (xdat a) $ \p -> do
128 let m g = do
129 g (fi (rows a)) (fi (cols a)) (stepRow a) (stepCol a) p
130 f m
131 120
132--------------------------------------------------------------------------------
133 121
122-- C-Haskell matrix adapters
134{-# INLINE amatr #-} 123{-# INLINE amatr #-}
135amatr :: Storable a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b 124amatr :: Storable a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b
136amatr f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c)) 125amatr f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c))
@@ -144,14 +133,8 @@ amat f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c sr sc))
144 where 133 where
145 r = fromIntegral (rows x) 134 r = fromIntegral (rows x)
146 c = fromIntegral (cols x) 135 c = fromIntegral (cols x)
147 sr = stepRow x 136 sr = xRow x
148 sc = stepCol x 137 sc = xCol x
149
150{-# INLINE arrmat #-}
151arrmat :: Storable a => (Ptr CInt -> Ptr a -> b) -> Matrix a -> b
152arrmat f x = inlinePerformIO (unsafeWith s (\p -> unsafeWith (xdat x) (return . f p)))
153 where
154 s = fromList [fi (rows x), fi (cols x), stepRow x, stepCol x]
155 138
156 139
157instance Storable t => TransArray (Matrix t) 140instance Storable t => TransArray (Matrix t)
@@ -163,8 +146,6 @@ instance Storable t => TransArray (Matrix t)
163 {-# INLINE apply #-} 146 {-# INLINE apply #-}
164 applyRaw = amatr 147 applyRaw = amatr
165 {-# INLINE applyRaw #-} 148 {-# INLINE applyRaw #-}
166 applyArray = arrmat
167 {-# INLINE applyArray #-}
168 149
169infixl 1 # 150infixl 1 #
170a # b = apply a b 151a # b = apply a b
@@ -246,8 +227,7 @@ m@Matrix {irows = r, icols = c} @@> (i,j)
246{-# INLINE (@@>) #-} 227{-# INLINE (@@>) #-}
247 228
248-- Unsafe matrix access without range checking 229-- Unsafe matrix access without range checking
249atM' Matrix {icols = c, xdat = v, order = RowMajor} i j = v `at'` (i*c+j) 230atM' m i j = xdat m `at'` (i * (ti $ xRow m) + j * (ti $ xCol m))
250atM' Matrix {irows = r, xdat = v, order = ColumnMajor} i j = v `at'` (j*r+i)
251{-# INLINE atM' #-} 231{-# INLINE atM' #-}
252 232
253------------------------------------------------------------------ 233------------------------------------------------------------------
@@ -256,7 +236,8 @@ matrixFromVector o r c v
256 | r * c == dim v = m 236 | r * c == dim v = m
257 | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m 237 | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m
258 where 238 where
259 m = Matrix { irows = r, icols = c, xdat = v, order = o } 239 m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = fi c, xCol = 1 }
240 | otherwise = Matrix { irows = r, icols = c, xdat = v, xRow = 1 , xCol = fi r }
260 241
261-- allocates memory for a new matrix 242-- allocates memory for a new matrix
262createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a) 243createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a)
@@ -282,7 +263,7 @@ reshape c v = matrixFromVector RowMajor (dim v `div` c) c v
282 263
283-- | application of a vector function on the flattened matrix elements 264-- | application of a vector function on the flattened matrix elements
284liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b 265liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b
285liftMatrix f Matrix { irows = r, icols = c, xdat = d, order = o } = matrixFromVector o r c (f d) 266liftMatrix f m@Matrix { irows = r, icols = c, xdat = d} = matrixFromVector (orderOf m) r c (f d)
286 267
287-- | application of a vector function on the flattened matrices elements 268-- | application of a vector function on the flattened matrices elements
288liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t 269liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs
index d1defda..c98ff0e 100644
--- a/packages/base/src/Internal/ST.hs
+++ b/packages/base/src/Internal/ST.hs
@@ -109,13 +109,13 @@ newVector x n = do
109 109
110{-# INLINE ioReadM #-} 110{-# INLINE ioReadM #-}
111ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t 111ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t
112ioReadM (Matrix _ nc cv RowMajor) r c = ioReadV cv (r*nc+c) 112ioReadM m r c = ioReadV (xdat m) (r * (ti $ xRow m) + c * (ti $ xCol m))
113ioReadM (Matrix nr _ fv ColumnMajor) r c = ioReadV fv (c*nr+r) 113
114 114
115{-# INLINE ioWriteM #-} 115{-# INLINE ioWriteM #-}
116ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO () 116ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO ()
117ioWriteM (Matrix _ nc cv RowMajor) r c val = ioWriteV cv (r*nc+c) val 117ioWriteM m r c val = ioWriteV (xdat m) (r * (ti $ xRow m) + c * (ti $ xCol m)) val
118ioWriteM (Matrix nr _ fv ColumnMajor) r c val = ioWriteV fv (c*nr+r) val 118
119 119
120newtype STMatrix s t = STMatrix (Matrix t) 120newtype STMatrix s t = STMatrix (Matrix t)
121 121
@@ -150,7 +150,7 @@ unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x
150freezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) 150freezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t)
151freezeMatrix m = liftSTMatrix id m 151freezeMatrix m = liftSTMatrix id m
152 152
153cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o) 153cloneMatrix m = cloneVector (xdat m) >>= return . (\d' -> m{xdat = d'})
154 154
155{-# INLINE safeIndexM #-} 155{-# INLINE safeIndexM #-}
156safeIndexM f (STMatrix m) r c 156safeIndexM f (STMatrix m) r c