summaryrefslogtreecommitdiff
path: root/packages/base/src/Data/Packed/Internal/Matrix.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Data/Packed/Internal/Matrix.hs')
-rw-r--r--packages/base/src/Data/Packed/Internal/Matrix.hs52
1 files changed, 46 insertions, 6 deletions
diff --git a/packages/base/src/Data/Packed/Internal/Matrix.hs b/packages/base/src/Data/Packed/Internal/Matrix.hs
index 82a9d8f..ddeddae 100644
--- a/packages/base/src/Data/Packed/Internal/Matrix.hs
+++ b/packages/base/src/Data/Packed/Internal/Matrix.hs
@@ -105,6 +105,14 @@ cols = icols
105orderOf :: Matrix t -> MatrixOrder 105orderOf :: Matrix t -> MatrixOrder
106orderOf = order 106orderOf = order
107 107
108stepRow :: Matrix t -> CInt
109stepRow Matrix {icols = c, order = RowMajor } = fromIntegral c
110stepRow _ = 1
111
112stepCol :: Matrix t -> CInt
113stepCol Matrix {irows = r, order = ColumnMajor } = fromIntegral r
114stepCol _ = 1
115
108 116
109-- | Matrix transpose. 117-- | Matrix transpose.
110trans :: Matrix t -> Matrix t 118trans :: Matrix t -> Matrix t
@@ -128,6 +136,14 @@ mat a f =
128 g (fi (rows a)) (fi (cols a)) p 136 g (fi (rows a)) (fi (cols a)) p
129 f m 137 f m
130 138
139omat :: (Storable t) => Matrix t -> (((CInt -> CInt -> CInt -> CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b
140omat a f =
141 unsafeWith (xdat a) $ \p -> do
142 let m g = do
143 g (fi (rows a)) (fi (cols a)) (stepRow a) (stepCol a) p
144 f m
145
146
131{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. 147{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose.
132 148
133>>> flatten (ident 3) 149>>> flatten (ident 3)
@@ -257,7 +273,7 @@ compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2
257 >instance Element Foo 273 >instance Element Foo
258-} 274-}
259class (Storable a) => Element a where 275class (Storable a) => Element a where
260 subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position 276 subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position
261 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix 277 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
262 -> Matrix a -> Matrix a 278 -> Matrix a -> Matrix a
263 subMatrixD = subMatrix' 279 subMatrixD = subMatrix'
@@ -270,6 +286,7 @@ class (Storable a) => Element a where
270 sortV :: Ord a => Vector a -> Vector a 286 sortV :: Ord a => Vector a -> Vector a
271 compareV :: Ord a => Vector a -> Vector a -> Vector CInt 287 compareV :: Ord a => Vector a -> Vector a -> Vector CInt
272 selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a 288 selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a
289 remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a
273 290
274 291
275instance Element Float where 292instance Element Float where
@@ -280,7 +297,7 @@ instance Element Float where
280 sortV = sortValF 297 sortV = sortValF
281 compareV = compareF 298 compareV = compareF
282 selectV = selectF 299 selectV = selectF
283 300 remapM = remapF
284 301
285instance Element Double where 302instance Element Double where
286 transdata = transdataAux ctransR 303 transdata = transdataAux ctransR
@@ -290,6 +307,7 @@ instance Element Double where
290 sortV = sortValD 307 sortV = sortValD
291 compareV = compareD 308 compareV = compareD
292 selectV = selectD 309 selectV = selectD
310 remapM = remapD
293 311
294 312
295instance Element (Complex Float) where 313instance Element (Complex Float) where
@@ -300,6 +318,7 @@ instance Element (Complex Float) where
300 sortV = undefined 318 sortV = undefined
301 compareV = undefined 319 compareV = undefined
302 selectV = selectQ 320 selectV = selectQ
321 remapM = remapQ
303 322
304 323
305instance Element (Complex Double) where 324instance Element (Complex Double) where
@@ -310,8 +329,8 @@ instance Element (Complex Double) where
310 sortV = undefined 329 sortV = undefined
311 compareV = undefined 330 compareV = undefined
312 selectV = selectC 331 selectV = selectC
332 remapM = remapC
313 333
314
315instance Element (CInt) where 334instance Element (CInt) where
316 transdata = transdataAux ctransI 335 transdata = transdataAux ctransI
317 constantD = constantAux cconstantI 336 constantD = constantAux cconstantI
@@ -320,7 +339,7 @@ instance Element (CInt) where
320 sortV = sortValI 339 sortV = sortValI
321 compareV = compareI 340 compareV = compareI
322 selectV = selectI 341 selectV = selectI
323 342 remapM = remapI
324 343
325------------------------------------------------------------------- 344-------------------------------------------------------------------
326 345
@@ -394,7 +413,7 @@ foreign import ccall unsafe "constantP" cconstantP :: Ptr () -> CInt -> Ptr () -
394 413
395-- | Extracts a submatrix from a matrix. 414-- | Extracts a submatrix from a matrix.
396subMatrix :: Element a 415subMatrix :: Element a
397 => (Int,Int) -- ^ (r0,c0) starting position 416 => (Int,Int) -- ^ (r0,c0) starting position
398 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix 417 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
399 -> Matrix a -- ^ input matrix 418 -> Matrix a -- ^ input matrix
400 -> Matrix a -- ^ result 419 -> Matrix a -- ^ result
@@ -427,7 +446,7 @@ conformMs ms = map (conformMTo (r,c)) ms
427 where 446 where
428 r = maxZ (map rows ms) 447 r = maxZ (map rows ms)
429 c = maxZ (map cols ms) 448 c = maxZ (map cols ms)
430 449
431 450
432conformVs vs = map (conformVTo n) vs 451conformVs vs = map (conformVTo n) vs
433 where 452 where
@@ -554,4 +573,25 @@ foreign import ccall unsafe "chooseI" c_selectI :: Sel CInt
554foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double) 573foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double)
555foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float) 574foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float)
556 575
576---------------------------------------------------------------------------
577
578remapG f i j m = unsafePerformIO $ do
579 r <- createMatrix RowMajor (rows i) (cols i)
580 app4 f omat i omat j omat m omat r "remapG"
581 return r
582
583remapD = remapG c_remapD
584remapF = remapG c_remapF
585remapI = remapG c_remapI
586remapC = remapG c_remapC
587remapQ = remapG c_remapQ
588
589type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt))))
590
591foreign import ccall unsafe "remapD" c_remapD :: Rem Double
592foreign import ccall unsafe "remapF" c_remapF :: Rem Float
593foreign import ccall unsafe "remapI" c_remapI :: Rem CInt
594foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double)
595foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float)
596
557 597