diff options
author | Alberto Ruiz <aruiz@um.es> | 2016-10-07 20:42:10 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-10-07 20:42:10 +0200 |
commit | 58205ccd5bd4daa0e0098fcd43fde9b82765151f (patch) | |
tree | b95f05bc88eb6b811d1e77fbde9ae8ddb1ac9aa0 /packages/base/src/Internal/Matrix.hs | |
parent | 2f773c0148a1a50b84226f69852997d53b0653fb (diff) | |
parent | 59cb364ebd7bff09a19f5f83104752a14f6a5177 (diff) |
Merge pull request #199 from exFalso/fix-use-after-free
Redefine (#), fixes #198
Diffstat (limited to 'packages/base/src/Internal/Matrix.hs')
-rw-r--r-- | packages/base/src/Internal/Matrix.hs | 35 |
1 files changed, 19 insertions, 16 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index c47c625..0135288 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -22,7 +22,7 @@ module Internal.Matrix where | |||
22 | 22 | ||
23 | import Internal.Vector | 23 | import Internal.Vector |
24 | import Internal.Devel | 24 | import Internal.Devel |
25 | import Internal.Vectorized hiding ((#)) | 25 | import Internal.Vectorized hiding ((#), (#!)) |
26 | import Foreign.Marshal.Alloc ( free ) | 26 | import Foreign.Marshal.Alloc ( free ) |
27 | import Foreign.Marshal.Array(newArray) | 27 | import Foreign.Marshal.Array(newArray) |
28 | import Foreign.Ptr ( Ptr ) | 28 | import Foreign.Ptr ( Ptr ) |
@@ -110,15 +110,15 @@ fmat m | |||
110 | 110 | ||
111 | -- C-Haskell matrix adapters | 111 | -- C-Haskell matrix adapters |
112 | {-# INLINE amatr #-} | 112 | {-# INLINE amatr #-} |
113 | amatr :: Storable a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b | 113 | amatr :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> Ptr a -> f) -> IO r |
114 | amatr f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c)) | 114 | amatr x f g = unsafeWith (xdat x) (f . g r c) |
115 | where | 115 | where |
116 | r = fi (rows x) | 116 | r = fi (rows x) |
117 | c = fi (cols x) | 117 | c = fi (cols x) |
118 | 118 | ||
119 | {-# INLINE amat #-} | 119 | {-# INLINE amat #-} |
120 | amat :: Storable a => (CInt -> CInt -> CInt -> CInt -> Ptr a -> b) -> Matrix a -> b | 120 | amat :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> CInt -> CInt -> Ptr a -> f) -> IO r |
121 | amat f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c sr sc)) | 121 | amat x f g = unsafeWith (xdat x) (f . g r c sr sc) |
122 | where | 122 | where |
123 | r = fi (rows x) | 123 | r = fi (rows x) |
124 | c = fi (cols x) | 124 | c = fi (cols x) |
@@ -135,10 +135,13 @@ instance Storable t => TransArray (Matrix t) | |||
135 | applyRaw = amatr | 135 | applyRaw = amatr |
136 | {-# INLINE applyRaw #-} | 136 | {-# INLINE applyRaw #-} |
137 | 137 | ||
138 | infixl 1 # | 138 | infixr 1 # |
139 | a # b = apply a b | 139 | a # b = apply a b |
140 | {-# INLINE (#) #-} | 140 | {-# INLINE (#) #-} |
141 | 141 | ||
142 | a #! b = a # b # id | ||
143 | {-# INLINE (#!) #-} | ||
144 | |||
142 | -------------------------------------------------------------------------------- | 145 | -------------------------------------------------------------------------------- |
143 | 146 | ||
144 | copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) | 147 | copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) |
@@ -426,7 +429,8 @@ extractAux f ord m moder vr modec vc = do | |||
426 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr | 429 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr |
427 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc | 430 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc |
428 | r <- createMatrix ord nr nc | 431 | r <- createMatrix ord nr nc |
429 | f moder modec # vr # vc # m # r #|"extract" | 432 | (vr # vc # m #! r) (f moder modec) #|"extract" |
433 | |||
430 | return r | 434 | return r |
431 | 435 | ||
432 | type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt)))) | 436 | type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt)))) |
@@ -440,7 +444,7 @@ foreign import ccall unsafe "extractL" c_extractL :: Extr Z | |||
440 | 444 | ||
441 | --------------------------------------------------------------- | 445 | --------------------------------------------------------------- |
442 | 446 | ||
443 | setRectAux f i j m r = f (fi i) (fi j) # m # r #|"setRect" | 447 | setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" |
444 | 448 | ||
445 | type SetRect x = I -> I -> x ::> x::> Ok | 449 | type SetRect x = I -> I -> x ::> x::> Ok |
446 | 450 | ||
@@ -455,7 +459,7 @@ foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z | |||
455 | 459 | ||
456 | sortG f v = unsafePerformIO $ do | 460 | sortG f v = unsafePerformIO $ do |
457 | r <- createVector (dim v) | 461 | r <- createVector (dim v) |
458 | f # v # r #|"sortG" | 462 | (v #! r) f #|"sortG" |
459 | return r | 463 | return r |
460 | 464 | ||
461 | sortIdxD = sortG c_sort_indexD | 465 | sortIdxD = sortG c_sort_indexD |
@@ -482,7 +486,7 @@ foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok | |||
482 | 486 | ||
483 | compareG f u v = unsafePerformIO $ do | 487 | compareG f u v = unsafePerformIO $ do |
484 | r <- createVector (dim v) | 488 | r <- createVector (dim v) |
485 | f # u # v # r #|"compareG" | 489 | (u # v #! r) f #|"compareG" |
486 | return r | 490 | return r |
487 | 491 | ||
488 | compareD = compareG c_compareD | 492 | compareD = compareG c_compareD |
@@ -499,7 +503,7 @@ foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok | |||
499 | 503 | ||
500 | selectG f c u v w = unsafePerformIO $ do | 504 | selectG f c u v w = unsafePerformIO $ do |
501 | r <- createVector (dim v) | 505 | r <- createVector (dim v) |
502 | f # c # u # v # w # r #|"selectG" | 506 | (c # u # v # w #! r) f #|"selectG" |
503 | return r | 507 | return r |
504 | 508 | ||
505 | selectD = selectG c_selectD | 509 | selectD = selectG c_selectD |
@@ -522,7 +526,7 @@ foreign import ccall unsafe "chooseL" c_selectL :: Sel Z | |||
522 | 526 | ||
523 | remapG f i j m = unsafePerformIO $ do | 527 | remapG f i j m = unsafePerformIO $ do |
524 | r <- createMatrix RowMajor (rows i) (cols i) | 528 | r <- createMatrix RowMajor (rows i) (cols i) |
525 | f # i # j # m # r #|"remapG" | 529 | (i # j # m #! r) f #|"remapG" |
526 | return r | 530 | return r |
527 | 531 | ||
528 | remapD = remapG c_remapD | 532 | remapD = remapG c_remapD |
@@ -545,7 +549,7 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z | |||
545 | 549 | ||
546 | rowOpAux f c x i1 i2 j1 j2 m = do | 550 | rowOpAux f c x i1 i2 j1 j2 m = do |
547 | px <- newArray [x] | 551 | px <- newArray [x] |
548 | f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2) # m #|"rowOp" | 552 | (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" |
549 | free px | 553 | free px |
550 | 554 | ||
551 | type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok | 555 | type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok |
@@ -561,7 +565,7 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z | |||
561 | 565 | ||
562 | -------------------------------------------------------------------------------- | 566 | -------------------------------------------------------------------------------- |
563 | 567 | ||
564 | gemmg f v m1 m2 m3 = f # v # m1 # m2 # m3 #|"gemmg" | 568 | gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" |
565 | 569 | ||
566 | type Tgemm x = x :> x ::> x ::> x ::> Ok | 570 | type Tgemm x = x :> x ::> x ::> x ::> Ok |
567 | 571 | ||
@@ -589,10 +593,9 @@ saveMatrix | |||
589 | saveMatrix name format m = do | 593 | saveMatrix name format m = do |
590 | cname <- newCString name | 594 | cname <- newCString name |
591 | cformat <- newCString format | 595 | cformat <- newCString format |
592 | c_saveMatrix cname cformat # m #|"saveMatrix" | 596 | (m # id) (c_saveMatrix cname cformat) #|"saveMatrix" |
593 | free cname | 597 | free cname |
594 | free cformat | 598 | free cformat |
595 | return () | 599 | return () |
596 | 600 | ||
597 | -------------------------------------------------------------------------------- | 601 | -------------------------------------------------------------------------------- |
598 | |||