summaryrefslogtreecommitdiff
path: root/packages/base/src/Data/Packed/Internal/Matrix.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Data/Packed/Internal/Matrix.hs')
-rw-r--r--packages/base/src/Data/Packed/Internal/Matrix.hs422
1 files changed, 422 insertions, 0 deletions
diff --git a/packages/base/src/Data/Packed/Internal/Matrix.hs b/packages/base/src/Data/Packed/Internal/Matrix.hs
new file mode 100644
index 0000000..9b831cc
--- /dev/null
+++ b/packages/base/src/Data/Packed/Internal/Matrix.hs
@@ -0,0 +1,422 @@
1{-# LANGUAGE ForeignFunctionInterface #-}
2{-# LANGUAGE FlexibleContexts #-}
3{-# LANGUAGE FlexibleInstances #-}
4{-# LANGUAGE BangPatterns #-}
5
6-- |
7-- Module : Data.Packed.Internal.Matrix
8-- Copyright : (c) Alberto Ruiz 2007
9-- License : BSD3
10-- Maintainer : Alberto Ruiz
11-- Stability : provisional
12--
13-- Internal matrix representation
14--
15
16module Data.Packed.Internal.Matrix(
17 Matrix(..), rows, cols, cdat, fdat,
18 MatrixOrder(..), orderOf,
19 createMatrix, mat,
20 cmat, fmat,
21 toLists, flatten, reshape,
22 Element(..),
23 trans,
24 fromRows, toRows, fromColumns, toColumns,
25 matrixFromVector,
26 subMatrix,
27 liftMatrix, liftMatrix2,
28 (@@>), atM',
29 singleton,
30 emptyM,
31 size, shSize, conformVs, conformMs, conformVTo, conformMTo
32) where
33
34import Data.Packed.Internal.Common
35import Data.Packed.Internal.Signatures
36import Data.Packed.Internal.Vector
37
38import Foreign.Marshal.Alloc(alloca, free)
39import Foreign.Marshal.Array(newArray)
40import Foreign.Ptr(Ptr, castPtr)
41import Foreign.Storable(Storable, peekElemOff, pokeElemOff, poke, sizeOf)
42import Data.Complex(Complex)
43import Foreign.C.Types
44import System.IO.Unsafe(unsafePerformIO)
45import Control.DeepSeq
46
47-----------------------------------------------------------------
48
49{- Design considerations for the Matrix Type
50 -----------------------------------------
51
52- we must easily handle both row major and column major order,
53 for bindings to LAPACK and GSL/C
54
55- we'd like to simplify redundant matrix transposes:
56 - Some of them arise from the order requirements of some functions
57 - some functions (matrix product) admit transposed arguments
58
59- maybe we don't really need this kind of simplification:
60 - more complex code
61 - some computational overhead
62 - only appreciable gain in code with a lot of redundant transpositions
63 and cheap matrix computations
64
65- we could carry both the matrix and its (lazily computed) transpose.
66 This may save some transpositions, but it is necessary to keep track of the
67 data which is actually computed to be used by functions like the matrix product
68 which admit both orders.
69
70- but if we need the transposed data and it is not in the structure, we must make
71 sure that we touch the same foreignptr that is used in the computation.
72
73- a reasonable solution is using two constructors for a matrix. Transposition just
74 "flips" the constructor. Actual data transposition is not done if followed by a
75 matrix product or another transpose.
76
77-}
78
79data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq)
80
81transOrder RowMajor = ColumnMajor
82transOrder ColumnMajor = RowMajor
83{- | Matrix representation suitable for GSL and LAPACK computations.
84
85The elements are stored in a continuous memory array.
86
87-}
88
89data Matrix t = Matrix { irows :: {-# UNPACK #-} !Int
90 , icols :: {-# UNPACK #-} !Int
91 , xdat :: {-# UNPACK #-} !(Vector t)
92 , order :: !MatrixOrder }
93-- RowMajor: preferred by C, fdat may require a transposition
94-- ColumnMajor: preferred by LAPACK, cdat may require a transposition
95
96cdat = xdat
97fdat = xdat
98
99rows :: Matrix t -> Int
100rows = irows
101
102cols :: Matrix t -> Int
103cols = icols
104
105orderOf :: Matrix t -> MatrixOrder
106orderOf = order
107
108
109-- | Matrix transpose.
110trans :: Matrix t -> Matrix t
111trans Matrix {irows = r, icols = c, xdat = d, order = o } = Matrix { irows = c, icols = r, xdat = d, order = transOrder o}
112
113cmat :: (Element t) => Matrix t -> Matrix t
114cmat m@Matrix{order = RowMajor} = m
115cmat Matrix {irows = r, icols = c, xdat = d, order = ColumnMajor } = Matrix { irows = r, icols = c, xdat = transdata r d c, order = RowMajor}
116
117fmat :: (Element t) => Matrix t -> Matrix t
118fmat m@Matrix{order = ColumnMajor} = m
119fmat Matrix {irows = r, icols = c, xdat = d, order = RowMajor } = Matrix { irows = r, icols = c, xdat = transdata c d r, order = ColumnMajor}
120
121-- C-Haskell matrix adapter
122-- mat :: Adapt (CInt -> CInt -> Ptr t -> r) (Matrix t) r
123
124mat :: (Storable t) => Matrix t -> (((CInt -> CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b
125mat a f =
126 unsafeWith (xdat a) $ \p -> do
127 let m g = do
128 g (fi (rows a)) (fi (cols a)) p
129 f m
130
131{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose.
132
133>>> flatten (ident 3)
134fromList [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]
135
136-}
137flatten :: Element t => Matrix t -> Vector t
138flatten = xdat . cmat
139
140{-
141type Mt t s = Int -> Int -> Ptr t -> s
142
143infixr 6 ::>
144type t ::> s = Mt t s
145-}
146
147-- | the inverse of 'Data.Packed.Matrix.fromLists'
148toLists :: (Element t) => Matrix t -> [[t]]
149toLists m = splitEvery (cols m) . toList . flatten $ m
150
151-- | Create a matrix from a list of vectors.
152-- All vectors must have the same dimension,
153-- or dimension 1, which is are automatically expanded.
154fromRows :: Element t => [Vector t] -> Matrix t
155fromRows [] = emptyM 0 0
156fromRows vs = case compatdim (map dim vs) of
157 Nothing -> error $ "fromRows expects vectors with equal sizes (or singletons), given: " ++ show (map dim vs)
158 Just 0 -> emptyM r 0
159 Just c -> matrixFromVector RowMajor r c . vjoin . map (adapt c) $ vs
160 where
161 r = length vs
162 adapt c v
163 | c == 0 = fromList[]
164 | dim v == c = v
165 | otherwise = constantD (v@>0) c
166
167-- | extracts the rows of a matrix as a list of vectors
168toRows :: Element t => Matrix t -> [Vector t]
169toRows m
170 | c == 0 = replicate r (fromList[])
171 | otherwise = toRows' 0
172 where
173 v = flatten m
174 r = rows m
175 c = cols m
176 toRows' k | k == r*c = []
177 | otherwise = subVector k c v : toRows' (k+c)
178
179-- | Creates a matrix from a list of vectors, as columns
180fromColumns :: Element t => [Vector t] -> Matrix t
181fromColumns m = trans . fromRows $ m
182
183-- | Creates a list of vectors from the columns of a matrix
184toColumns :: Element t => Matrix t -> [Vector t]
185toColumns m = toRows . trans $ m
186
187-- | Reads a matrix position.
188(@@>) :: Storable t => Matrix t -> (Int,Int) -> t
189infixl 9 @@>
190m@Matrix {irows = r, icols = c} @@> (i,j)
191 | safe = if i<0 || i>=r || j<0 || j>=c
192 then error "matrix indexing out of range"
193 else atM' m i j
194 | otherwise = atM' m i j
195{-# INLINE (@@>) #-}
196
197-- Unsafe matrix access without range checking
198atM' Matrix {icols = c, xdat = v, order = RowMajor} i j = v `at'` (i*c+j)
199atM' Matrix {irows = r, xdat = v, order = ColumnMajor} i j = v `at'` (j*r+i)
200{-# INLINE atM' #-}
201
202------------------------------------------------------------------
203
204matrixFromVector o r c v
205 | r * c == dim v = m
206 | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m
207 where
208 m = Matrix { irows = r, icols = c, xdat = v, order = o }
209
210-- allocates memory for a new matrix
211createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a)
212createMatrix ord r c = do
213 p <- createVector (r*c)
214 return (matrixFromVector ord r c p)
215
216{- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. (GNU-Octave groups by columns. To do it you can define @reshapeF r = trans . reshape r@
217where r is the desired number of rows.)
218
219>>> reshape 4 (fromList [1..12])
220(3><4)
221 [ 1.0, 2.0, 3.0, 4.0
222 , 5.0, 6.0, 7.0, 8.0
223 , 9.0, 10.0, 11.0, 12.0 ]
224
225-}
226reshape :: Storable t => Int -> Vector t -> Matrix t
227reshape 0 v = matrixFromVector RowMajor 0 0 v
228reshape c v = matrixFromVector RowMajor (dim v `div` c) c v
229
230singleton x = reshape 1 (fromList [x])
231
232-- | application of a vector function on the flattened matrix elements
233liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b
234liftMatrix f Matrix { irows = r, icols = c, xdat = d, order = o } = matrixFromVector o r c (f d)
235
236-- | application of a vector function on the flattened matrices elements
237liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
238liftMatrix2 f m1 m2
239 | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2"
240 | otherwise = case orderOf m1 of
241 RowMajor -> matrixFromVector RowMajor (rows m1) (cols m1) (f (xdat m1) (flatten m2))
242 ColumnMajor -> matrixFromVector ColumnMajor (rows m1) (cols m1) (f (xdat m1) ((xdat.fmat) m2))
243
244
245compat :: Matrix a -> Matrix b -> Bool
246compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2
247
248------------------------------------------------------------------
249
250{- | Supported matrix elements.
251
252 This class provides optimized internal
253 operations for selected element types.
254 It provides unoptimised defaults for any 'Storable' type,
255 so you can create instances simply as:
256 @instance Element Foo@.
257-}
258class (Storable a) => Element a where
259 subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position
260 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
261 -> Matrix a -> Matrix a
262 subMatrixD = subMatrix'
263 transdata :: Int -> Vector a -> Int -> Vector a
264 transdata = transdataP -- transdata'
265 constantD :: a -> Int -> Vector a
266 constantD = constantP -- constant'
267
268
269instance Element Float where
270 transdata = transdataAux ctransF
271 constantD = constantAux cconstantF
272
273instance Element Double where
274 transdata = transdataAux ctransR
275 constantD = constantAux cconstantR
276
277instance Element (Complex Float) where
278 transdata = transdataAux ctransQ
279 constantD = constantAux cconstantQ
280
281instance Element (Complex Double) where
282 transdata = transdataAux ctransC
283 constantD = constantAux cconstantC
284
285-------------------------------------------------------------------
286
287transdataAux fun c1 d c2 =
288 if noneed
289 then d
290 else unsafePerformIO $ do
291 v <- createVector (dim d)
292 unsafeWith d $ \pd ->
293 unsafeWith v $ \pv ->
294 fun (fi r1) (fi c1) pd (fi r2) (fi c2) pv // check "transdataAux"
295 return v
296 where r1 = dim d `div` c1
297 r2 = dim d `div` c2
298 noneed = dim d == 0 || r1 == 1 || c1 == 1
299
300transdataP :: Storable a => Int -> Vector a -> Int -> Vector a
301transdataP c1 d c2 =
302 if noneed
303 then d
304 else unsafePerformIO $ do
305 v <- createVector (dim d)
306 unsafeWith d $ \pd ->
307 unsafeWith v $ \pv ->
308 ctransP (fi r1) (fi c1) (castPtr pd) (fi sz) (fi r2) (fi c2) (castPtr pv) (fi sz) // check "transdataP"
309 return v
310 where r1 = dim d `div` c1
311 r2 = dim d `div` c2
312 sz = sizeOf (d @> 0)
313 noneed = dim d == 0 || r1 == 1 || c1 == 1
314
315foreign import ccall unsafe "transF" ctransF :: TFMFM
316foreign import ccall unsafe "transR" ctransR :: TMM
317foreign import ccall unsafe "transQ" ctransQ :: TQMQM
318foreign import ccall unsafe "transC" ctransC :: TCMCM
319foreign import ccall unsafe "transP" ctransP :: CInt -> CInt -> Ptr () -> CInt -> CInt -> CInt -> Ptr () -> CInt -> IO CInt
320
321----------------------------------------------------------------------
322
323constantAux fun x n = unsafePerformIO $ do
324 v <- createVector n
325 px <- newArray [x]
326 app1 (fun px) vec v "constantAux"
327 free px
328 return v
329
330foreign import ccall unsafe "constantF" cconstantF :: Ptr Float -> TF
331
332foreign import ccall unsafe "constantR" cconstantR :: Ptr Double -> TV
333
334foreign import ccall unsafe "constantQ" cconstantQ :: Ptr (Complex Float) -> TQV
335
336foreign import ccall unsafe "constantC" cconstantC :: Ptr (Complex Double) -> TCV
337
338constantP :: Storable a => a -> Int -> Vector a
339constantP a n = unsafePerformIO $ do
340 let sz = sizeOf a
341 v <- createVector n
342 unsafeWith v $ \p -> do
343 alloca $ \k -> do
344 poke k a
345 cconstantP (castPtr k) (fi n) (castPtr p) (fi sz) // check "constantP"
346 return v
347foreign import ccall unsafe "constantP" cconstantP :: Ptr () -> CInt -> Ptr () -> CInt -> IO CInt
348
349----------------------------------------------------------------------
350
351-- | Extracts a submatrix from a matrix.
352subMatrix :: Element a
353 => (Int,Int) -- ^ (r0,c0) starting position
354 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
355 -> Matrix a -- ^ input matrix
356 -> Matrix a -- ^ result
357subMatrix (r0,c0) (rt,ct) m
358 | 0 <= r0 && 0 <= rt && r0+rt <= (rows m) &&
359 0 <= c0 && 0 <= ct && c0+ct <= (cols m) = subMatrixD (r0,c0) (rt,ct) m
360 | otherwise = error $ "wrong subMatrix "++
361 show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m)
362
363subMatrix'' (r0,c0) (rt,ct) c v = unsafePerformIO $ do
364 w <- createVector (rt*ct)
365 unsafeWith v $ \p ->
366 unsafeWith w $ \q -> do
367 let go (-1) _ = return ()
368 go !i (-1) = go (i-1) (ct-1)
369 go !i !j = do x <- peekElemOff p ((i+r0)*c+j+c0)
370 pokeElemOff q (i*ct+j) x
371 go i (j-1)
372 go (rt-1) (ct-1)
373 return w
374
375subMatrix' (r0,c0) (rt,ct) (Matrix { icols = c, xdat = v, order = RowMajor}) = Matrix rt ct (subMatrix'' (r0,c0) (rt,ct) c v) RowMajor
376subMatrix' (r0,c0) (rt,ct) m = trans $ subMatrix' (c0,r0) (ct,rt) (trans m)
377
378--------------------------------------------------------------------------
379
380maxZ xs = if minimum xs == 0 then 0 else maximum xs
381
382conformMs ms = map (conformMTo (r,c)) ms
383 where
384 r = maxZ (map rows ms)
385 c = maxZ (map cols ms)
386
387
388conformVs vs = map (conformVTo n) vs
389 where
390 n = maxZ (map dim vs)
391
392conformMTo (r,c) m
393 | size m == (r,c) = m
394 | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c))
395 | size m == (r,1) = repCols c m
396 | size m == (1,c) = repRows r m
397 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")"
398
399conformVTo n v
400 | dim v == n = v
401 | dim v == 1 = constantD (v@>0) n
402 | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n
403
404repRows n x = fromRows (replicate n (flatten x))
405repCols n x = fromColumns (replicate n (flatten x))
406
407size m = (rows m, cols m)
408
409shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")"
410
411emptyM r c = matrixFromVector RowMajor r c (fromList[])
412
413----------------------------------------------------------------------
414
415instance (Storable t, NFData t) => NFData (Matrix t)
416 where
417 rnf m | d > 0 = rnf (v @> 0)
418 | otherwise = ()
419 where
420 d = dim v
421 v = xdat m
422