summaryrefslogtreecommitdiff
path: root/lib/Data
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Data')
-rw-r--r--lib/Data/Packed/ST.hs99
1 files changed, 99 insertions, 0 deletions
diff --git a/lib/Data/Packed/ST.hs b/lib/Data/Packed/ST.hs
new file mode 100644
index 0000000..3d94014
--- /dev/null
+++ b/lib/Data/Packed/ST.hs
@@ -0,0 +1,99 @@
1{-# OPTIONS -XTypeOperators -XRank2Types -XFlexibleContexts #-}
2
3-----------------------------------------------------------------------------
4-- |
5-- Module : Data.Packed.ST
6-- Copyright : (c) Alberto Ruiz 2008
7-- License : GPL-style
8--
9-- Maintainer : Alberto Ruiz <aruiz@um.es>
10-- Stability : provisional
11-- Portability : portable
12--
13-- In-place manipulation inside the ST monad.
14-- See examples/inplace.hs in the distribution.
15--
16-----------------------------------------------------------------------------
17
18module Data.Packed.ST (
19 STVector, thawVector, freezeVector, runSTVector,
20 readVector, writeVector, modifyVector, liftSTVector,
21 STMatrix, thawMatrix, freezeMatrix, runSTMatrix,
22 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix
23) where
24
25import Data.Packed.Internal
26import Data.Array.Storable
27import Control.Monad.ST
28import Data.Array.ST
29import Foreign
30
31
32ioReadV :: Storable t => Vector t -> Int -> IO t
33ioReadV v k = withForeignPtr (fptr v) $ \s -> peekElemOff s k
34
35ioWriteV :: Storable t => Vector t -> Int -> t -> IO ()
36ioWriteV v k x = withForeignPtr (fptr v) $ \s -> pokeElemOff s k x
37
38newtype STVector s t = Mut (Vector t)
39
40thawVector :: Storable t => Vector t -> ST s (STVector s t)
41thawVector = unsafeIOToST . fmap Mut . cloneVector
42
43unsafeFreezeVector (Mut x) = unsafeIOToST . return $ x
44
45runSTVector :: Storable t => (forall s . ST s (STVector s t)) -> Vector t
46runSTVector st = runST (st >>= unsafeFreezeVector)
47
48readVector :: Storable t => STVector s t -> Int -> ST s t
49readVector (Mut x) = unsafeIOToST . ioReadV x
50
51writeVector :: Storable t => STVector s t -> Int -> t -> ST s ()
52writeVector (Mut x) k = unsafeIOToST . ioWriteV x k
53
54modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s ()
55modifyVector x k f = readVector x k >>= return . f >>= writeVector x k
56
57liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s1 t -> ST s2 a
58liftSTVector f (Mut x) = unsafeIOToST . fmap f . cloneVector $ x
59
60freezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t)
61freezeVector v = liftSTVector id v
62
63-------------------------------------------------------------------------
64
65ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t
66ioReadM (MC nr nc cv) r c = ioReadV cv (r*nc+c)
67ioReadM (MF nr nc fv) r c = ioReadV fv (c*nr+r)
68
69ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO ()
70ioWriteM (MC nr nc cv) r c val = ioWriteV cv (r*nc+c) val
71ioWriteM (MF nr nc fv) r c val = ioWriteV fv (c*nr+r) val
72
73newtype STMatrix s t = STMatrix (Matrix t)
74
75thawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t)
76thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix
77
78unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x
79
80runSTMatrix :: Storable t => (forall s . ST s (STMatrix s t)) -> Matrix t
81runSTMatrix st = runST (st >>= unsafeFreezeMatrix)
82
83readMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t
84readMatrix (STMatrix x) r = unsafeIOToST . ioReadM x r
85
86writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s ()
87writeMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c
88
89modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
90modifyMatrix x r c f = readMatrix x r c >>= return . f >>= writeMatrix x r c
91
92liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s1 t -> ST s2 a
93liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x
94
95freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t)
96freezeMatrix m = liftSTMatrix id m
97
98cloneMatrix (MC r c d) = cloneVector d >>= return . MC r c
99cloneMatrix (MF r c d) = cloneVector d >>= return . MF r c