summaryrefslogtreecommitdiff
path: root/lib/Data/Packed
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2008-06-06 15:41:48 +0000
committerAlberto Ruiz <aruiz@um.es>2008-06-06 15:41:48 +0000
commita407c0e101f8f6db44fcf731ebb8460d0f691196 (patch)
tree8bfefb55548279305d2005ef8e5fbfba23b426cd /lib/Data/Packed
parentfa7d2f17cbba1de2e900432e07bf4e1e7da2caab (diff)
range checking and other additions to the mutable interface
Diffstat (limited to 'lib/Data/Packed')
-rw-r--r--lib/Data/Packed/ST.hs97
1 files changed, 78 insertions, 19 deletions
diff --git a/lib/Data/Packed/ST.hs b/lib/Data/Packed/ST.hs
index 3d94014..1311ff9 100644
--- a/lib/Data/Packed/ST.hs
+++ b/lib/Data/Packed/ST.hs
@@ -16,10 +16,17 @@
16----------------------------------------------------------------------------- 16-----------------------------------------------------------------------------
17 17
18module Data.Packed.ST ( 18module Data.Packed.ST (
19 STVector, thawVector, freezeVector, runSTVector, 19 -- * Mutable Vectors
20 STVector, newVector, thawVector, freezeVector, runSTVector,
20 readVector, writeVector, modifyVector, liftSTVector, 21 readVector, writeVector, modifyVector, liftSTVector,
21 STMatrix, thawMatrix, freezeMatrix, runSTMatrix, 22 -- * Mutable Matrices
22 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix 23 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix,
24 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix,
25 -- * Unsafe functions
26 unsafeReadVector, unsafeWriteVector,
27 unsafeThawVector, unsafeFreezeVector,
28 unsafeReadMatrix, unsafeWriteMatrix,
29 unsafeThawMatrix, unsafeFreezeMatrix
23) where 30) where
24 31
25import Data.Packed.Internal 32import Data.Packed.Internal
@@ -28,44 +35,71 @@ import Control.Monad.ST
28import Data.Array.ST 35import Data.Array.ST
29import Foreign 36import Foreign
30 37
31 38{-# INLINE ioReadV #-}
32ioReadV :: Storable t => Vector t -> Int -> IO t 39ioReadV :: Storable t => Vector t -> Int -> IO t
33ioReadV v k = withForeignPtr (fptr v) $ \s -> peekElemOff s k 40ioReadV v k = withForeignPtr (fptr v) $ \s -> peekElemOff s k
34 41
42{-# INLINE ioWriteV #-}
35ioWriteV :: Storable t => Vector t -> Int -> t -> IO () 43ioWriteV :: Storable t => Vector t -> Int -> t -> IO ()
36ioWriteV v k x = withForeignPtr (fptr v) $ \s -> pokeElemOff s k x 44ioWriteV v k x = withForeignPtr (fptr v) $ \s -> pokeElemOff s k x
37 45
38newtype STVector s t = Mut (Vector t) 46newtype STVector s t = STVector (Vector t)
39 47
40thawVector :: Storable t => Vector t -> ST s (STVector s t) 48thawVector :: Storable t => Vector t -> ST s (STVector s t)
41thawVector = unsafeIOToST . fmap Mut . cloneVector 49thawVector = unsafeIOToST . fmap STVector . cloneVector
42 50
43unsafeFreezeVector (Mut x) = unsafeIOToST . return $ x 51unsafeThawVector :: Storable t => Vector t -> ST s (STVector s t)
52unsafeThawVector = unsafeIOToST . return . STVector
44 53
45runSTVector :: Storable t => (forall s . ST s (STVector s t)) -> Vector t 54runSTVector :: Storable t => (forall s . ST s (STVector s t)) -> Vector t
46runSTVector st = runST (st >>= unsafeFreezeVector) 55runSTVector st = runST (st >>= unsafeFreezeVector)
47 56
48readVector :: Storable t => STVector s t -> Int -> ST s t 57{-# INLINE unsafeReadVector #-}
49readVector (Mut x) = unsafeIOToST . ioReadV x 58unsafeReadVector :: Storable t => STVector s t -> Int -> ST s t
59unsafeReadVector (STVector x) = unsafeIOToST . ioReadV x
50 60
51writeVector :: Storable t => STVector s t -> Int -> t -> ST s () 61{-# INLINE unsafeWriteVector #-}
52writeVector (Mut x) k = unsafeIOToST . ioWriteV x k 62unsafeWriteVector :: Storable t => STVector s t -> Int -> t -> ST s ()
63unsafeWriteVector (STVector x) k = unsafeIOToST . ioWriteV x k
53 64
65{-# INLINE modifyVector #-}
54modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s () 66modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s ()
55modifyVector x k f = readVector x k >>= return . f >>= writeVector x k 67modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k
56 68
57liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s1 t -> ST s2 a 69liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s1 t -> ST s2 a
58liftSTVector f (Mut x) = unsafeIOToST . fmap f . cloneVector $ x 70liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x
59 71
60freezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t) 72freezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t)
61freezeVector v = liftSTVector id v 73freezeVector v = liftSTVector id v
62 74
75unsafeFreezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t)
76unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x
77
78{-# INLINE safeIndexV #-}
79safeIndexV f (STVector v) k
80 | k < 0 || k>= dim v = error $ "out of range error in vector (dim="
81 ++show (dim v)++", pos="++show k++")"
82 | otherwise = f (STVector v) k
83
84{-# INLINE readVector #-}
85readVector :: Storable t => STVector s t -> Int -> ST s t
86readVector = safeIndexV unsafeReadVector
87
88{-# INLINE writeVector #-}
89writeVector :: Storable t => STVector s t -> Int -> t -> ST s ()
90writeVector = safeIndexV unsafeWriteVector
91
92newVector :: Element t => t -> Int -> ST s (STVector s t)
93newVector v = unsafeThawVector . constant v
94
63------------------------------------------------------------------------- 95-------------------------------------------------------------------------
64 96
97{-# INLINE ioReadM #-}
65ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t 98ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t
66ioReadM (MC nr nc cv) r c = ioReadV cv (r*nc+c) 99ioReadM (MC nr nc cv) r c = ioReadV cv (r*nc+c)
67ioReadM (MF nr nc fv) r c = ioReadV fv (c*nr+r) 100ioReadM (MF nr nc fv) r c = ioReadV fv (c*nr+r)
68 101
102{-# INLINE ioWriteM #-}
69ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO () 103ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO ()
70ioWriteM (MC nr nc cv) r c val = ioWriteV cv (r*nc+c) val 104ioWriteM (MC nr nc cv) r c val = ioWriteV cv (r*nc+c) val
71ioWriteM (MF nr nc fv) r c val = ioWriteV fv (c*nr+r) val 105ioWriteM (MF nr nc fv) r c val = ioWriteV fv (c*nr+r) val
@@ -75,25 +109,50 @@ newtype STMatrix s t = STMatrix (Matrix t)
75thawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) 109thawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t)
76thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix 110thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix
77 111
78unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x 112unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t)
113unsafeThawMatrix = unsafeIOToST . return . STMatrix
79 114
80runSTMatrix :: Storable t => (forall s . ST s (STMatrix s t)) -> Matrix t 115runSTMatrix :: Storable t => (forall s . ST s (STMatrix s t)) -> Matrix t
81runSTMatrix st = runST (st >>= unsafeFreezeMatrix) 116runSTMatrix st = runST (st >>= unsafeFreezeMatrix)
82 117
83readMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t 118{-# INLINE unsafeReadMatrix #-}
84readMatrix (STMatrix x) r = unsafeIOToST . ioReadM x r 119unsafeReadMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t
120unsafeReadMatrix (STMatrix x) r = unsafeIOToST . ioReadM x r
85 121
86writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s () 122{-# INLINE unsafeWriteMatrix #-}
87writeMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c 123unsafeWriteMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s ()
124unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c
88 125
126{-# INLINE modifyMatrix #-}
89modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () 127modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
90modifyMatrix x r c f = readMatrix x r c >>= return . f >>= writeMatrix x r c 128modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c
91 129
92liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s1 t -> ST s2 a 130liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s1 t -> ST s2 a
93liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x 131liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x
94 132
133unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t)
134unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x
135
95freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) 136freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t)
96freezeMatrix m = liftSTMatrix id m 137freezeMatrix m = liftSTMatrix id m
97 138
98cloneMatrix (MC r c d) = cloneVector d >>= return . MC r c 139cloneMatrix (MC r c d) = cloneVector d >>= return . MC r c
99cloneMatrix (MF r c d) = cloneVector d >>= return . MF r c 140cloneMatrix (MF r c d) = cloneVector d >>= return . MF r c
141
142{-# INLINE safeIndexM #-}
143safeIndexM f (STMatrix m) r c
144 | r<0 || r>=rows m ||
145 c<0 || c>=cols m = error $ "out of range error in matrix (size="
146 ++show (rows m,cols m)++", pos="++show (r,c)++")"
147 | otherwise = f (STMatrix m) r c
148
149{-# INLINE readMatrix #-}
150readMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t
151readMatrix = safeIndexM unsafeReadMatrix
152
153{-# INLINE writeMatrix #-}
154writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s ()
155writeMatrix = safeIndexM unsafeWriteMatrix
156
157newMatrix :: Element t => t -> Int -> Int -> ST s (STMatrix s t)
158newMatrix v r c = unsafeThawMatrix . reshape c . constant v $ r*c