summaryrefslogtreecommitdiff
path: root/packages/hmatrix/src/Data/Packed/Internal/Matrix.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/hmatrix/src/Data/Packed/Internal/Matrix.hs')
-rw-r--r--packages/hmatrix/src/Data/Packed/Internal/Matrix.hs491
1 files changed, 491 insertions, 0 deletions
diff --git a/packages/hmatrix/src/Data/Packed/Internal/Matrix.hs b/packages/hmatrix/src/Data/Packed/Internal/Matrix.hs
new file mode 100644
index 0000000..9719fc0
--- /dev/null
+++ b/packages/hmatrix/src/Data/Packed/Internal/Matrix.hs
@@ -0,0 +1,491 @@
1{-# LANGUAGE ForeignFunctionInterface #-}
2{-# LANGUAGE FlexibleContexts #-}
3{-# LANGUAGE FlexibleInstances #-}
4{-# LANGUAGE BangPatterns #-}
5-----------------------------------------------------------------------------
6-- |
7-- Module : Data.Packed.Internal.Matrix
8-- Copyright : (c) Alberto Ruiz 2007
9-- License : GPL-style
10--
11-- Maintainer : Alberto Ruiz <aruiz@um.es>
12-- Stability : provisional
13-- Portability : portable (uses FFI)
14--
15-- Internal matrix representation
16--
17-----------------------------------------------------------------------------
18-- #hide
19
20module Data.Packed.Internal.Matrix(
21 Matrix(..), rows, cols, cdat, fdat,
22 MatrixOrder(..), orderOf,
23 createMatrix, mat,
24 cmat, fmat,
25 toLists, flatten, reshape,
26 Element(..),
27 trans,
28 fromRows, toRows, fromColumns, toColumns,
29 matrixFromVector,
30 subMatrix,
31 liftMatrix, liftMatrix2,
32 (@@>), atM',
33 saveMatrix,
34 singleton,
35 emptyM,
36 size, shSize, conformVs, conformMs, conformVTo, conformMTo
37) where
38
39import Data.Packed.Internal.Common
40import Data.Packed.Internal.Signatures
41import Data.Packed.Internal.Vector
42
43import Foreign.Marshal.Alloc(alloca, free)
44import Foreign.Marshal.Array(newArray)
45import Foreign.Ptr(Ptr, castPtr)
46import Foreign.Storable(Storable, peekElemOff, pokeElemOff, poke, sizeOf)
47import Data.Complex(Complex)
48import Foreign.C.Types
49import Foreign.C.String(newCString)
50import System.IO.Unsafe(unsafePerformIO)
51import Control.DeepSeq
52
53-----------------------------------------------------------------
54
55{- Design considerations for the Matrix Type
56 -----------------------------------------
57
58- we must easily handle both row major and column major order,
59 for bindings to LAPACK and GSL/C
60
61- we'd like to simplify redundant matrix transposes:
62 - Some of them arise from the order requirements of some functions
63 - some functions (matrix product) admit transposed arguments
64
65- maybe we don't really need this kind of simplification:
66 - more complex code
67 - some computational overhead
68 - only appreciable gain in code with a lot of redundant transpositions
69 and cheap matrix computations
70
71- we could carry both the matrix and its (lazily computed) transpose.
72 This may save some transpositions, but it is necessary to keep track of the
73 data which is actually computed to be used by functions like the matrix product
74 which admit both orders.
75
76- but if we need the transposed data and it is not in the structure, we must make
77 sure that we touch the same foreignptr that is used in the computation.
78
79- a reasonable solution is using two constructors for a matrix. Transposition just
80 "flips" the constructor. Actual data transposition is not done if followed by a
81 matrix product or another transpose.
82
83-}
84
85data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq)
86
87transOrder RowMajor = ColumnMajor
88transOrder ColumnMajor = RowMajor
89{- | Matrix representation suitable for GSL and LAPACK computations.
90
91The elements are stored in a continuous memory array.
92
93-}
94
95data Matrix t = Matrix { irows :: {-# UNPACK #-} !Int
96 , icols :: {-# UNPACK #-} !Int
97 , xdat :: {-# UNPACK #-} !(Vector t)
98 , order :: !MatrixOrder }
99-- RowMajor: preferred by C, fdat may require a transposition
100-- ColumnMajor: preferred by LAPACK, cdat may require a transposition
101
102cdat = xdat
103fdat = xdat
104
105rows :: Matrix t -> Int
106rows = irows
107
108cols :: Matrix t -> Int
109cols = icols
110
111orderOf :: Matrix t -> MatrixOrder
112orderOf = order
113
114
115-- | Matrix transpose.
116trans :: Matrix t -> Matrix t
117trans Matrix {irows = r, icols = c, xdat = d, order = o } = Matrix { irows = c, icols = r, xdat = d, order = transOrder o}
118
119cmat :: (Element t) => Matrix t -> Matrix t
120cmat m@Matrix{order = RowMajor} = m
121cmat Matrix {irows = r, icols = c, xdat = d, order = ColumnMajor } = Matrix { irows = r, icols = c, xdat = transdata r d c, order = RowMajor}
122
123fmat :: (Element t) => Matrix t -> Matrix t
124fmat m@Matrix{order = ColumnMajor} = m
125fmat Matrix {irows = r, icols = c, xdat = d, order = RowMajor } = Matrix { irows = r, icols = c, xdat = transdata c d r, order = ColumnMajor}
126
127-- C-Haskell matrix adapter
128-- mat :: Adapt (CInt -> CInt -> Ptr t -> r) (Matrix t) r
129
130mat :: (Storable t) => Matrix t -> (((CInt -> CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b
131mat a f =
132 unsafeWith (xdat a) $ \p -> do
133 let m g = do
134 g (fi (rows a)) (fi (cols a)) p
135 f m
136
137{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose.
138
139>>> flatten (ident 3)
140fromList [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]
141
142-}
143flatten :: Element t => Matrix t -> Vector t
144flatten = xdat . cmat
145
146{-
147type Mt t s = Int -> Int -> Ptr t -> s
148
149infixr 6 ::>
150type t ::> s = Mt t s
151-}
152
153-- | the inverse of 'Data.Packed.Matrix.fromLists'
154toLists :: (Element t) => Matrix t -> [[t]]
155toLists m = splitEvery (cols m) . toList . flatten $ m
156
157-- | Create a matrix from a list of vectors.
158-- All vectors must have the same dimension,
159-- or dimension 1, which is are automatically expanded.
160fromRows :: Element t => [Vector t] -> Matrix t
161fromRows [] = emptyM 0 0
162fromRows vs = case compatdim (map dim vs) of
163 Nothing -> error $ "fromRows expects vectors with equal sizes (or singletons), given: " ++ show (map dim vs)
164 Just 0 -> emptyM r 0
165 Just c -> matrixFromVector RowMajor r c . vjoin . map (adapt c) $ vs
166 where
167 r = length vs
168 adapt c v
169 | c == 0 = fromList[]
170 | dim v == c = v
171 | otherwise = constantD (v@>0) c
172
173-- | extracts the rows of a matrix as a list of vectors
174toRows :: Element t => Matrix t -> [Vector t]
175toRows m
176 | c == 0 = replicate r (fromList[])
177 | otherwise = toRows' 0
178 where
179 v = flatten m
180 r = rows m
181 c = cols m
182 toRows' k | k == r*c = []
183 | otherwise = subVector k c v : toRows' (k+c)
184
185-- | Creates a matrix from a list of vectors, as columns
186fromColumns :: Element t => [Vector t] -> Matrix t
187fromColumns m = trans . fromRows $ m
188
189-- | Creates a list of vectors from the columns of a matrix
190toColumns :: Element t => Matrix t -> [Vector t]
191toColumns m = toRows . trans $ m
192
193-- | Reads a matrix position.
194(@@>) :: Storable t => Matrix t -> (Int,Int) -> t
195infixl 9 @@>
196m@Matrix {irows = r, icols = c} @@> (i,j)
197 | safe = if i<0 || i>=r || j<0 || j>=c
198 then error "matrix indexing out of range"
199 else atM' m i j
200 | otherwise = atM' m i j
201{-# INLINE (@@>) #-}
202
203-- Unsafe matrix access without range checking
204atM' Matrix {icols = c, xdat = v, order = RowMajor} i j = v `at'` (i*c+j)
205atM' Matrix {irows = r, xdat = v, order = ColumnMajor} i j = v `at'` (j*r+i)
206{-# INLINE atM' #-}
207
208------------------------------------------------------------------
209
210matrixFromVector o r c v
211 | r * c == dim v = m
212 | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m
213 where
214 m = Matrix { irows = r, icols = c, xdat = v, order = o }
215
216-- allocates memory for a new matrix
217createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a)
218createMatrix ord r c = do
219 p <- createVector (r*c)
220 return (matrixFromVector ord r c p)
221
222{- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. (GNU-Octave groups by columns. To do it you can define @reshapeF r = trans . reshape r@
223where r is the desired number of rows.)
224
225>>> reshape 4 (fromList [1..12])
226(3><4)
227 [ 1.0, 2.0, 3.0, 4.0
228 , 5.0, 6.0, 7.0, 8.0
229 , 9.0, 10.0, 11.0, 12.0 ]
230
231-}
232reshape :: Storable t => Int -> Vector t -> Matrix t
233reshape 0 v = matrixFromVector RowMajor 0 0 v
234reshape c v = matrixFromVector RowMajor (dim v `div` c) c v
235
236singleton x = reshape 1 (fromList [x])
237
238-- | application of a vector function on the flattened matrix elements
239liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b
240liftMatrix f Matrix { irows = r, icols = c, xdat = d, order = o } = matrixFromVector o r c (f d)
241
242-- | application of a vector function on the flattened matrices elements
243liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
244liftMatrix2 f m1 m2
245 | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2"
246 | otherwise = case orderOf m1 of
247 RowMajor -> matrixFromVector RowMajor (rows m1) (cols m1) (f (xdat m1) (flatten m2))
248 ColumnMajor -> matrixFromVector ColumnMajor (rows m1) (cols m1) (f (xdat m1) ((xdat.fmat) m2))
249
250
251compat :: Matrix a -> Matrix b -> Bool
252compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2
253
254------------------------------------------------------------------
255
256{- | Supported matrix elements.
257
258 This class provides optimized internal
259 operations for selected element types.
260 It provides unoptimised defaults for any 'Storable' type,
261 so you can create instances simply as:
262 @instance Element Foo@.
263-}
264class (Storable a) => Element a where
265 subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position
266 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
267 -> Matrix a -> Matrix a
268 subMatrixD = subMatrix'
269 transdata :: Int -> Vector a -> Int -> Vector a
270 transdata = transdataP -- transdata'
271 constantD :: a -> Int -> Vector a
272 constantD = constantP -- constant'
273
274
275instance Element Float where
276 transdata = transdataAux ctransF
277 constantD = constantAux cconstantF
278
279instance Element Double where
280 transdata = transdataAux ctransR
281 constantD = constantAux cconstantR
282
283instance Element (Complex Float) where
284 transdata = transdataAux ctransQ
285 constantD = constantAux cconstantQ
286
287instance Element (Complex Double) where
288 transdata = transdataAux ctransC
289 constantD = constantAux cconstantC
290
291-------------------------------------------------------------------
292
293transdata' :: Storable a => Int -> Vector a -> Int -> Vector a
294transdata' c1 v c2 =
295 if noneed
296 then v
297 else unsafePerformIO $ do
298 w <- createVector (r2*c2)
299 unsafeWith v $ \p ->
300 unsafeWith w $ \q -> do
301 let go (-1) _ = return ()
302 go !i (-1) = go (i-1) (c1-1)
303 go !i !j = do x <- peekElemOff p (i*c1+j)
304 pokeElemOff q (j*c2+i) x
305 go i (j-1)
306 go (r1-1) (c1-1)
307 return w
308 where r1 = dim v `div` c1
309 r2 = dim v `div` c2
310 noneed = dim v == 0 || r1 == 1 || c1 == 1
311
312-- {-# SPECIALIZE transdata' :: Int -> Vector Double -> Int -> Vector Double #-}
313-- {-# SPECIALIZE transdata' :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) #-}
314
315-- I don't know how to specialize...
316-- The above pragmas only seem to work on top level defs
317-- Fortunately everything seems to work using the above class
318
319-- C versions, still a little faster:
320
321transdataAux fun c1 d c2 =
322 if noneed
323 then d
324 else unsafePerformIO $ do
325 v <- createVector (dim d)
326 unsafeWith d $ \pd ->
327 unsafeWith v $ \pv ->
328 fun (fi r1) (fi c1) pd (fi r2) (fi c2) pv // check "transdataAux"
329 return v
330 where r1 = dim d `div` c1
331 r2 = dim d `div` c2
332 noneed = dim d == 0 || r1 == 1 || c1 == 1
333
334transdataP :: Storable a => Int -> Vector a -> Int -> Vector a
335transdataP c1 d c2 =
336 if noneed
337 then d
338 else unsafePerformIO $ do
339 v <- createVector (dim d)
340 unsafeWith d $ \pd ->
341 unsafeWith v $ \pv ->
342 ctransP (fi r1) (fi c1) (castPtr pd) (fi sz) (fi r2) (fi c2) (castPtr pv) (fi sz) // check "transdataP"
343 return v
344 where r1 = dim d `div` c1
345 r2 = dim d `div` c2
346 sz = sizeOf (d @> 0)
347 noneed = dim d == 0 || r1 == 1 || c1 == 1
348
349foreign import ccall unsafe "transF" ctransF :: TFMFM
350foreign import ccall unsafe "transR" ctransR :: TMM
351foreign import ccall unsafe "transQ" ctransQ :: TQMQM
352foreign import ccall unsafe "transC" ctransC :: TCMCM
353foreign import ccall unsafe "transP" ctransP :: CInt -> CInt -> Ptr () -> CInt -> CInt -> CInt -> Ptr () -> CInt -> IO CInt
354
355----------------------------------------------------------------------
356
357constant' v n = unsafePerformIO $ do
358 w <- createVector n
359 unsafeWith w $ \p -> do
360 let go (-1) = return ()
361 go !k = pokeElemOff p k v >> go (k-1)
362 go (n-1)
363 return w
364
365-- C versions
366
367constantAux fun x n = unsafePerformIO $ do
368 v <- createVector n
369 px <- newArray [x]
370 app1 (fun px) vec v "constantAux"
371 free px
372 return v
373
374constantF :: Float -> Int -> Vector Float
375constantF = constantAux cconstantF
376foreign import ccall unsafe "constantF" cconstantF :: Ptr Float -> TF
377
378constantR :: Double -> Int -> Vector Double
379constantR = constantAux cconstantR
380foreign import ccall unsafe "constantR" cconstantR :: Ptr Double -> TV
381
382constantQ :: Complex Float -> Int -> Vector (Complex Float)
383constantQ = constantAux cconstantQ
384foreign import ccall unsafe "constantQ" cconstantQ :: Ptr (Complex Float) -> TQV
385
386constantC :: Complex Double -> Int -> Vector (Complex Double)
387constantC = constantAux cconstantC
388foreign import ccall unsafe "constantC" cconstantC :: Ptr (Complex Double) -> TCV
389
390constantP :: Storable a => a -> Int -> Vector a
391constantP a n = unsafePerformIO $ do
392 let sz = sizeOf a
393 v <- createVector n
394 unsafeWith v $ \p -> do
395 alloca $ \k -> do
396 poke k a
397 cconstantP (castPtr k) (fi n) (castPtr p) (fi sz) // check "constantP"
398 return v
399foreign import ccall unsafe "constantP" cconstantP :: Ptr () -> CInt -> Ptr () -> CInt -> IO CInt
400
401----------------------------------------------------------------------
402
403-- | Extracts a submatrix from a matrix.
404subMatrix :: Element a
405 => (Int,Int) -- ^ (r0,c0) starting position
406 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
407 -> Matrix a -- ^ input matrix
408 -> Matrix a -- ^ result
409subMatrix (r0,c0) (rt,ct) m
410 | 0 <= r0 && 0 <= rt && r0+rt <= (rows m) &&
411 0 <= c0 && 0 <= ct && c0+ct <= (cols m) = subMatrixD (r0,c0) (rt,ct) m
412 | otherwise = error $ "wrong subMatrix "++
413 show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m)
414
415subMatrix'' (r0,c0) (rt,ct) c v = unsafePerformIO $ do
416 w <- createVector (rt*ct)
417 unsafeWith v $ \p ->
418 unsafeWith w $ \q -> do
419 let go (-1) _ = return ()
420 go !i (-1) = go (i-1) (ct-1)
421 go !i !j = do x <- peekElemOff p ((i+r0)*c+j+c0)
422 pokeElemOff q (i*ct+j) x
423 go i (j-1)
424 go (rt-1) (ct-1)
425 return w
426
427subMatrix' (r0,c0) (rt,ct) (Matrix { icols = c, xdat = v, order = RowMajor}) = Matrix rt ct (subMatrix'' (r0,c0) (rt,ct) c v) RowMajor
428subMatrix' (r0,c0) (rt,ct) m = trans $ subMatrix' (c0,r0) (ct,rt) (trans m)
429
430--------------------------------------------------------------------------
431
432-- | Saves a matrix as 2D ASCII table.
433saveMatrix :: FilePath
434 -> String -- ^ format (%f, %g, %e)
435 -> Matrix Double
436 -> IO ()
437saveMatrix filename fmt m = do
438 charname <- newCString filename
439 charfmt <- newCString fmt
440 let o = if orderOf m == RowMajor then 1 else 0
441 app1 (matrix_fprintf charname charfmt o) mat m "matrix_fprintf"
442 free charname
443 free charfmt
444
445foreign import ccall unsafe "matrix_fprintf" matrix_fprintf :: Ptr CChar -> Ptr CChar -> CInt -> TM
446
447----------------------------------------------------------------------
448
449maxZ xs = if minimum xs == 0 then 0 else maximum xs
450
451conformMs ms = map (conformMTo (r,c)) ms
452 where
453 r = maxZ (map rows ms)
454 c = maxZ (map cols ms)
455
456
457conformVs vs = map (conformVTo n) vs
458 where
459 n = maxZ (map dim vs)
460
461conformMTo (r,c) m
462 | size m == (r,c) = m
463 | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c))
464 | size m == (r,1) = repCols c m
465 | size m == (1,c) = repRows r m
466 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")"
467
468conformVTo n v
469 | dim v == n = v
470 | dim v == 1 = constantD (v@>0) n
471 | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n
472
473repRows n x = fromRows (replicate n (flatten x))
474repCols n x = fromColumns (replicate n (flatten x))
475
476size m = (rows m, cols m)
477
478shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")"
479
480emptyM r c = matrixFromVector RowMajor r c (fromList[])
481
482----------------------------------------------------------------------
483
484instance (Storable t, NFData t) => NFData (Matrix t)
485 where
486 rnf m | d > 0 = rnf (v @> 0)
487 | otherwise = ()
488 where
489 d = dim v
490 v = xdat m
491