diff options
Diffstat (limited to 'lib/Data')
-rw-r--r-- | lib/Data/Packed/ST.hs | 97 |
1 files changed, 78 insertions, 19 deletions
diff --git a/lib/Data/Packed/ST.hs b/lib/Data/Packed/ST.hs index 3d94014..1311ff9 100644 --- a/lib/Data/Packed/ST.hs +++ b/lib/Data/Packed/ST.hs | |||
@@ -16,10 +16,17 @@ | |||
16 | ----------------------------------------------------------------------------- | 16 | ----------------------------------------------------------------------------- |
17 | 17 | ||
18 | module Data.Packed.ST ( | 18 | module Data.Packed.ST ( |
19 | STVector, thawVector, freezeVector, runSTVector, | 19 | -- * Mutable Vectors |
20 | STVector, newVector, thawVector, freezeVector, runSTVector, | ||
20 | readVector, writeVector, modifyVector, liftSTVector, | 21 | readVector, writeVector, modifyVector, liftSTVector, |
21 | STMatrix, thawMatrix, freezeMatrix, runSTMatrix, | 22 | -- * Mutable Matrices |
22 | readMatrix, writeMatrix, modifyMatrix, liftSTMatrix | 23 | STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, |
24 | readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, | ||
25 | -- * Unsafe functions | ||
26 | unsafeReadVector, unsafeWriteVector, | ||
27 | unsafeThawVector, unsafeFreezeVector, | ||
28 | unsafeReadMatrix, unsafeWriteMatrix, | ||
29 | unsafeThawMatrix, unsafeFreezeMatrix | ||
23 | ) where | 30 | ) where |
24 | 31 | ||
25 | import Data.Packed.Internal | 32 | import Data.Packed.Internal |
@@ -28,44 +35,71 @@ import Control.Monad.ST | |||
28 | import Data.Array.ST | 35 | import Data.Array.ST |
29 | import Foreign | 36 | import Foreign |
30 | 37 | ||
31 | 38 | {-# INLINE ioReadV #-} | |
32 | ioReadV :: Storable t => Vector t -> Int -> IO t | 39 | ioReadV :: Storable t => Vector t -> Int -> IO t |
33 | ioReadV v k = withForeignPtr (fptr v) $ \s -> peekElemOff s k | 40 | ioReadV v k = withForeignPtr (fptr v) $ \s -> peekElemOff s k |
34 | 41 | ||
42 | {-# INLINE ioWriteV #-} | ||
35 | ioWriteV :: Storable t => Vector t -> Int -> t -> IO () | 43 | ioWriteV :: Storable t => Vector t -> Int -> t -> IO () |
36 | ioWriteV v k x = withForeignPtr (fptr v) $ \s -> pokeElemOff s k x | 44 | ioWriteV v k x = withForeignPtr (fptr v) $ \s -> pokeElemOff s k x |
37 | 45 | ||
38 | newtype STVector s t = Mut (Vector t) | 46 | newtype STVector s t = STVector (Vector t) |
39 | 47 | ||
40 | thawVector :: Storable t => Vector t -> ST s (STVector s t) | 48 | thawVector :: Storable t => Vector t -> ST s (STVector s t) |
41 | thawVector = unsafeIOToST . fmap Mut . cloneVector | 49 | thawVector = unsafeIOToST . fmap STVector . cloneVector |
42 | 50 | ||
43 | unsafeFreezeVector (Mut x) = unsafeIOToST . return $ x | 51 | unsafeThawVector :: Storable t => Vector t -> ST s (STVector s t) |
52 | unsafeThawVector = unsafeIOToST . return . STVector | ||
44 | 53 | ||
45 | runSTVector :: Storable t => (forall s . ST s (STVector s t)) -> Vector t | 54 | runSTVector :: Storable t => (forall s . ST s (STVector s t)) -> Vector t |
46 | runSTVector st = runST (st >>= unsafeFreezeVector) | 55 | runSTVector st = runST (st >>= unsafeFreezeVector) |
47 | 56 | ||
48 | readVector :: Storable t => STVector s t -> Int -> ST s t | 57 | {-# INLINE unsafeReadVector #-} |
49 | readVector (Mut x) = unsafeIOToST . ioReadV x | 58 | unsafeReadVector :: Storable t => STVector s t -> Int -> ST s t |
59 | unsafeReadVector (STVector x) = unsafeIOToST . ioReadV x | ||
50 | 60 | ||
51 | writeVector :: Storable t => STVector s t -> Int -> t -> ST s () | 61 | {-# INLINE unsafeWriteVector #-} |
52 | writeVector (Mut x) k = unsafeIOToST . ioWriteV x k | 62 | unsafeWriteVector :: Storable t => STVector s t -> Int -> t -> ST s () |
63 | unsafeWriteVector (STVector x) k = unsafeIOToST . ioWriteV x k | ||
53 | 64 | ||
65 | {-# INLINE modifyVector #-} | ||
54 | modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s () | 66 | modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s () |
55 | modifyVector x k f = readVector x k >>= return . f >>= writeVector x k | 67 | modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k |
56 | 68 | ||
57 | liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s1 t -> ST s2 a | 69 | liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s1 t -> ST s2 a |
58 | liftSTVector f (Mut x) = unsafeIOToST . fmap f . cloneVector $ x | 70 | liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x |
59 | 71 | ||
60 | freezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t) | 72 | freezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t) |
61 | freezeVector v = liftSTVector id v | 73 | freezeVector v = liftSTVector id v |
62 | 74 | ||
75 | unsafeFreezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t) | ||
76 | unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x | ||
77 | |||
78 | {-# INLINE safeIndexV #-} | ||
79 | safeIndexV f (STVector v) k | ||
80 | | k < 0 || k>= dim v = error $ "out of range error in vector (dim=" | ||
81 | ++show (dim v)++", pos="++show k++")" | ||
82 | | otherwise = f (STVector v) k | ||
83 | |||
84 | {-# INLINE readVector #-} | ||
85 | readVector :: Storable t => STVector s t -> Int -> ST s t | ||
86 | readVector = safeIndexV unsafeReadVector | ||
87 | |||
88 | {-# INLINE writeVector #-} | ||
89 | writeVector :: Storable t => STVector s t -> Int -> t -> ST s () | ||
90 | writeVector = safeIndexV unsafeWriteVector | ||
91 | |||
92 | newVector :: Element t => t -> Int -> ST s (STVector s t) | ||
93 | newVector v = unsafeThawVector . constant v | ||
94 | |||
63 | ------------------------------------------------------------------------- | 95 | ------------------------------------------------------------------------- |
64 | 96 | ||
97 | {-# INLINE ioReadM #-} | ||
65 | ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t | 98 | ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t |
66 | ioReadM (MC nr nc cv) r c = ioReadV cv (r*nc+c) | 99 | ioReadM (MC nr nc cv) r c = ioReadV cv (r*nc+c) |
67 | ioReadM (MF nr nc fv) r c = ioReadV fv (c*nr+r) | 100 | ioReadM (MF nr nc fv) r c = ioReadV fv (c*nr+r) |
68 | 101 | ||
102 | {-# INLINE ioWriteM #-} | ||
69 | ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO () | 103 | ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO () |
70 | ioWriteM (MC nr nc cv) r c val = ioWriteV cv (r*nc+c) val | 104 | ioWriteM (MC nr nc cv) r c val = ioWriteV cv (r*nc+c) val |
71 | ioWriteM (MF nr nc fv) r c val = ioWriteV fv (c*nr+r) val | 105 | ioWriteM (MF nr nc fv) r c val = ioWriteV fv (c*nr+r) val |
@@ -75,25 +109,50 @@ newtype STMatrix s t = STMatrix (Matrix t) | |||
75 | thawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) | 109 | thawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) |
76 | thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix | 110 | thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix |
77 | 111 | ||
78 | unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x | 112 | unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) |
113 | unsafeThawMatrix = unsafeIOToST . return . STMatrix | ||
79 | 114 | ||
80 | runSTMatrix :: Storable t => (forall s . ST s (STMatrix s t)) -> Matrix t | 115 | runSTMatrix :: Storable t => (forall s . ST s (STMatrix s t)) -> Matrix t |
81 | runSTMatrix st = runST (st >>= unsafeFreezeMatrix) | 116 | runSTMatrix st = runST (st >>= unsafeFreezeMatrix) |
82 | 117 | ||
83 | readMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t | 118 | {-# INLINE unsafeReadMatrix #-} |
84 | readMatrix (STMatrix x) r = unsafeIOToST . ioReadM x r | 119 | unsafeReadMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t |
120 | unsafeReadMatrix (STMatrix x) r = unsafeIOToST . ioReadM x r | ||
85 | 121 | ||
86 | writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s () | 122 | {-# INLINE unsafeWriteMatrix #-} |
87 | writeMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c | 123 | unsafeWriteMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s () |
124 | unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c | ||
88 | 125 | ||
126 | {-# INLINE modifyMatrix #-} | ||
89 | modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () | 127 | modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () |
90 | modifyMatrix x r c f = readMatrix x r c >>= return . f >>= writeMatrix x r c | 128 | modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c |
91 | 129 | ||
92 | liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s1 t -> ST s2 a | 130 | liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s1 t -> ST s2 a |
93 | liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x | 131 | liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x |
94 | 132 | ||
133 | unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) | ||
134 | unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x | ||
135 | |||
95 | freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) | 136 | freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) |
96 | freezeMatrix m = liftSTMatrix id m | 137 | freezeMatrix m = liftSTMatrix id m |
97 | 138 | ||
98 | cloneMatrix (MC r c d) = cloneVector d >>= return . MC r c | 139 | cloneMatrix (MC r c d) = cloneVector d >>= return . MC r c |
99 | cloneMatrix (MF r c d) = cloneVector d >>= return . MF r c | 140 | cloneMatrix (MF r c d) = cloneVector d >>= return . MF r c |
141 | |||
142 | {-# INLINE safeIndexM #-} | ||
143 | safeIndexM f (STMatrix m) r c | ||
144 | | r<0 || r>=rows m || | ||
145 | c<0 || c>=cols m = error $ "out of range error in matrix (size=" | ||
146 | ++show (rows m,cols m)++", pos="++show (r,c)++")" | ||
147 | | otherwise = f (STMatrix m) r c | ||
148 | |||
149 | {-# INLINE readMatrix #-} | ||
150 | readMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t | ||
151 | readMatrix = safeIndexM unsafeReadMatrix | ||
152 | |||
153 | {-# INLINE writeMatrix #-} | ||
154 | writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s () | ||
155 | writeMatrix = safeIndexM unsafeWriteMatrix | ||
156 | |||
157 | newMatrix :: Element t => t -> Int -> Int -> ST s (STMatrix s t) | ||
158 | newMatrix v r c = unsafeThawMatrix . reshape c . constant v $ r*c | ||