summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Matrix.hs
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2015-06-12 20:58:13 +0200
committerAlberto Ruiz <aruiz@um.es>2015-06-12 20:58:13 +0200
commit4b3e29097aa272d429f8005fe17b459cf0c049c8 (patch)
treedf01591ec7bdffe61f68062cc09e95f69e745a90 /packages/base/src/Internal/Matrix.hs
parent0396adb9f10f5b337e54d64fec365c9cb01e9745 (diff)
row ops in ST
Diffstat (limited to 'packages/base/src/Internal/Matrix.hs')
-rw-r--r--packages/base/src/Internal/Matrix.hs32
1 files changed, 29 insertions, 3 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index 8de06ce..fa1aad6 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -20,6 +20,7 @@ import Internal.Vector
20import Internal.Devel 20import Internal.Devel
21import Internal.Vectorized 21import Internal.Vectorized
22import Foreign.Marshal.Alloc ( free ) 22import Foreign.Marshal.Alloc ( free )
23import Foreign.Marshal.Array(newArray)
23import Foreign.Ptr ( Ptr ) 24import Foreign.Ptr ( Ptr )
24import Foreign.Storable ( Storable ) 25import Foreign.Storable ( Storable )
25import Data.Complex ( Complex ) 26import Data.Complex ( Complex )
@@ -273,12 +274,13 @@ compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2
273class (Storable a) => Element a where 274class (Storable a) => Element a where
274 transdata :: Int -> Vector a -> Int -> Vector a 275 transdata :: Int -> Vector a -> Int -> Vector a
275 constantD :: a -> Int -> Vector a 276 constantD :: a -> Int -> Vector a
276 extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> Matrix a 277 extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a)
277 sortI :: Ord a => Vector a -> Vector CInt 278 sortI :: Ord a => Vector a -> Vector CInt
278 sortV :: Ord a => Vector a -> Vector a 279 sortV :: Ord a => Vector a -> Vector a
279 compareV :: Ord a => Vector a -> Vector a -> Vector CInt 280 compareV :: Ord a => Vector a -> Vector a -> Vector CInt
280 selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a 281 selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a
281 remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a 282 remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a
283 rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO ()
282 284
283 285
284instance Element Float where 286instance Element Float where
@@ -290,6 +292,7 @@ instance Element Float where
290 compareV = compareF 292 compareV = compareF
291 selectV = selectF 293 selectV = selectF
292 remapM = remapF 294 remapM = remapF
295 rowOp = rowOpAux c_rowOpF
293 296
294instance Element Double where 297instance Element Double where
295 transdata = transdataAux ctransR 298 transdata = transdataAux ctransR
@@ -300,6 +303,7 @@ instance Element Double where
300 compareV = compareD 303 compareV = compareD
301 selectV = selectD 304 selectV = selectD
302 remapM = remapD 305 remapM = remapD
306 rowOp = rowOpAux c_rowOpD
303 307
304 308
305instance Element (Complex Float) where 309instance Element (Complex Float) where
@@ -311,6 +315,7 @@ instance Element (Complex Float) where
311 compareV = undefined 315 compareV = undefined
312 selectV = selectQ 316 selectV = selectQ
313 remapM = remapQ 317 remapM = remapQ
318 rowOp = rowOpAux c_rowOpQ
314 319
315 320
316instance Element (Complex Double) where 321instance Element (Complex Double) where
@@ -322,6 +327,7 @@ instance Element (Complex Double) where
322 compareV = undefined 327 compareV = undefined
323 selectV = selectC 328 selectV = selectC
324 remapM = remapC 329 remapM = remapC
330 rowOp = rowOpAux c_rowOpC
325 331
326instance Element (CInt) where 332instance Element (CInt) where
327 transdata = transdataAux ctransI 333 transdata = transdataAux ctransI
@@ -332,6 +338,7 @@ instance Element (CInt) where
332 compareV = compareI 338 compareV = compareI
333 selectV = selectI 339 selectV = selectI
334 remapM = remapI 340 remapM = remapI
341 rowOp = rowOpAux c_rowOpI
335 342
336instance Element Z where 343instance Element Z where
337 transdata = transdataAux ctransL 344 transdata = transdataAux ctransL
@@ -342,6 +349,7 @@ instance Element Z where
342 compareV = compareL 349 compareV = compareL
343 selectV = selectL 350 selectV = selectL
344 remapM = remapL 351 remapM = remapL
352 rowOp = rowOpAux c_rowOpL
345 353
346------------------------------------------------------------------- 354-------------------------------------------------------------------
347 355
@@ -379,7 +387,7 @@ subMatrix :: Element a
379 -> Matrix a -- ^ result 387 -> Matrix a -- ^ result
380subMatrix (r0,c0) (rt,ct) m 388subMatrix (r0,c0) (rt,ct) m
381 | 0 <= r0 && 0 <= rt && r0+rt <= rows m && 389 | 0 <= r0 && 0 <= rt && r0+rt <= rows m &&
382 0 <= c0 && 0 <= ct && c0+ct <= cols m = extractR m 0 (idxs[r0,r0+rt-1]) 0 (idxs[c0,c0+ct-1]) 390 0 <= c0 && 0 <= ct && c0+ct <= cols m = unsafePerformIO $ extractR m 0 (idxs[r0,r0+rt-1]) 0 (idxs[c0,c0+ct-1])
383 | otherwise = error $ "wrong subMatrix "++ 391 | otherwise = error $ "wrong subMatrix "++
384 show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m) 392 show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m)
385 393
@@ -430,7 +438,7 @@ instance (Storable t, NFData t) => NFData (Matrix t)
430 438
431--------------------------------------------------------------- 439---------------------------------------------------------------
432 440
433extractAux f m moder vr modec vc = unsafePerformIO $ do 441extractAux f m moder vr modec vc = do
434 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr 442 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr
435 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc 443 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc
436 r <- createMatrix RowMajor nr nc 444 r <- createMatrix RowMajor nr nc
@@ -538,6 +546,24 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z
538 546
539-------------------------------------------------------------------------------- 547--------------------------------------------------------------------------------
540 548
549rowOpAux f c x i1 i2 j1 j2 m = do
550 px <- newArray [x]
551 app1 (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) omat m "rowOp"
552 free px
553
554type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok
555
556foreign import ccall unsafe "rowop_double" c_rowOpD :: RowOp R
557foreign import ccall unsafe "rowop_float" c_rowOpF :: RowOp Float
558foreign import ccall unsafe "rowop_TCD" c_rowOpC :: RowOp C
559foreign import ccall unsafe "rowop_TCF" c_rowOpQ :: RowOp (Complex Float)
560foreign import ccall unsafe "rowop_int32_t" c_rowOpI :: RowOp I
561foreign import ccall unsafe "rowop_int64_t" c_rowOpL :: RowOp Z
562foreign import ccall unsafe "rowop_mod_int32_t" c_rowOpMI :: I -> RowOp I
563foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z
564
565--------------------------------------------------------------------------------
566
541foreign import ccall unsafe "saveMatrix" c_saveMatrix 567foreign import ccall unsafe "saveMatrix" c_saveMatrix
542 :: CString -> CString -> Double ..> Ok 568 :: CString -> CString -> Double ..> Ok
543 569