diff options
Diffstat (limited to 'packages/base/src/Internal/Matrix.hs')
-rw-r--r-- | packages/base/src/Internal/Matrix.hs | 32 |
1 files changed, 29 insertions, 3 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 8de06ce..fa1aad6 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -20,6 +20,7 @@ import Internal.Vector | |||
20 | import Internal.Devel | 20 | import Internal.Devel |
21 | import Internal.Vectorized | 21 | import Internal.Vectorized |
22 | import Foreign.Marshal.Alloc ( free ) | 22 | import Foreign.Marshal.Alloc ( free ) |
23 | import Foreign.Marshal.Array(newArray) | ||
23 | import Foreign.Ptr ( Ptr ) | 24 | import Foreign.Ptr ( Ptr ) |
24 | import Foreign.Storable ( Storable ) | 25 | import Foreign.Storable ( Storable ) |
25 | import Data.Complex ( Complex ) | 26 | import Data.Complex ( Complex ) |
@@ -273,12 +274,13 @@ compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2 | |||
273 | class (Storable a) => Element a where | 274 | class (Storable a) => Element a where |
274 | transdata :: Int -> Vector a -> Int -> Vector a | 275 | transdata :: Int -> Vector a -> Int -> Vector a |
275 | constantD :: a -> Int -> Vector a | 276 | constantD :: a -> Int -> Vector a |
276 | extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> Matrix a | 277 | extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) |
277 | sortI :: Ord a => Vector a -> Vector CInt | 278 | sortI :: Ord a => Vector a -> Vector CInt |
278 | sortV :: Ord a => Vector a -> Vector a | 279 | sortV :: Ord a => Vector a -> Vector a |
279 | compareV :: Ord a => Vector a -> Vector a -> Vector CInt | 280 | compareV :: Ord a => Vector a -> Vector a -> Vector CInt |
280 | selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a | 281 | selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a |
281 | remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a | 282 | remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a |
283 | rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () | ||
282 | 284 | ||
283 | 285 | ||
284 | instance Element Float where | 286 | instance Element Float where |
@@ -290,6 +292,7 @@ instance Element Float where | |||
290 | compareV = compareF | 292 | compareV = compareF |
291 | selectV = selectF | 293 | selectV = selectF |
292 | remapM = remapF | 294 | remapM = remapF |
295 | rowOp = rowOpAux c_rowOpF | ||
293 | 296 | ||
294 | instance Element Double where | 297 | instance Element Double where |
295 | transdata = transdataAux ctransR | 298 | transdata = transdataAux ctransR |
@@ -300,6 +303,7 @@ instance Element Double where | |||
300 | compareV = compareD | 303 | compareV = compareD |
301 | selectV = selectD | 304 | selectV = selectD |
302 | remapM = remapD | 305 | remapM = remapD |
306 | rowOp = rowOpAux c_rowOpD | ||
303 | 307 | ||
304 | 308 | ||
305 | instance Element (Complex Float) where | 309 | instance Element (Complex Float) where |
@@ -311,6 +315,7 @@ instance Element (Complex Float) where | |||
311 | compareV = undefined | 315 | compareV = undefined |
312 | selectV = selectQ | 316 | selectV = selectQ |
313 | remapM = remapQ | 317 | remapM = remapQ |
318 | rowOp = rowOpAux c_rowOpQ | ||
314 | 319 | ||
315 | 320 | ||
316 | instance Element (Complex Double) where | 321 | instance Element (Complex Double) where |
@@ -322,6 +327,7 @@ instance Element (Complex Double) where | |||
322 | compareV = undefined | 327 | compareV = undefined |
323 | selectV = selectC | 328 | selectV = selectC |
324 | remapM = remapC | 329 | remapM = remapC |
330 | rowOp = rowOpAux c_rowOpC | ||
325 | 331 | ||
326 | instance Element (CInt) where | 332 | instance Element (CInt) where |
327 | transdata = transdataAux ctransI | 333 | transdata = transdataAux ctransI |
@@ -332,6 +338,7 @@ instance Element (CInt) where | |||
332 | compareV = compareI | 338 | compareV = compareI |
333 | selectV = selectI | 339 | selectV = selectI |
334 | remapM = remapI | 340 | remapM = remapI |
341 | rowOp = rowOpAux c_rowOpI | ||
335 | 342 | ||
336 | instance Element Z where | 343 | instance Element Z where |
337 | transdata = transdataAux ctransL | 344 | transdata = transdataAux ctransL |
@@ -342,6 +349,7 @@ instance Element Z where | |||
342 | compareV = compareL | 349 | compareV = compareL |
343 | selectV = selectL | 350 | selectV = selectL |
344 | remapM = remapL | 351 | remapM = remapL |
352 | rowOp = rowOpAux c_rowOpL | ||
345 | 353 | ||
346 | ------------------------------------------------------------------- | 354 | ------------------------------------------------------------------- |
347 | 355 | ||
@@ -379,7 +387,7 @@ subMatrix :: Element a | |||
379 | -> Matrix a -- ^ result | 387 | -> Matrix a -- ^ result |
380 | subMatrix (r0,c0) (rt,ct) m | 388 | subMatrix (r0,c0) (rt,ct) m |
381 | | 0 <= r0 && 0 <= rt && r0+rt <= rows m && | 389 | | 0 <= r0 && 0 <= rt && r0+rt <= rows m && |
382 | 0 <= c0 && 0 <= ct && c0+ct <= cols m = extractR m 0 (idxs[r0,r0+rt-1]) 0 (idxs[c0,c0+ct-1]) | 390 | 0 <= c0 && 0 <= ct && c0+ct <= cols m = unsafePerformIO $ extractR m 0 (idxs[r0,r0+rt-1]) 0 (idxs[c0,c0+ct-1]) |
383 | | otherwise = error $ "wrong subMatrix "++ | 391 | | otherwise = error $ "wrong subMatrix "++ |
384 | show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m) | 392 | show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m) |
385 | 393 | ||
@@ -430,7 +438,7 @@ instance (Storable t, NFData t) => NFData (Matrix t) | |||
430 | 438 | ||
431 | --------------------------------------------------------------- | 439 | --------------------------------------------------------------- |
432 | 440 | ||
433 | extractAux f m moder vr modec vc = unsafePerformIO $ do | 441 | extractAux f m moder vr modec vc = do |
434 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr | 442 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr |
435 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc | 443 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc |
436 | r <- createMatrix RowMajor nr nc | 444 | r <- createMatrix RowMajor nr nc |
@@ -538,6 +546,24 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z | |||
538 | 546 | ||
539 | -------------------------------------------------------------------------------- | 547 | -------------------------------------------------------------------------------- |
540 | 548 | ||
549 | rowOpAux f c x i1 i2 j1 j2 m = do | ||
550 | px <- newArray [x] | ||
551 | app1 (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) omat m "rowOp" | ||
552 | free px | ||
553 | |||
554 | type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok | ||
555 | |||
556 | foreign import ccall unsafe "rowop_double" c_rowOpD :: RowOp R | ||
557 | foreign import ccall unsafe "rowop_float" c_rowOpF :: RowOp Float | ||
558 | foreign import ccall unsafe "rowop_TCD" c_rowOpC :: RowOp C | ||
559 | foreign import ccall unsafe "rowop_TCF" c_rowOpQ :: RowOp (Complex Float) | ||
560 | foreign import ccall unsafe "rowop_int32_t" c_rowOpI :: RowOp I | ||
561 | foreign import ccall unsafe "rowop_int64_t" c_rowOpL :: RowOp Z | ||
562 | foreign import ccall unsafe "rowop_mod_int32_t" c_rowOpMI :: I -> RowOp I | ||
563 | foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z | ||
564 | |||
565 | -------------------------------------------------------------------------------- | ||
566 | |||
541 | foreign import ccall unsafe "saveMatrix" c_saveMatrix | 567 | foreign import ccall unsafe "saveMatrix" c_saveMatrix |
542 | :: CString -> CString -> Double ..> Ok | 568 | :: CString -> CString -> Double ..> Ok |
543 | 569 | ||