summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Matrix.hs
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2015-06-05 16:59:29 +0200
committerAlberto Ruiz <aruiz@um.es>2015-06-05 16:59:29 +0200
commit42cc173794ee2cb52008ab642b9ab28fe6ad3eba (patch)
treeb2b4da410a8d2b91d3e6f5e47297616b7bcf831c /packages/base/src/Internal/Matrix.hs
parent9f0b5e4e0800cf3effb7a7f3b0ef2662c72fdb57 (diff)
move internal matrix
Diffstat (limited to 'packages/base/src/Internal/Matrix.hs')
-rw-r--r--packages/base/src/Internal/Matrix.hs526
1 files changed, 526 insertions, 0 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
new file mode 100644
index 0000000..44365d0
--- /dev/null
+++ b/packages/base/src/Internal/Matrix.hs
@@ -0,0 +1,526 @@
1{-# LANGUAGE ForeignFunctionInterface #-}
2{-# LANGUAGE FlexibleContexts #-}
3{-# LANGUAGE FlexibleInstances #-}
4{-# LANGUAGE BangPatterns #-}
5{-# LANGUAGE TypeOperators #-}
6
7-- |
8-- Module : Internal.Matrix
9-- Copyright : (c) Alberto Ruiz 2007-15
10-- License : BSD3
11-- Maintainer : Alberto Ruiz
12-- Stability : provisional
13--
14-- Internal matrix representation
15--
16
17module Internal.Matrix where
18
19
20import Internal.Tools ( splitEvery, fi, compatdim, (//) )
21import Internal.Vector
22import Internal.Devel
23import Internal.Vectorized
24import Data.Vector.Storable ( unsafeWith, fromList )
25import Foreign.Marshal.Alloc ( free )
26import Foreign.Ptr ( Ptr )
27import Foreign.Storable ( Storable )
28import Data.Complex ( Complex )
29import Foreign.C.Types ( CInt(..) )
30import Foreign.C.String ( CString, newCString )
31import System.IO.Unsafe ( unsafePerformIO )
32import Control.DeepSeq ( NFData(..) )
33
34
35-----------------------------------------------------------------
36
37{- Design considerations for the Matrix Type
38 -----------------------------------------
39
40- we must easily handle both row major and column major order,
41 for bindings to LAPACK and GSL/C
42
43- we'd like to simplify redundant matrix transposes:
44 - Some of them arise from the order requirements of some functions
45 - some functions (matrix product) admit transposed arguments
46
47- maybe we don't really need this kind of simplification:
48 - more complex code
49 - some computational overhead
50 - only appreciable gain in code with a lot of redundant transpositions
51 and cheap matrix computations
52
53- we could carry both the matrix and its (lazily computed) transpose.
54 This may save some transpositions, but it is necessary to keep track of the
55 data which is actually computed to be used by functions like the matrix product
56 which admit both orders.
57
58- but if we need the transposed data and it is not in the structure, we must make
59 sure that we touch the same foreignptr that is used in the computation.
60
61- a reasonable solution is using two constructors for a matrix. Transposition just
62 "flips" the constructor. Actual data transposition is not done if followed by a
63 matrix product or another transpose.
64
65-}
66
67data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq)
68
69transOrder RowMajor = ColumnMajor
70transOrder ColumnMajor = RowMajor
71{- | Matrix representation suitable for BLAS\/LAPACK computations.
72
73The elements are stored in a continuous memory array.
74
75-}
76
77data Matrix t = Matrix { irows :: {-# UNPACK #-} !Int
78 , icols :: {-# UNPACK #-} !Int
79 , xdat :: {-# UNPACK #-} !(Vector t)
80 , order :: !MatrixOrder }
81-- RowMajor: preferred by C, fdat may require a transposition
82-- ColumnMajor: preferred by LAPACK, cdat may require a transposition
83
84--cdat = xdat
85--fdat = xdat
86
87rows :: Matrix t -> Int
88rows = irows
89
90cols :: Matrix t -> Int
91cols = icols
92
93orderOf :: Matrix t -> MatrixOrder
94orderOf = order
95
96stepRow :: Matrix t -> CInt
97stepRow Matrix {icols = c, order = RowMajor } = fromIntegral c
98stepRow _ = 1
99
100stepCol :: Matrix t -> CInt
101stepCol Matrix {irows = r, order = ColumnMajor } = fromIntegral r
102stepCol _ = 1
103
104
105-- | Matrix transpose.
106trans :: Matrix t -> Matrix t
107trans Matrix {irows = r, icols = c, xdat = d, order = o } = Matrix { irows = c, icols = r, xdat = d, order = transOrder o}
108
109cmat :: (Element t) => Matrix t -> Matrix t
110cmat m@Matrix{order = RowMajor} = m
111cmat Matrix {irows = r, icols = c, xdat = d, order = ColumnMajor } = Matrix { irows = r, icols = c, xdat = transdata r d c, order = RowMajor}
112
113fmat :: (Element t) => Matrix t -> Matrix t
114fmat m@Matrix{order = ColumnMajor} = m
115fmat Matrix {irows = r, icols = c, xdat = d, order = RowMajor } = Matrix { irows = r, icols = c, xdat = transdata c d r, order = ColumnMajor}
116
117-- C-Haskell matrix adapter
118-- mat :: Adapt (CInt -> CInt -> Ptr t -> r) (Matrix t) r
119
120mat :: (Storable t) => Matrix t -> (((CInt -> CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b
121mat a f =
122 unsafeWith (xdat a) $ \p -> do
123 let m g = do
124 g (fi (rows a)) (fi (cols a)) p
125 f m
126
127omat :: (Storable t) => Matrix t -> (((CInt -> CInt -> CInt -> CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b
128omat a f =
129 unsafeWith (xdat a) $ \p -> do
130 let m g = do
131 g (fi (rows a)) (fi (cols a)) (stepRow a) (stepCol a) p
132 f m
133
134
135{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose.
136
137>>> flatten (ident 3)
138fromList [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]
139
140-}
141flatten :: Element t => Matrix t -> Vector t
142flatten = xdat . cmat
143
144{-
145type Mt t s = Int -> Int -> Ptr t -> s
146
147infixr 6 ::>
148type t ::> s = Mt t s
149-}
150
151-- | the inverse of 'Data.Packed.Matrix.fromLists'
152toLists :: (Element t) => Matrix t -> [[t]]
153toLists m = splitEvery (cols m) . toList . flatten $ m
154
155-- | Create a matrix from a list of vectors.
156-- All vectors must have the same dimension,
157-- or dimension 1, which is are automatically expanded.
158fromRows :: Element t => [Vector t] -> Matrix t
159fromRows [] = emptyM 0 0
160fromRows vs = case compatdim (map dim vs) of
161 Nothing -> error $ "fromRows expects vectors with equal sizes (or singletons), given: " ++ show (map dim vs)
162 Just 0 -> emptyM r 0
163 Just c -> matrixFromVector RowMajor r c . vjoin . map (adapt c) $ vs
164 where
165 r = length vs
166 adapt c v
167 | c == 0 = fromList[]
168 | dim v == c = v
169 | otherwise = constantD (v@>0) c
170
171-- | extracts the rows of a matrix as a list of vectors
172toRows :: Element t => Matrix t -> [Vector t]
173toRows m
174 | c == 0 = replicate r (fromList[])
175 | otherwise = toRows' 0
176 where
177 v = flatten m
178 r = rows m
179 c = cols m
180 toRows' k | k == r*c = []
181 | otherwise = subVector k c v : toRows' (k+c)
182
183-- | Creates a matrix from a list of vectors, as columns
184fromColumns :: Element t => [Vector t] -> Matrix t
185fromColumns m = trans . fromRows $ m
186
187-- | Creates a list of vectors from the columns of a matrix
188toColumns :: Element t => Matrix t -> [Vector t]
189toColumns m = toRows . trans $ m
190
191-- | Reads a matrix position.
192(@@>) :: Storable t => Matrix t -> (Int,Int) -> t
193infixl 9 @@>
194m@Matrix {irows = r, icols = c} @@> (i,j)
195 | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range"
196 | otherwise = atM' m i j
197{-# INLINE (@@>) #-}
198
199-- Unsafe matrix access without range checking
200atM' Matrix {icols = c, xdat = v, order = RowMajor} i j = v `at'` (i*c+j)
201atM' Matrix {irows = r, xdat = v, order = ColumnMajor} i j = v `at'` (j*r+i)
202{-# INLINE atM' #-}
203
204------------------------------------------------------------------
205
206matrixFromVector o r c v
207 | r * c == dim v = m
208 | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m
209 where
210 m = Matrix { irows = r, icols = c, xdat = v, order = o }
211
212-- allocates memory for a new matrix
213createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a)
214createMatrix ord r c = do
215 p <- createVector (r*c)
216 return (matrixFromVector ord r c p)
217
218{- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. (GNU-Octave groups by columns. To do it you can define @reshapeF r = trans . reshape r@
219where r is the desired number of rows.)
220
221>>> reshape 4 (fromList [1..12])
222(3><4)
223 [ 1.0, 2.0, 3.0, 4.0
224 , 5.0, 6.0, 7.0, 8.0
225 , 9.0, 10.0, 11.0, 12.0 ]
226
227-}
228reshape :: Storable t => Int -> Vector t -> Matrix t
229reshape 0 v = matrixFromVector RowMajor 0 0 v
230reshape c v = matrixFromVector RowMajor (dim v `div` c) c v
231
232--singleton x = reshape 1 (fromList [x])
233
234-- | application of a vector function on the flattened matrix elements
235liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b
236liftMatrix f Matrix { irows = r, icols = c, xdat = d, order = o } = matrixFromVector o r c (f d)
237
238-- | application of a vector function on the flattened matrices elements
239liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
240liftMatrix2 f m1 m2
241 | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2"
242 | otherwise = case orderOf m1 of
243 RowMajor -> matrixFromVector RowMajor (rows m1) (cols m1) (f (xdat m1) (flatten m2))
244 ColumnMajor -> matrixFromVector ColumnMajor (rows m1) (cols m1) (f (xdat m1) ((xdat.fmat) m2))
245
246
247compat :: Matrix a -> Matrix b -> Bool
248compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2
249
250------------------------------------------------------------------
251
252{- | Supported matrix elements.
253
254 This class provides optimized internal
255 operations for selected element types.
256 It provides unoptimised defaults for any 'Storable' type,
257 so you can create instances simply as:
258
259 >instance Element Foo
260-}
261class (Storable a) => Element a where
262 transdata :: Int -> Vector a -> Int -> Vector a
263 constantD :: a -> Int -> Vector a
264 extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> Matrix a
265 sortI :: Ord a => Vector a -> Vector CInt
266 sortV :: Ord a => Vector a -> Vector a
267 compareV :: Ord a => Vector a -> Vector a -> Vector CInt
268 selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a
269 remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a
270
271
272instance Element Float where
273 transdata = transdataAux ctransF
274 constantD = constantAux cconstantF
275 extractR = extractAux c_extractF
276 sortI = sortIdxF
277 sortV = sortValF
278 compareV = compareF
279 selectV = selectF
280 remapM = remapF
281
282instance Element Double where
283 transdata = transdataAux ctransR
284 constantD = constantAux cconstantR
285 extractR = extractAux c_extractD
286 sortI = sortIdxD
287 sortV = sortValD
288 compareV = compareD
289 selectV = selectD
290 remapM = remapD
291
292
293instance Element (Complex Float) where
294 transdata = transdataAux ctransQ
295 constantD = constantAux cconstantQ
296 extractR = extractAux c_extractQ
297 sortI = undefined
298 sortV = undefined
299 compareV = undefined
300 selectV = selectQ
301 remapM = remapQ
302
303
304instance Element (Complex Double) where
305 transdata = transdataAux ctransC
306 constantD = constantAux cconstantC
307 extractR = extractAux c_extractC
308 sortI = undefined
309 sortV = undefined
310 compareV = undefined
311 selectV = selectC
312 remapM = remapC
313
314instance Element (CInt) where
315 transdata = transdataAux ctransI
316 constantD = constantAux cconstantI
317 extractR = extractAux c_extractI
318 sortI = sortIdxI
319 sortV = sortValI
320 compareV = compareI
321 selectV = selectI
322 remapM = remapI
323
324-------------------------------------------------------------------
325
326transdataAux fun c1 d c2 =
327 if noneed
328 then d
329 else unsafePerformIO $ do
330 -- putStrLn "T"
331 v <- createVector (dim d)
332 unsafeWith d $ \pd ->
333 unsafeWith v $ \pv ->
334 fun (fi r1) (fi c1) pd (fi r2) (fi c2) pv // check "transdataAux"
335 return v
336 where r1 = dim d `div` c1
337 r2 = dim d `div` c2
338 noneed = dim d == 0 || r1 == 1 || c1 == 1
339
340
341type TMM t = t ..> t ..> Ok
342
343foreign import ccall unsafe "transF" ctransF :: TMM Float
344foreign import ccall unsafe "transR" ctransR :: TMM Double
345foreign import ccall unsafe "transQ" ctransQ :: TMM (Complex Float)
346foreign import ccall unsafe "transC" ctransC :: TMM (Complex Double)
347foreign import ccall unsafe "transI" ctransI :: TMM CInt
348
349----------------------------------------------------------------------
350
351-- | Extracts a submatrix from a matrix.
352subMatrix :: Element a
353 => (Int,Int) -- ^ (r0,c0) starting position
354 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
355 -> Matrix a -- ^ input matrix
356 -> Matrix a -- ^ result
357subMatrix (r0,c0) (rt,ct) m
358 | 0 <= r0 && 0 <= rt && r0+rt <= rows m &&
359 0 <= c0 && 0 <= ct && c0+ct <= cols m = extractR m 0 (idxs[r0,r0+rt-1]) 0 (idxs[c0,c0+ct-1])
360 | otherwise = error $ "wrong subMatrix "++
361 show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m)
362
363--------------------------------------------------------------------------
364
365maxZ xs = if minimum xs == 0 then 0 else maximum xs
366
367conformMs ms = map (conformMTo (r,c)) ms
368 where
369 r = maxZ (map rows ms)
370 c = maxZ (map cols ms)
371
372
373conformVs vs = map (conformVTo n) vs
374 where
375 n = maxZ (map dim vs)
376
377conformMTo (r,c) m
378 | size m == (r,c) = m
379 | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c))
380 | size m == (r,1) = repCols c m
381 | size m == (1,c) = repRows r m
382 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")"
383
384conformVTo n v
385 | dim v == n = v
386 | dim v == 1 = constantD (v@>0) n
387 | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n
388
389repRows n x = fromRows (replicate n (flatten x))
390repCols n x = fromColumns (replicate n (flatten x))
391
392size m = (rows m, cols m)
393
394shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")"
395
396emptyM r c = matrixFromVector RowMajor r c (fromList[])
397
398----------------------------------------------------------------------
399
400instance (Storable t, NFData t) => NFData (Matrix t)
401 where
402 rnf m | d > 0 = rnf (v @> 0)
403 | otherwise = ()
404 where
405 d = dim v
406 v = xdat m
407
408---------------------------------------------------------------
409
410extractAux f m moder vr modec vc = unsafePerformIO $ do
411 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr
412 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc
413 r <- createMatrix RowMajor nr nc
414 app4 (f moder modec) vec vr vec vc omat m omat r "extractAux"
415 return r
416
417type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt))))
418
419foreign import ccall unsafe "extractD" c_extractD :: Extr Double
420foreign import ccall unsafe "extractF" c_extractF :: Extr Float
421foreign import ccall unsafe "extractC" c_extractC :: Extr (Complex Double)
422foreign import ccall unsafe "extractQ" c_extractQ :: Extr (Complex Float)
423foreign import ccall unsafe "extractI" c_extractI :: Extr CInt
424
425--------------------------------------------------------------------------------
426
427sortG f v = unsafePerformIO $ do
428 r <- createVector (dim v)
429 app2 f vec v vec r "sortG"
430 return r
431
432sortIdxD = sortG c_sort_indexD
433sortIdxF = sortG c_sort_indexF
434sortIdxI = sortG c_sort_indexI
435
436sortValD = sortG c_sort_valD
437sortValF = sortG c_sort_valF
438sortValI = sortG c_sort_valI
439
440foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt))
441foreign import ccall unsafe "sort_indexF" c_sort_indexF :: CV Float (CV CInt (IO CInt))
442foreign import ccall unsafe "sort_indexI" c_sort_indexI :: CV CInt (CV CInt (IO CInt))
443
444foreign import ccall unsafe "sort_valuesD" c_sort_valD :: CV Double (CV Double (IO CInt))
445foreign import ccall unsafe "sort_valuesF" c_sort_valF :: CV Float (CV Float (IO CInt))
446foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV CInt (CV CInt (IO CInt))
447
448--------------------------------------------------------------------------------
449
450compareG f u v = unsafePerformIO $ do
451 r <- createVector (dim v)
452 app3 f vec u vec v vec r "compareG"
453 return r
454
455compareD = compareG c_compareD
456compareF = compareG c_compareF
457compareI = compareG c_compareI
458
459foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt)))
460foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV CInt (IO CInt)))
461foreign import ccall unsafe "compareI" c_compareI :: CV CInt (CV CInt (CV CInt (IO CInt)))
462
463--------------------------------------------------------------------------------
464
465selectG f c u v w = unsafePerformIO $ do
466 r <- createVector (dim v)
467 app5 f vec c vec u vec v vec w vec r "selectG"
468 return r
469
470selectD = selectG c_selectD
471selectF = selectG c_selectF
472selectI = selectG c_selectI
473selectC = selectG c_selectC
474selectQ = selectG c_selectQ
475
476type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt)))))
477
478foreign import ccall unsafe "chooseD" c_selectD :: Sel Double
479foreign import ccall unsafe "chooseF" c_selectF :: Sel Float
480foreign import ccall unsafe "chooseI" c_selectI :: Sel CInt
481foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double)
482foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float)
483
484---------------------------------------------------------------------------
485
486remapG f i j m = unsafePerformIO $ do
487 r <- createMatrix RowMajor (rows i) (cols i)
488 app4 f omat i omat j omat m omat r "remapG"
489 return r
490
491remapD = remapG c_remapD
492remapF = remapG c_remapF
493remapI = remapG c_remapI
494remapC = remapG c_remapC
495remapQ = remapG c_remapQ
496
497type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt))))
498
499foreign import ccall unsafe "remapD" c_remapD :: Rem Double
500foreign import ccall unsafe "remapF" c_remapF :: Rem Float
501foreign import ccall unsafe "remapI" c_remapI :: Rem CInt
502foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double)
503foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float)
504
505--------------------------------------------------------------------------------
506
507foreign import ccall unsafe "saveMatrix" c_saveMatrix
508 :: CString -> CString -> Double ..> Ok
509
510{- | save a matrix as a 2D ASCII table
511-}
512saveMatrix
513 :: FilePath
514 -> String -- ^ \"printf\" format (e.g. \"%.2f\", \"%g\", etc.)
515 -> Matrix Double
516 -> IO ()
517saveMatrix name format m = do
518 cname <- newCString name
519 cformat <- newCString format
520 app1 (c_saveMatrix cname cformat) mat m "saveMatrix"
521 free cname
522 free cformat
523 return ()
524
525--------------------------------------------------------------------------------
526