summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Matrix.hs
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2016-10-07 20:42:10 +0200
committerGitHub <noreply@github.com>2016-10-07 20:42:10 +0200
commit58205ccd5bd4daa0e0098fcd43fde9b82765151f (patch)
treeb95f05bc88eb6b811d1e77fbde9ae8ddb1ac9aa0 /packages/base/src/Internal/Matrix.hs
parent2f773c0148a1a50b84226f69852997d53b0653fb (diff)
parent59cb364ebd7bff09a19f5f83104752a14f6a5177 (diff)
Merge pull request #199 from exFalso/fix-use-after-free
Redefine (#), fixes #198
Diffstat (limited to 'packages/base/src/Internal/Matrix.hs')
-rw-r--r--packages/base/src/Internal/Matrix.hs35
1 files changed, 19 insertions, 16 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index c47c625..0135288 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -22,7 +22,7 @@ module Internal.Matrix where
22 22
23import Internal.Vector 23import Internal.Vector
24import Internal.Devel 24import Internal.Devel
25import Internal.Vectorized hiding ((#)) 25import Internal.Vectorized hiding ((#), (#!))
26import Foreign.Marshal.Alloc ( free ) 26import Foreign.Marshal.Alloc ( free )
27import Foreign.Marshal.Array(newArray) 27import Foreign.Marshal.Array(newArray)
28import Foreign.Ptr ( Ptr ) 28import Foreign.Ptr ( Ptr )
@@ -110,15 +110,15 @@ fmat m
110 110
111-- C-Haskell matrix adapters 111-- C-Haskell matrix adapters
112{-# INLINE amatr #-} 112{-# INLINE amatr #-}
113amatr :: Storable a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b 113amatr :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> Ptr a -> f) -> IO r
114amatr f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c)) 114amatr x f g = unsafeWith (xdat x) (f . g r c)
115 where 115 where
116 r = fi (rows x) 116 r = fi (rows x)
117 c = fi (cols x) 117 c = fi (cols x)
118 118
119{-# INLINE amat #-} 119{-# INLINE amat #-}
120amat :: Storable a => (CInt -> CInt -> CInt -> CInt -> Ptr a -> b) -> Matrix a -> b 120amat :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> CInt -> CInt -> Ptr a -> f) -> IO r
121amat f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c sr sc)) 121amat x f g = unsafeWith (xdat x) (f . g r c sr sc)
122 where 122 where
123 r = fi (rows x) 123 r = fi (rows x)
124 c = fi (cols x) 124 c = fi (cols x)
@@ -135,10 +135,13 @@ instance Storable t => TransArray (Matrix t)
135 applyRaw = amatr 135 applyRaw = amatr
136 {-# INLINE applyRaw #-} 136 {-# INLINE applyRaw #-}
137 137
138infixl 1 # 138infixr 1 #
139a # b = apply a b 139a # b = apply a b
140{-# INLINE (#) #-} 140{-# INLINE (#) #-}
141 141
142a #! b = a # b # id
143{-# INLINE (#!) #-}
144
142-------------------------------------------------------------------------------- 145--------------------------------------------------------------------------------
143 146
144copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) 147copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1])
@@ -426,7 +429,8 @@ extractAux f ord m moder vr modec vc = do
426 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr 429 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr
427 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc 430 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc
428 r <- createMatrix ord nr nc 431 r <- createMatrix ord nr nc
429 f moder modec # vr # vc # m # r #|"extract" 432 (vr # vc # m #! r) (f moder modec) #|"extract"
433
430 return r 434 return r
431 435
432type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt)))) 436type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt))))
@@ -440,7 +444,7 @@ foreign import ccall unsafe "extractL" c_extractL :: Extr Z
440 444
441--------------------------------------------------------------- 445---------------------------------------------------------------
442 446
443setRectAux f i j m r = f (fi i) (fi j) # m # r #|"setRect" 447setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect"
444 448
445type SetRect x = I -> I -> x ::> x::> Ok 449type SetRect x = I -> I -> x ::> x::> Ok
446 450
@@ -455,7 +459,7 @@ foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z
455 459
456sortG f v = unsafePerformIO $ do 460sortG f v = unsafePerformIO $ do
457 r <- createVector (dim v) 461 r <- createVector (dim v)
458 f # v # r #|"sortG" 462 (v #! r) f #|"sortG"
459 return r 463 return r
460 464
461sortIdxD = sortG c_sort_indexD 465sortIdxD = sortG c_sort_indexD
@@ -482,7 +486,7 @@ foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok
482 486
483compareG f u v = unsafePerformIO $ do 487compareG f u v = unsafePerformIO $ do
484 r <- createVector (dim v) 488 r <- createVector (dim v)
485 f # u # v # r #|"compareG" 489 (u # v #! r) f #|"compareG"
486 return r 490 return r
487 491
488compareD = compareG c_compareD 492compareD = compareG c_compareD
@@ -499,7 +503,7 @@ foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok
499 503
500selectG f c u v w = unsafePerformIO $ do 504selectG f c u v w = unsafePerformIO $ do
501 r <- createVector (dim v) 505 r <- createVector (dim v)
502 f # c # u # v # w # r #|"selectG" 506 (c # u # v # w #! r) f #|"selectG"
503 return r 507 return r
504 508
505selectD = selectG c_selectD 509selectD = selectG c_selectD
@@ -522,7 +526,7 @@ foreign import ccall unsafe "chooseL" c_selectL :: Sel Z
522 526
523remapG f i j m = unsafePerformIO $ do 527remapG f i j m = unsafePerformIO $ do
524 r <- createMatrix RowMajor (rows i) (cols i) 528 r <- createMatrix RowMajor (rows i) (cols i)
525 f # i # j # m # r #|"remapG" 529 (i # j # m #! r) f #|"remapG"
526 return r 530 return r
527 531
528remapD = remapG c_remapD 532remapD = remapG c_remapD
@@ -545,7 +549,7 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z
545 549
546rowOpAux f c x i1 i2 j1 j2 m = do 550rowOpAux f c x i1 i2 j1 j2 m = do
547 px <- newArray [x] 551 px <- newArray [x]
548 f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2) # m #|"rowOp" 552 (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp"
549 free px 553 free px
550 554
551type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok 555type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok
@@ -561,7 +565,7 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z
561 565
562-------------------------------------------------------------------------------- 566--------------------------------------------------------------------------------
563 567
564gemmg f v m1 m2 m3 = f # v # m1 # m2 # m3 #|"gemmg" 568gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg"
565 569
566type Tgemm x = x :> x ::> x ::> x ::> Ok 570type Tgemm x = x :> x ::> x ::> x ::> Ok
567 571
@@ -589,10 +593,9 @@ saveMatrix
589saveMatrix name format m = do 593saveMatrix name format m = do
590 cname <- newCString name 594 cname <- newCString name
591 cformat <- newCString format 595 cformat <- newCString format
592 c_saveMatrix cname cformat # m #|"saveMatrix" 596 (m # id) (c_saveMatrix cname cformat) #|"saveMatrix"
593 free cname 597 free cname
594 free cformat 598 free cformat
595 return () 599 return ()
596 600
597-------------------------------------------------------------------------------- 601--------------------------------------------------------------------------------
598