summaryrefslogtreecommitdiff
path: root/packages/base/src/Data
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Data')
-rw-r--r--packages/base/src/Data/Packed/Internal/Matrix.hs583
1 files changed, 0 insertions, 583 deletions
diff --git a/packages/base/src/Data/Packed/Internal/Matrix.hs b/packages/base/src/Data/Packed/Internal/Matrix.hs
deleted file mode 100644
index 66eef15..0000000
--- a/packages/base/src/Data/Packed/Internal/Matrix.hs
+++ /dev/null
@@ -1,583 +0,0 @@
1{-# LANGUAGE ForeignFunctionInterface #-}
2{-# LANGUAGE FlexibleContexts #-}
3{-# LANGUAGE FlexibleInstances #-}
4{-# LANGUAGE BangPatterns #-}
5
6-- |
7-- Module : Data.Packed.Internal.Matrix
8-- Copyright : (c) Alberto Ruiz 2007
9-- License : BSD3
10-- Maintainer : Alberto Ruiz
11-- Stability : provisional
12--
13-- Internal matrix representation
14--
15
16module Data.Packed.Internal.Matrix(
17 Matrix(..), rows, cols, cdat, fdat,
18 MatrixOrder(..), orderOf,
19 createMatrix, mat, omat,
20 cmat, fmat,
21 toLists, flatten, reshape,
22 Element(..),
23 trans,
24 fromRows, toRows, fromColumns, toColumns,
25 matrixFromVector,
26 subMatrix,
27 liftMatrix, liftMatrix2,
28 (@@>), atM',
29 singleton,
30 emptyM,
31 size, shSize, conformVs, conformMs, conformVTo, conformMTo
32) where
33
34import Data.Packed.Internal.Common
35import Data.Packed.Internal.Signatures
36import Data.Packed.Internal.Vector
37
38import Foreign.Marshal.Alloc(alloca, free)
39import Foreign.Marshal.Array(newArray)
40import Foreign.Ptr(Ptr, castPtr)
41import Foreign.Storable(Storable, peekElemOff, pokeElemOff, poke, sizeOf)
42import Data.Complex(Complex)
43import Foreign.C.Types
44import System.IO.Unsafe(unsafePerformIO)
45import Control.DeepSeq
46
47-----------------------------------------------------------------
48
49{- Design considerations for the Matrix Type
50 -----------------------------------------
51
52- we must easily handle both row major and column major order,
53 for bindings to LAPACK and GSL/C
54
55- we'd like to simplify redundant matrix transposes:
56 - Some of them arise from the order requirements of some functions
57 - some functions (matrix product) admit transposed arguments
58
59- maybe we don't really need this kind of simplification:
60 - more complex code
61 - some computational overhead
62 - only appreciable gain in code with a lot of redundant transpositions
63 and cheap matrix computations
64
65- we could carry both the matrix and its (lazily computed) transpose.
66 This may save some transpositions, but it is necessary to keep track of the
67 data which is actually computed to be used by functions like the matrix product
68 which admit both orders.
69
70- but if we need the transposed data and it is not in the structure, we must make
71 sure that we touch the same foreignptr that is used in the computation.
72
73- a reasonable solution is using two constructors for a matrix. Transposition just
74 "flips" the constructor. Actual data transposition is not done if followed by a
75 matrix product or another transpose.
76
77-}
78
79data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq)
80
81transOrder RowMajor = ColumnMajor
82transOrder ColumnMajor = RowMajor
83{- | Matrix representation suitable for BLAS\/LAPACK computations.
84
85The elements are stored in a continuous memory array.
86
87-}
88
89data Matrix t = Matrix { irows :: {-# UNPACK #-} !Int
90 , icols :: {-# UNPACK #-} !Int
91 , xdat :: {-# UNPACK #-} !(Vector t)
92 , order :: !MatrixOrder }
93-- RowMajor: preferred by C, fdat may require a transposition
94-- ColumnMajor: preferred by LAPACK, cdat may require a transposition
95
96cdat = xdat
97fdat = xdat
98
99rows :: Matrix t -> Int
100rows = irows
101
102cols :: Matrix t -> Int
103cols = icols
104
105orderOf :: Matrix t -> MatrixOrder
106orderOf = order
107
108stepRow :: Matrix t -> CInt
109stepRow Matrix {icols = c, order = RowMajor } = fromIntegral c
110stepRow _ = 1
111
112stepCol :: Matrix t -> CInt
113stepCol Matrix {irows = r, order = ColumnMajor } = fromIntegral r
114stepCol _ = 1
115
116
117-- | Matrix transpose.
118trans :: Matrix t -> Matrix t
119trans Matrix {irows = r, icols = c, xdat = d, order = o } = Matrix { irows = c, icols = r, xdat = d, order = transOrder o}
120
121cmat :: (Element t) => Matrix t -> Matrix t
122cmat m@Matrix{order = RowMajor} = m
123cmat Matrix {irows = r, icols = c, xdat = d, order = ColumnMajor } = Matrix { irows = r, icols = c, xdat = transdata r d c, order = RowMajor}
124
125fmat :: (Element t) => Matrix t -> Matrix t
126fmat m@Matrix{order = ColumnMajor} = m
127fmat Matrix {irows = r, icols = c, xdat = d, order = RowMajor } = Matrix { irows = r, icols = c, xdat = transdata c d r, order = ColumnMajor}
128
129-- C-Haskell matrix adapter
130-- mat :: Adapt (CInt -> CInt -> Ptr t -> r) (Matrix t) r
131
132mat :: (Storable t) => Matrix t -> (((CInt -> CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b
133mat a f =
134 unsafeWith (xdat a) $ \p -> do
135 let m g = do
136 g (fi (rows a)) (fi (cols a)) p
137 f m
138
139omat :: (Storable t) => Matrix t -> (((CInt -> CInt -> CInt -> CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b
140omat a f =
141 unsafeWith (xdat a) $ \p -> do
142 let m g = do
143 g (fi (rows a)) (fi (cols a)) (stepRow a) (stepCol a) p
144 f m
145
146
147{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose.
148
149>>> flatten (ident 3)
150fromList [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]
151
152-}
153flatten :: Element t => Matrix t -> Vector t
154flatten = xdat . cmat
155
156{-
157type Mt t s = Int -> Int -> Ptr t -> s
158
159infixr 6 ::>
160type t ::> s = Mt t s
161-}
162
163-- | the inverse of 'Data.Packed.Matrix.fromLists'
164toLists :: (Element t) => Matrix t -> [[t]]
165toLists m = splitEvery (cols m) . toList . flatten $ m
166
167-- | Create a matrix from a list of vectors.
168-- All vectors must have the same dimension,
169-- or dimension 1, which is are automatically expanded.
170fromRows :: Element t => [Vector t] -> Matrix t
171fromRows [] = emptyM 0 0
172fromRows vs = case compatdim (map dim vs) of
173 Nothing -> error $ "fromRows expects vectors with equal sizes (or singletons), given: " ++ show (map dim vs)
174 Just 0 -> emptyM r 0
175 Just c -> matrixFromVector RowMajor r c . vjoin . map (adapt c) $ vs
176 where
177 r = length vs
178 adapt c v
179 | c == 0 = fromList[]
180 | dim v == c = v
181 | otherwise = constantD (v@>0) c
182
183-- | extracts the rows of a matrix as a list of vectors
184toRows :: Element t => Matrix t -> [Vector t]
185toRows m
186 | c == 0 = replicate r (fromList[])
187 | otherwise = toRows' 0
188 where
189 v = flatten m
190 r = rows m
191 c = cols m
192 toRows' k | k == r*c = []
193 | otherwise = subVector k c v : toRows' (k+c)
194
195-- | Creates a matrix from a list of vectors, as columns
196fromColumns :: Element t => [Vector t] -> Matrix t
197fromColumns m = trans . fromRows $ m
198
199-- | Creates a list of vectors from the columns of a matrix
200toColumns :: Element t => Matrix t -> [Vector t]
201toColumns m = toRows . trans $ m
202
203-- | Reads a matrix position.
204(@@>) :: Storable t => Matrix t -> (Int,Int) -> t
205infixl 9 @@>
206m@Matrix {irows = r, icols = c} @@> (i,j)
207 | safe = if i<0 || i>=r || j<0 || j>=c
208 then error "matrix indexing out of range"
209 else atM' m i j
210 | otherwise = atM' m i j
211{-# INLINE (@@>) #-}
212
213-- Unsafe matrix access without range checking
214atM' Matrix {icols = c, xdat = v, order = RowMajor} i j = v `at'` (i*c+j)
215atM' Matrix {irows = r, xdat = v, order = ColumnMajor} i j = v `at'` (j*r+i)
216{-# INLINE atM' #-}
217
218------------------------------------------------------------------
219
220matrixFromVector o r c v
221 | r * c == dim v = m
222 | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m
223 where
224 m = Matrix { irows = r, icols = c, xdat = v, order = o }
225
226-- allocates memory for a new matrix
227createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a)
228createMatrix ord r c = do
229 p <- createVector (r*c)
230 return (matrixFromVector ord r c p)
231
232{- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. (GNU-Octave groups by columns. To do it you can define @reshapeF r = trans . reshape r@
233where r is the desired number of rows.)
234
235>>> reshape 4 (fromList [1..12])
236(3><4)
237 [ 1.0, 2.0, 3.0, 4.0
238 , 5.0, 6.0, 7.0, 8.0
239 , 9.0, 10.0, 11.0, 12.0 ]
240
241-}
242reshape :: Storable t => Int -> Vector t -> Matrix t
243reshape 0 v = matrixFromVector RowMajor 0 0 v
244reshape c v = matrixFromVector RowMajor (dim v `div` c) c v
245
246singleton x = reshape 1 (fromList [x])
247
248-- | application of a vector function on the flattened matrix elements
249liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b
250liftMatrix f Matrix { irows = r, icols = c, xdat = d, order = o } = matrixFromVector o r c (f d)
251
252-- | application of a vector function on the flattened matrices elements
253liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
254liftMatrix2 f m1 m2
255 | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2"
256 | otherwise = case orderOf m1 of
257 RowMajor -> matrixFromVector RowMajor (rows m1) (cols m1) (f (xdat m1) (flatten m2))
258 ColumnMajor -> matrixFromVector ColumnMajor (rows m1) (cols m1) (f (xdat m1) ((xdat.fmat) m2))
259
260
261compat :: Matrix a -> Matrix b -> Bool
262compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2
263
264------------------------------------------------------------------
265
266{- | Supported matrix elements.
267
268 This class provides optimized internal
269 operations for selected element types.
270 It provides unoptimised defaults for any 'Storable' type,
271 so you can create instances simply as:
272
273 >instance Element Foo
274-}
275class (Storable a) => Element a where
276 subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position
277 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
278 -> Matrix a -> Matrix a
279 subMatrixD = subMatrix'
280 transdata :: Int -> Vector a -> Int -> Vector a
281 transdata = transdataP -- transdata'
282 constantD :: a -> Int -> Vector a
283 constantD = constantP -- constant'
284 extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> Matrix a
285 sortI :: Ord a => Vector a -> Vector CInt
286 sortV :: Ord a => Vector a -> Vector a
287 compareV :: Ord a => Vector a -> Vector a -> Vector CInt
288 selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a
289 remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a
290
291
292instance Element Float where
293 transdata = transdataAux ctransF
294 constantD = constantAux cconstantF
295 extractR = extractAux c_extractF
296 sortI = sortIdxF
297 sortV = sortValF
298 compareV = compareF
299 selectV = selectF
300 remapM = remapF
301
302instance Element Double where
303 transdata = transdataAux ctransR
304 constantD = constantAux cconstantR
305 extractR = extractAux c_extractD
306 sortI = sortIdxD
307 sortV = sortValD
308 compareV = compareD
309 selectV = selectD
310 remapM = remapD
311
312
313instance Element (Complex Float) where
314 transdata = transdataAux ctransQ
315 constantD = constantAux cconstantQ
316 extractR = extractAux c_extractQ
317 sortI = undefined
318 sortV = undefined
319 compareV = undefined
320 selectV = selectQ
321 remapM = remapQ
322
323
324instance Element (Complex Double) where
325 transdata = transdataAux ctransC
326 constantD = constantAux cconstantC
327 extractR = extractAux c_extractC
328 sortI = undefined
329 sortV = undefined
330 compareV = undefined
331 selectV = selectC
332 remapM = remapC
333
334instance Element (CInt) where
335 transdata = transdataAux ctransI
336 constantD = constantAux cconstantI
337 extractR = extractAux c_extractI
338 sortI = sortIdxI
339 sortV = sortValI
340 compareV = compareI
341 selectV = selectI
342 remapM = remapI
343
344-------------------------------------------------------------------
345
346transdataAux fun c1 d c2 =
347 if noneed
348 then d
349 else unsafePerformIO $ do
350 -- putStrLn "T"
351 v <- createVector (dim d)
352 unsafeWith d $ \pd ->
353 unsafeWith v $ \pv ->
354 fun (fi r1) (fi c1) pd (fi r2) (fi c2) pv // check "transdataAux"
355 return v
356 where r1 = dim d `div` c1
357 r2 = dim d `div` c2
358 noneed = dim d == 0 || r1 == 1 || c1 == 1
359
360transdataP :: Storable a => Int -> Vector a -> Int -> Vector a
361transdataP c1 d c2 =
362 if noneed
363 then d
364 else unsafePerformIO $ do
365 v <- createVector (dim d)
366 unsafeWith d $ \pd ->
367 unsafeWith v $ \pv ->
368 ctransP (fi r1) (fi c1) (castPtr pd) (fi sz) (fi r2) (fi c2) (castPtr pv) (fi sz) // check "transdataP"
369 return v
370 where r1 = dim d `div` c1
371 r2 = dim d `div` c2
372 sz = sizeOf (d @> 0)
373 noneed = dim d == 0 || r1 == 1 || c1 == 1
374
375foreign import ccall unsafe "transF" ctransF :: TFMFM
376foreign import ccall unsafe "transR" ctransR :: TMM
377foreign import ccall unsafe "transQ" ctransQ :: TQMQM
378foreign import ccall unsafe "transC" ctransC :: TCMCM
379foreign import ccall unsafe "transI" ctransI :: CM CInt (CM CInt (IO CInt))
380foreign import ccall unsafe "transP" ctransP :: CInt -> CInt -> Ptr () -> CInt -> CInt -> CInt -> Ptr () -> CInt -> IO CInt
381
382----------------------------------------------------------------------
383
384constantAux fun x n = unsafePerformIO $ do
385 v <- createVector n
386 px <- newArray [x]
387 app1 (fun px) vec v "constantAux"
388 free px
389 return v
390
391foreign import ccall unsafe "constantF" cconstantF :: Ptr Float -> TF
392
393foreign import ccall unsafe "constantR" cconstantR :: Ptr Double -> TV
394
395foreign import ccall unsafe "constantQ" cconstantQ :: Ptr (Complex Float) -> TQV
396
397foreign import ccall unsafe "constantC" cconstantC :: Ptr (Complex Double) -> TCV
398
399foreign import ccall unsafe "constantI" cconstantI :: Ptr CInt -> CV CInt (IO CInt)
400
401constantP :: Storable a => a -> Int -> Vector a
402constantP a n = unsafePerformIO $ do
403 let sz = sizeOf a
404 v <- createVector n
405 unsafeWith v $ \p -> do
406 alloca $ \k -> do
407 poke k a
408 cconstantP (castPtr k) (fi n) (castPtr p) (fi sz) // check "constantP"
409 return v
410foreign import ccall unsafe "constantP" cconstantP :: Ptr () -> CInt -> Ptr () -> CInt -> IO CInt
411
412----------------------------------------------------------------------
413
414-- | Extracts a submatrix from a matrix.
415subMatrix :: Element a
416 => (Int,Int) -- ^ (r0,c0) starting position
417 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
418 -> Matrix a -- ^ input matrix
419 -> Matrix a -- ^ result
420subMatrix (r0,c0) (rt,ct) m
421 | 0 <= r0 && 0 <= rt && r0+rt <= (rows m) &&
422 0 <= c0 && 0 <= ct && c0+ct <= (cols m) = subMatrixD (r0,c0) (rt,ct) m
423 | otherwise = error $ "wrong subMatrix "++
424 show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m)
425
426subMatrix'' (r0,c0) (rt,ct) c v = unsafePerformIO $ do
427 w <- createVector (rt*ct)
428 unsafeWith v $ \p ->
429 unsafeWith w $ \q -> do
430 let go (-1) _ = return ()
431 go !i (-1) = go (i-1) (ct-1)
432 go !i !j = do x <- peekElemOff p ((i+r0)*c+j+c0)
433 pokeElemOff q (i*ct+j) x
434 go i (j-1)
435 go (rt-1) (ct-1)
436 return w
437
438subMatrix' (r0,c0) (rt,ct) (Matrix { icols = c, xdat = v, order = RowMajor}) = Matrix rt ct (subMatrix'' (r0,c0) (rt,ct) c v) RowMajor
439subMatrix' (r0,c0) (rt,ct) m = trans $ subMatrix' (c0,r0) (ct,rt) (trans m)
440
441--------------------------------------------------------------------------
442
443maxZ xs = if minimum xs == 0 then 0 else maximum xs
444
445conformMs ms = map (conformMTo (r,c)) ms
446 where
447 r = maxZ (map rows ms)
448 c = maxZ (map cols ms)
449
450
451conformVs vs = map (conformVTo n) vs
452 where
453 n = maxZ (map dim vs)
454
455conformMTo (r,c) m
456 | size m == (r,c) = m
457 | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c))
458 | size m == (r,1) = repCols c m
459 | size m == (1,c) = repRows r m
460 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")"
461
462conformVTo n v
463 | dim v == n = v
464 | dim v == 1 = constantD (v@>0) n
465 | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n
466
467repRows n x = fromRows (replicate n (flatten x))
468repCols n x = fromColumns (replicate n (flatten x))
469
470size m = (rows m, cols m)
471
472shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")"
473
474emptyM r c = matrixFromVector RowMajor r c (fromList[])
475
476----------------------------------------------------------------------
477
478instance (Storable t, NFData t) => NFData (Matrix t)
479 where
480 rnf m | d > 0 = rnf (v @> 0)
481 | otherwise = ()
482 where
483 d = dim v
484 v = xdat m
485
486---------------------------------------------------------------
487
488extractAux f m moder vr modec vc = unsafePerformIO $ do
489 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr
490 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc
491 r <- createMatrix RowMajor nr nc
492 app4 (f moder modec) vec vr vec vc omat m omat r "extractAux"
493 return r
494
495type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt))))
496
497foreign import ccall unsafe "extractD" c_extractD :: Extr Double
498foreign import ccall unsafe "extractF" c_extractF :: Extr Float
499foreign import ccall unsafe "extractC" c_extractC :: Extr (Complex Double)
500foreign import ccall unsafe "extractQ" c_extractQ :: Extr (Complex Float)
501foreign import ccall unsafe "extractI" c_extractI :: Extr CInt
502
503--------------------------------------------------------------------------------
504
505sortG f v = unsafePerformIO $ do
506 r <- createVector (dim v)
507 app2 f vec v vec r "sortG"
508 return r
509
510sortIdxD = sortG c_sort_indexD
511sortIdxF = sortG c_sort_indexF
512sortIdxI = sortG c_sort_indexI
513
514sortValD = sortG c_sort_valD
515sortValF = sortG c_sort_valF
516sortValI = sortG c_sort_valI
517
518foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt))
519foreign import ccall unsafe "sort_indexF" c_sort_indexF :: CV Float (CV CInt (IO CInt))
520foreign import ccall unsafe "sort_indexI" c_sort_indexI :: CV CInt (CV CInt (IO CInt))
521
522foreign import ccall unsafe "sort_valuesD" c_sort_valD :: CV Double (CV Double (IO CInt))
523foreign import ccall unsafe "sort_valuesF" c_sort_valF :: CV Float (CV Float (IO CInt))
524foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV CInt (CV CInt (IO CInt))
525
526--------------------------------------------------------------------------------
527
528compareG f u v = unsafePerformIO $ do
529 r <- createVector (dim v)
530 app3 f vec u vec v vec r "compareG"
531 return r
532
533compareD = compareG c_compareD
534compareF = compareG c_compareF
535compareI = compareG c_compareI
536
537foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt)))
538foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV CInt (IO CInt)))
539foreign import ccall unsafe "compareI" c_compareI :: CV CInt (CV CInt (CV CInt (IO CInt)))
540
541--------------------------------------------------------------------------------
542
543selectG f c u v w = unsafePerformIO $ do
544 r <- createVector (dim v)
545 app5 f vec c vec u vec v vec w vec r "selectG"
546 return r
547
548selectD = selectG c_selectD
549selectF = selectG c_selectF
550selectI = selectG c_selectI
551selectC = selectG c_selectC
552selectQ = selectG c_selectQ
553
554type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt)))))
555
556foreign import ccall unsafe "chooseD" c_selectD :: Sel Double
557foreign import ccall unsafe "chooseF" c_selectF :: Sel Float
558foreign import ccall unsafe "chooseI" c_selectI :: Sel CInt
559foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double)
560foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float)
561
562---------------------------------------------------------------------------
563
564remapG f i j m = unsafePerformIO $ do
565 r <- createMatrix RowMajor (rows i) (cols i)
566 app4 f omat i omat j omat m omat r "remapG"
567 return r
568
569remapD = remapG c_remapD
570remapF = remapG c_remapF
571remapI = remapG c_remapI
572remapC = remapG c_remapC
573remapQ = remapG c_remapQ
574
575type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt))))
576
577foreign import ccall unsafe "remapD" c_remapD :: Rem Double
578foreign import ccall unsafe "remapF" c_remapF :: Rem Float
579foreign import ccall unsafe "remapI" c_remapI :: Rem CInt
580foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double)
581foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float)
582
583