diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 74 | ||||
-rw-r--r-- | lib/Data/Packed/ST.hs | 11 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/LAPACK.hs | 8 |
3 files changed, 39 insertions, 54 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index a39c0f0..28bebbc 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -18,7 +18,7 @@ | |||
18 | -- #hide | 18 | -- #hide |
19 | 19 | ||
20 | module Data.Packed.Internal.Matrix( | 20 | module Data.Packed.Internal.Matrix( |
21 | Matrix(..), rows, cols, | 21 | Matrix(..), rows, cols, cdat, fdat, |
22 | MatrixOrder(..), orderOf, | 22 | MatrixOrder(..), orderOf, |
23 | createMatrix, mat, | 23 | createMatrix, mat, |
24 | cmat, fmat, | 24 | cmat, fmat, |
@@ -82,21 +82,23 @@ import System.IO.Unsafe(unsafePerformIO) | |||
82 | 82 | ||
83 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) | 83 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) |
84 | 84 | ||
85 | transOrder RowMajor = ColumnMajor | ||
86 | transOrder ColumnMajor = RowMajor | ||
85 | {- | Matrix representation suitable for GSL and LAPACK computations. | 87 | {- | Matrix representation suitable for GSL and LAPACK computations. |
86 | 88 | ||
87 | The elements are stored in a continuous memory array. | 89 | The elements are stored in a continuous memory array. |
88 | 90 | ||
89 | -} | 91 | -} |
90 | data Matrix t = MC { irows :: {-# UNPACK #-} !Int | ||
91 | , icols :: {-# UNPACK #-} !Int | ||
92 | , cdat :: {-# UNPACK #-} !(Vector t) } | ||
93 | 92 | ||
94 | | MF { irows :: {-# UNPACK #-} !Int | 93 | data Matrix t = Matrix { irows :: {-# UNPACK #-} !Int |
95 | , icols :: {-# UNPACK #-} !Int | 94 | , icols :: {-# UNPACK #-} !Int |
96 | , fdat :: {-# UNPACK #-} !(Vector t) } | 95 | , xdat :: {-# UNPACK #-} !(Vector t) |
96 | , order :: !MatrixOrder } | ||
97 | -- RowMajor: preferred by C, fdat may require a transposition | ||
98 | -- ColumnMajor: preferred by LAPACK, cdat may require a transposition | ||
97 | 99 | ||
98 | -- MC: preferred by C, fdat may require a transposition | 100 | cdat = xdat |
99 | -- MF: preferred by LAPACK, cdat may require a transposition | 101 | fdat = xdat |
100 | 102 | ||
101 | rows :: Matrix t -> Int | 103 | rows :: Matrix t -> Int |
102 | rows = irows | 104 | rows = irows |
@@ -104,25 +106,21 @@ rows = irows | |||
104 | cols :: Matrix t -> Int | 106 | cols :: Matrix t -> Int |
105 | cols = icols | 107 | cols = icols |
106 | 108 | ||
107 | xdat MC {cdat = d } = d | ||
108 | xdat MF {fdat = d } = d | ||
109 | |||
110 | orderOf :: Matrix t -> MatrixOrder | 109 | orderOf :: Matrix t -> MatrixOrder |
111 | orderOf MF{} = ColumnMajor | 110 | orderOf = order |
112 | orderOf MC{} = RowMajor | 111 | |
113 | 112 | ||
114 | -- | Matrix transpose. | 113 | -- | Matrix transpose. |
115 | trans :: Matrix t -> Matrix t | 114 | trans :: Matrix t -> Matrix t |
116 | trans MC {irows = r, icols = c, cdat = d } = MF {irows = c, icols = r, fdat = d } | 115 | trans Matrix {irows = r, icols = c, xdat = d, order = o } = Matrix { irows = c, icols = r, xdat = d, order = transOrder o} |
117 | trans MF {irows = r, icols = c, fdat = d } = MC {irows = c, icols = r, cdat = d } | ||
118 | 116 | ||
119 | cmat :: (Element t) => Matrix t -> Matrix t | 117 | cmat :: (Element t) => Matrix t -> Matrix t |
120 | cmat m@MC{} = m | 118 | cmat m@Matrix{order = RowMajor} = m |
121 | cmat MF {irows = r, icols = c, fdat = d } = MC {irows = r, icols = c, cdat = transdata r d c} | 119 | cmat Matrix {irows = r, icols = c, xdat = d, order = ColumnMajor } = Matrix { irows = r, icols = c, xdat = transdata r d c, order = RowMajor} |
122 | 120 | ||
123 | fmat :: (Element t) => Matrix t -> Matrix t | 121 | fmat :: (Element t) => Matrix t -> Matrix t |
124 | fmat m@MF{} = m | 122 | fmat m@Matrix{order = ColumnMajor} = m |
125 | fmat MC {irows = r, icols = c, cdat = d } = MF {irows = r, icols = c, fdat = transdata c d r} | 123 | fmat Matrix {irows = r, icols = c, xdat = d, order = RowMajor } = Matrix { irows = r, icols = c, xdat = transdata c d r, order = ColumnMajor} |
126 | 124 | ||
127 | -- C-Haskell matrix adapter | 125 | -- C-Haskell matrix adapter |
128 | -- mat :: Adapt (CInt -> CInt -> Ptr t -> r) (Matrix t) r | 126 | -- mat :: Adapt (CInt -> CInt -> Ptr t -> r) (Matrix t) r |
@@ -140,7 +138,7 @@ mat a f = | |||
140 | 9 |> [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]@ | 138 | 9 |> [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]@ |
141 | -} | 139 | -} |
142 | flatten :: Element t => Matrix t -> Vector t | 140 | flatten :: Element t => Matrix t -> Vector t |
143 | flatten = cdat . cmat | 141 | flatten = xdat . cmat |
144 | 142 | ||
145 | type Mt t s = Int -> Int -> Ptr t -> s | 143 | type Mt t s = Int -> Int -> Ptr t -> s |
146 | -- not yet admitted by my haddock version | 144 | -- not yet admitted by my haddock version |
@@ -186,32 +184,21 @@ infixl 9 @@> | |||
186 | -- | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" | 184 | -- | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" |
187 | -- | otherwise = cdat m `at` (i*c+j) | 185 | -- | otherwise = cdat m `at` (i*c+j) |
188 | 186 | ||
189 | MC {irows = r, icols = c, cdat = v} @@> (i,j) | 187 | m@Matrix {irows = r, icols = c, xdat = v, order = o} @@> (i,j) |
190 | | safe = if i<0 || i>=r || j<0 || j>=c | ||
191 | then error "matrix indexing out of range" | ||
192 | else v `at` (i*c+j) | ||
193 | | otherwise = v `at` (i*c+j) | ||
194 | |||
195 | MF {irows = r, icols = c, fdat = v} @@> (i,j) | ||
196 | | safe = if i<0 || i>=r || j<0 || j>=c | 188 | | safe = if i<0 || i>=r || j<0 || j>=c |
197 | then error "matrix indexing out of range" | 189 | then error "matrix indexing out of range" |
198 | else v `at` (j*r+i) | 190 | else atM' m i j |
199 | | otherwise = v `at` (j*r+i) | 191 | | otherwise = atM' m i j |
200 | {-# INLINE (@@>) #-} | 192 | {-# INLINE (@@>) #-} |
201 | 193 | ||
202 | -- Unsafe matrix access without range checking | 194 | -- Unsafe matrix access without range checking |
203 | atM' MC {icols = c, cdat = v} i j = v `at'` (i*c+j) | 195 | atM' Matrix {icols = c, xdat = v, order = RowMajor} i j = v `at'` (i*c+j) |
204 | atM' MF {irows = r, fdat = v} i j = v `at'` (j*r+i) | 196 | atM' Matrix {irows = r, xdat = v, order = ColumnMajor} i j = v `at'` (j*r+i) |
205 | {-# INLINE atM' #-} | 197 | {-# INLINE atM' #-} |
206 | 198 | ||
207 | ------------------------------------------------------------------ | 199 | ------------------------------------------------------------------ |
208 | 200 | ||
209 | matrixFromVector RowMajor c v = MC { irows = r, icols = c, cdat = v } | 201 | matrixFromVector o c v = Matrix { irows = r, icols = c, xdat = v, order = o } |
210 | where (d,m) = dim v `quotRem` c | ||
211 | r | m==0 = d | ||
212 | | otherwise = error "matrixFromVector" | ||
213 | |||
214 | matrixFromVector ColumnMajor c v = MF { irows = r, icols = c, fdat = v } | ||
215 | where (d,m) = dim v `quotRem` c | 202 | where (d,m) = dim v `quotRem` c |
216 | r | m==0 = d | 203 | r | m==0 = d |
217 | | otherwise = error "matrixFromVector" | 204 | | otherwise = error "matrixFromVector" |
@@ -239,16 +226,15 @@ singleton x = reshape 1 (fromList [x]) | |||
239 | 226 | ||
240 | -- | application of a vector function on the flattened matrix elements | 227 | -- | application of a vector function on the flattened matrix elements |
241 | liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b | 228 | liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b |
242 | liftMatrix f MC { icols = c, cdat = d } = matrixFromVector RowMajor c (f d) | 229 | liftMatrix f Matrix { icols = c, xdat = d, order = o } = matrixFromVector o c (f d) |
243 | liftMatrix f MF { icols = c, fdat = d } = matrixFromVector ColumnMajor c (f d) | ||
244 | 230 | ||
245 | -- | application of a vector function on the flattened matrices elements | 231 | -- | application of a vector function on the flattened matrices elements |
246 | liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t | 232 | liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t |
247 | liftMatrix2 f m1 m2 | 233 | liftMatrix2 f m1 m2 |
248 | | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2" | 234 | | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2" |
249 | | otherwise = case m1 of | 235 | | otherwise = case orderOf m1 of |
250 | MC {} -> matrixFromVector RowMajor (cols m1) (f (cdat m1) (flatten m2)) | 236 | RowMajor -> matrixFromVector RowMajor (cols m1) (f (xdat m1) (flatten m2)) |
251 | MF {} -> matrixFromVector ColumnMajor (cols m1) (f (fdat m1) ((fdat.fmat) m2)) | 237 | ColumnMajor -> matrixFromVector ColumnMajor (cols m1) (f (xdat m1) ((xdat.fmat) m2)) |
252 | 238 | ||
253 | 239 | ||
254 | compat :: Matrix a -> Matrix b -> Bool | 240 | compat :: Matrix a -> Matrix b -> Bool |
@@ -427,7 +413,7 @@ subMatrix'' (r0,c0) (rt,ct) c v = unsafePerformIO $ do | |||
427 | go (rt-1) (ct-1) | 413 | go (rt-1) (ct-1) |
428 | return w | 414 | return w |
429 | 415 | ||
430 | subMatrix' (r0,c0) (rt,ct) (MC _r c v) = MC rt ct $ subMatrix'' (r0,c0) (rt,ct) c v | 416 | subMatrix' (r0,c0) (rt,ct) (Matrix { icols = c, xdat = v, order = RowMajor}) = Matrix rt ct (subMatrix'' (r0,c0) (rt,ct) c v) RowMajor |
431 | subMatrix' (r0,c0) (rt,ct) m = trans $ subMatrix' (c0,r0) (ct,rt) (trans m) | 417 | subMatrix' (r0,c0) (rt,ct) m = trans $ subMatrix' (c0,r0) (ct,rt) (trans m) |
432 | 418 | ||
433 | -------------------------------------------------------------------------- | 419 | -------------------------------------------------------------------------- |
diff --git a/lib/Data/Packed/ST.hs b/lib/Data/Packed/ST.hs index 00f5e78..c96a209 100644 --- a/lib/Data/Packed/ST.hs +++ b/lib/Data/Packed/ST.hs | |||
@@ -113,13 +113,13 @@ newVector x n = do | |||
113 | 113 | ||
114 | {-# INLINE ioReadM #-} | 114 | {-# INLINE ioReadM #-} |
115 | ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t | 115 | ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t |
116 | ioReadM (MC _ nc cv) r c = ioReadV cv (r*nc+c) | 116 | ioReadM (Matrix _ nc cv RowMajor) r c = ioReadV cv (r*nc+c) |
117 | ioReadM (MF nr _ fv) r c = ioReadV fv (c*nr+r) | 117 | ioReadM (Matrix nr _ fv ColumnMajor) r c = ioReadV fv (c*nr+r) |
118 | 118 | ||
119 | {-# INLINE ioWriteM #-} | 119 | {-# INLINE ioWriteM #-} |
120 | ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO () | 120 | ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO () |
121 | ioWriteM (MC _ nc cv) r c val = ioWriteV cv (r*nc+c) val | 121 | ioWriteM (Matrix _ nc cv RowMajor) r c val = ioWriteV cv (r*nc+c) val |
122 | ioWriteM (MF nr _ fv) r c val = ioWriteV fv (c*nr+r) val | 122 | ioWriteM (Matrix nr _ fv ColumnMajor) r c val = ioWriteV fv (c*nr+r) val |
123 | 123 | ||
124 | newtype STMatrix s t = STMatrix (Matrix t) | 124 | newtype STMatrix s t = STMatrix (Matrix t) |
125 | 125 | ||
@@ -153,8 +153,7 @@ unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x | |||
153 | freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) | 153 | freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) |
154 | freezeMatrix m = liftSTMatrix id m | 154 | freezeMatrix m = liftSTMatrix id m |
155 | 155 | ||
156 | cloneMatrix (MC r c d) = cloneVector d >>= return . MC r c | 156 | cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o) |
157 | cloneMatrix (MF r c d) = cloneVector d >>= return . MF r c | ||
158 | 157 | ||
159 | {-# INLINE safeIndexM #-} | 158 | {-# INLINE safeIndexM #-} |
160 | safeIndexM f (STMatrix m) r c | 159 | safeIndexM f (STMatrix m) r c |
diff --git a/lib/Numeric/LinearAlgebra/LAPACK.hs b/lib/Numeric/LinearAlgebra/LAPACK.hs index d1aa564..349650c 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK.hs +++ b/lib/Numeric/LinearAlgebra/LAPACK.hs | |||
@@ -58,11 +58,11 @@ foreign import ccall "multiplyC" zgemmc :: CInt -> CInt -> TCMCMCM | |||
58 | foreign import ccall "multiplyF" sgemmc :: CInt -> CInt -> TFMFMFM | 58 | foreign import ccall "multiplyF" sgemmc :: CInt -> CInt -> TFMFMFM |
59 | foreign import ccall "multiplyQ" cgemmc :: CInt -> CInt -> TQMQMQM | 59 | foreign import ccall "multiplyQ" cgemmc :: CInt -> CInt -> TQMQMQM |
60 | 60 | ||
61 | isT MF{} = 0 | 61 | isT Matrix{order = ColumnMajor} = 0 |
62 | isT MC{} = 1 | 62 | isT Matrix{order = RowMajor} = 1 |
63 | 63 | ||
64 | tt x@MF{} = x | 64 | tt x@Matrix{order = RowMajor} = x |
65 | tt x@MC{} = trans x | 65 | tt x@Matrix{order = ColumnMajor} = trans x |
66 | 66 | ||
67 | multiplyAux f st a b = unsafePerformIO $ do | 67 | multiplyAux f st a b = unsafePerformIO $ do |
68 | when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++ | 68 | when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++ |