summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorReiner Pope <reiner.pope@gmail.com>2012-01-07 11:47:06 +1100
committerReiner Pope <reiner.pope@gmail.com>2012-01-07 11:47:06 +1100
commitfdf8d8778d52cf14aec493ef5ab18d363b900ed7 (patch)
treeca761c585a7a5287eae45a673f10f58931c8353a /lib
parent4029bf2f48c7e0564fe23de8dc74409d1206ca0d (diff)
Make Matrix a product type
Diffstat (limited to 'lib')
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs74
-rw-r--r--lib/Data/Packed/ST.hs11
-rw-r--r--lib/Numeric/LinearAlgebra/LAPACK.hs8
3 files changed, 39 insertions, 54 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index a39c0f0..28bebbc 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -18,7 +18,7 @@
18-- #hide 18-- #hide
19 19
20module Data.Packed.Internal.Matrix( 20module Data.Packed.Internal.Matrix(
21 Matrix(..), rows, cols, 21 Matrix(..), rows, cols, cdat, fdat,
22 MatrixOrder(..), orderOf, 22 MatrixOrder(..), orderOf,
23 createMatrix, mat, 23 createMatrix, mat,
24 cmat, fmat, 24 cmat, fmat,
@@ -82,21 +82,23 @@ import System.IO.Unsafe(unsafePerformIO)
82 82
83data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) 83data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq)
84 84
85transOrder RowMajor = ColumnMajor
86transOrder ColumnMajor = RowMajor
85{- | Matrix representation suitable for GSL and LAPACK computations. 87{- | Matrix representation suitable for GSL and LAPACK computations.
86 88
87The elements are stored in a continuous memory array. 89The elements are stored in a continuous memory array.
88 90
89-} 91-}
90data Matrix t = MC { irows :: {-# UNPACK #-} !Int
91 , icols :: {-# UNPACK #-} !Int
92 , cdat :: {-# UNPACK #-} !(Vector t) }
93 92
94 | MF { irows :: {-# UNPACK #-} !Int 93data Matrix t = Matrix { irows :: {-# UNPACK #-} !Int
95 , icols :: {-# UNPACK #-} !Int 94 , icols :: {-# UNPACK #-} !Int
96 , fdat :: {-# UNPACK #-} !(Vector t) } 95 , xdat :: {-# UNPACK #-} !(Vector t)
96 , order :: !MatrixOrder }
97-- RowMajor: preferred by C, fdat may require a transposition
98-- ColumnMajor: preferred by LAPACK, cdat may require a transposition
97 99
98-- MC: preferred by C, fdat may require a transposition 100cdat = xdat
99-- MF: preferred by LAPACK, cdat may require a transposition 101fdat = xdat
100 102
101rows :: Matrix t -> Int 103rows :: Matrix t -> Int
102rows = irows 104rows = irows
@@ -104,25 +106,21 @@ rows = irows
104cols :: Matrix t -> Int 106cols :: Matrix t -> Int
105cols = icols 107cols = icols
106 108
107xdat MC {cdat = d } = d
108xdat MF {fdat = d } = d
109
110orderOf :: Matrix t -> MatrixOrder 109orderOf :: Matrix t -> MatrixOrder
111orderOf MF{} = ColumnMajor 110orderOf = order
112orderOf MC{} = RowMajor 111
113 112
114-- | Matrix transpose. 113-- | Matrix transpose.
115trans :: Matrix t -> Matrix t 114trans :: Matrix t -> Matrix t
116trans MC {irows = r, icols = c, cdat = d } = MF {irows = c, icols = r, fdat = d } 115trans Matrix {irows = r, icols = c, xdat = d, order = o } = Matrix { irows = c, icols = r, xdat = d, order = transOrder o}
117trans MF {irows = r, icols = c, fdat = d } = MC {irows = c, icols = r, cdat = d }
118 116
119cmat :: (Element t) => Matrix t -> Matrix t 117cmat :: (Element t) => Matrix t -> Matrix t
120cmat m@MC{} = m 118cmat m@Matrix{order = RowMajor} = m
121cmat MF {irows = r, icols = c, fdat = d } = MC {irows = r, icols = c, cdat = transdata r d c} 119cmat Matrix {irows = r, icols = c, xdat = d, order = ColumnMajor } = Matrix { irows = r, icols = c, xdat = transdata r d c, order = RowMajor}
122 120
123fmat :: (Element t) => Matrix t -> Matrix t 121fmat :: (Element t) => Matrix t -> Matrix t
124fmat m@MF{} = m 122fmat m@Matrix{order = ColumnMajor} = m
125fmat MC {irows = r, icols = c, cdat = d } = MF {irows = r, icols = c, fdat = transdata c d r} 123fmat Matrix {irows = r, icols = c, xdat = d, order = RowMajor } = Matrix { irows = r, icols = c, xdat = transdata c d r, order = ColumnMajor}
126 124
127-- C-Haskell matrix adapter 125-- C-Haskell matrix adapter
128-- mat :: Adapt (CInt -> CInt -> Ptr t -> r) (Matrix t) r 126-- mat :: Adapt (CInt -> CInt -> Ptr t -> r) (Matrix t) r
@@ -140,7 +138,7 @@ mat a f =
1409 |> [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]@ 1389 |> [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]@
141-} 139-}
142flatten :: Element t => Matrix t -> Vector t 140flatten :: Element t => Matrix t -> Vector t
143flatten = cdat . cmat 141flatten = xdat . cmat
144 142
145type Mt t s = Int -> Int -> Ptr t -> s 143type Mt t s = Int -> Int -> Ptr t -> s
146-- not yet admitted by my haddock version 144-- not yet admitted by my haddock version
@@ -186,32 +184,21 @@ infixl 9 @@>
186-- | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" 184-- | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range"
187-- | otherwise = cdat m `at` (i*c+j) 185-- | otherwise = cdat m `at` (i*c+j)
188 186
189MC {irows = r, icols = c, cdat = v} @@> (i,j) 187m@Matrix {irows = r, icols = c, xdat = v, order = o} @@> (i,j)
190 | safe = if i<0 || i>=r || j<0 || j>=c
191 then error "matrix indexing out of range"
192 else v `at` (i*c+j)
193 | otherwise = v `at` (i*c+j)
194
195MF {irows = r, icols = c, fdat = v} @@> (i,j)
196 | safe = if i<0 || i>=r || j<0 || j>=c 188 | safe = if i<0 || i>=r || j<0 || j>=c
197 then error "matrix indexing out of range" 189 then error "matrix indexing out of range"
198 else v `at` (j*r+i) 190 else atM' m i j
199 | otherwise = v `at` (j*r+i) 191 | otherwise = atM' m i j
200{-# INLINE (@@>) #-} 192{-# INLINE (@@>) #-}
201 193
202-- Unsafe matrix access without range checking 194-- Unsafe matrix access without range checking
203atM' MC {icols = c, cdat = v} i j = v `at'` (i*c+j) 195atM' Matrix {icols = c, xdat = v, order = RowMajor} i j = v `at'` (i*c+j)
204atM' MF {irows = r, fdat = v} i j = v `at'` (j*r+i) 196atM' Matrix {irows = r, xdat = v, order = ColumnMajor} i j = v `at'` (j*r+i)
205{-# INLINE atM' #-} 197{-# INLINE atM' #-}
206 198
207------------------------------------------------------------------ 199------------------------------------------------------------------
208 200
209matrixFromVector RowMajor c v = MC { irows = r, icols = c, cdat = v } 201matrixFromVector o c v = Matrix { irows = r, icols = c, xdat = v, order = o }
210 where (d,m) = dim v `quotRem` c
211 r | m==0 = d
212 | otherwise = error "matrixFromVector"
213
214matrixFromVector ColumnMajor c v = MF { irows = r, icols = c, fdat = v }
215 where (d,m) = dim v `quotRem` c 202 where (d,m) = dim v `quotRem` c
216 r | m==0 = d 203 r | m==0 = d
217 | otherwise = error "matrixFromVector" 204 | otherwise = error "matrixFromVector"
@@ -239,16 +226,15 @@ singleton x = reshape 1 (fromList [x])
239 226
240-- | application of a vector function on the flattened matrix elements 227-- | application of a vector function on the flattened matrix elements
241liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b 228liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b
242liftMatrix f MC { icols = c, cdat = d } = matrixFromVector RowMajor c (f d) 229liftMatrix f Matrix { icols = c, xdat = d, order = o } = matrixFromVector o c (f d)
243liftMatrix f MF { icols = c, fdat = d } = matrixFromVector ColumnMajor c (f d)
244 230
245-- | application of a vector function on the flattened matrices elements 231-- | application of a vector function on the flattened matrices elements
246liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t 232liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
247liftMatrix2 f m1 m2 233liftMatrix2 f m1 m2
248 | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2" 234 | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2"
249 | otherwise = case m1 of 235 | otherwise = case orderOf m1 of
250 MC {} -> matrixFromVector RowMajor (cols m1) (f (cdat m1) (flatten m2)) 236 RowMajor -> matrixFromVector RowMajor (cols m1) (f (xdat m1) (flatten m2))
251 MF {} -> matrixFromVector ColumnMajor (cols m1) (f (fdat m1) ((fdat.fmat) m2)) 237 ColumnMajor -> matrixFromVector ColumnMajor (cols m1) (f (xdat m1) ((xdat.fmat) m2))
252 238
253 239
254compat :: Matrix a -> Matrix b -> Bool 240compat :: Matrix a -> Matrix b -> Bool
@@ -427,7 +413,7 @@ subMatrix'' (r0,c0) (rt,ct) c v = unsafePerformIO $ do
427 go (rt-1) (ct-1) 413 go (rt-1) (ct-1)
428 return w 414 return w
429 415
430subMatrix' (r0,c0) (rt,ct) (MC _r c v) = MC rt ct $ subMatrix'' (r0,c0) (rt,ct) c v 416subMatrix' (r0,c0) (rt,ct) (Matrix { icols = c, xdat = v, order = RowMajor}) = Matrix rt ct (subMatrix'' (r0,c0) (rt,ct) c v) RowMajor
431subMatrix' (r0,c0) (rt,ct) m = trans $ subMatrix' (c0,r0) (ct,rt) (trans m) 417subMatrix' (r0,c0) (rt,ct) m = trans $ subMatrix' (c0,r0) (ct,rt) (trans m)
432 418
433-------------------------------------------------------------------------- 419--------------------------------------------------------------------------
diff --git a/lib/Data/Packed/ST.hs b/lib/Data/Packed/ST.hs
index 00f5e78..c96a209 100644
--- a/lib/Data/Packed/ST.hs
+++ b/lib/Data/Packed/ST.hs
@@ -113,13 +113,13 @@ newVector x n = do
113 113
114{-# INLINE ioReadM #-} 114{-# INLINE ioReadM #-}
115ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t 115ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t
116ioReadM (MC _ nc cv) r c = ioReadV cv (r*nc+c) 116ioReadM (Matrix _ nc cv RowMajor) r c = ioReadV cv (r*nc+c)
117ioReadM (MF nr _ fv) r c = ioReadV fv (c*nr+r) 117ioReadM (Matrix nr _ fv ColumnMajor) r c = ioReadV fv (c*nr+r)
118 118
119{-# INLINE ioWriteM #-} 119{-# INLINE ioWriteM #-}
120ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO () 120ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO ()
121ioWriteM (MC _ nc cv) r c val = ioWriteV cv (r*nc+c) val 121ioWriteM (Matrix _ nc cv RowMajor) r c val = ioWriteV cv (r*nc+c) val
122ioWriteM (MF nr _ fv) r c val = ioWriteV fv (c*nr+r) val 122ioWriteM (Matrix nr _ fv ColumnMajor) r c val = ioWriteV fv (c*nr+r) val
123 123
124newtype STMatrix s t = STMatrix (Matrix t) 124newtype STMatrix s t = STMatrix (Matrix t)
125 125
@@ -153,8 +153,7 @@ unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x
153freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) 153freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t)
154freezeMatrix m = liftSTMatrix id m 154freezeMatrix m = liftSTMatrix id m
155 155
156cloneMatrix (MC r c d) = cloneVector d >>= return . MC r c 156cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o)
157cloneMatrix (MF r c d) = cloneVector d >>= return . MF r c
158 157
159{-# INLINE safeIndexM #-} 158{-# INLINE safeIndexM #-}
160safeIndexM f (STMatrix m) r c 159safeIndexM f (STMatrix m) r c
diff --git a/lib/Numeric/LinearAlgebra/LAPACK.hs b/lib/Numeric/LinearAlgebra/LAPACK.hs
index d1aa564..349650c 100644
--- a/lib/Numeric/LinearAlgebra/LAPACK.hs
+++ b/lib/Numeric/LinearAlgebra/LAPACK.hs
@@ -58,11 +58,11 @@ foreign import ccall "multiplyC" zgemmc :: CInt -> CInt -> TCMCMCM
58foreign import ccall "multiplyF" sgemmc :: CInt -> CInt -> TFMFMFM 58foreign import ccall "multiplyF" sgemmc :: CInt -> CInt -> TFMFMFM
59foreign import ccall "multiplyQ" cgemmc :: CInt -> CInt -> TQMQMQM 59foreign import ccall "multiplyQ" cgemmc :: CInt -> CInt -> TQMQMQM
60 60
61isT MF{} = 0 61isT Matrix{order = ColumnMajor} = 0
62isT MC{} = 1 62isT Matrix{order = RowMajor} = 1
63 63
64tt x@MF{} = x 64tt x@Matrix{order = RowMajor} = x
65tt x@MC{} = trans x 65tt x@Matrix{order = ColumnMajor} = trans x
66 66
67multiplyAux f st a b = unsafePerformIO $ do 67multiplyAux f st a b = unsafePerformIO $ do
68 when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++ 68 when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++