summaryrefslogtreecommitdiff
path: root/packages/hmatrix/src/Data/Packed/ST.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/hmatrix/src/Data/Packed/ST.hs')
-rw-r--r--packages/hmatrix/src/Data/Packed/ST.hs179
1 files changed, 0 insertions, 179 deletions
diff --git a/packages/hmatrix/src/Data/Packed/ST.hs b/packages/hmatrix/src/Data/Packed/ST.hs
deleted file mode 100644
index 1cef296..0000000
--- a/packages/hmatrix/src/Data/Packed/ST.hs
+++ /dev/null
@@ -1,179 +0,0 @@
1{-# LANGUAGE CPP #-}
2{-# LANGUAGE TypeOperators #-}
3{-# LANGUAGE Rank2Types #-}
4{-# LANGUAGE BangPatterns #-}
5-----------------------------------------------------------------------------
6-- |
7-- Module : Data.Packed.ST
8-- Copyright : (c) Alberto Ruiz 2008
9-- License : GPL-style
10--
11-- Maintainer : Alberto Ruiz <aruiz@um.es>
12-- Stability : provisional
13-- Portability : portable
14--
15-- In-place manipulation inside the ST monad.
16-- See examples/inplace.hs in the distribution.
17--
18-----------------------------------------------------------------------------
19{-# OPTIONS_HADDOCK hide #-}
20
21module Data.Packed.ST (
22 -- * Mutable Vectors
23 STVector, newVector, thawVector, freezeVector, runSTVector,
24 readVector, writeVector, modifyVector, liftSTVector,
25 -- * Mutable Matrices
26 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix,
27 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix,
28 -- * Unsafe functions
29 newUndefinedVector,
30 unsafeReadVector, unsafeWriteVector,
31 unsafeThawVector, unsafeFreezeVector,
32 newUndefinedMatrix,
33 unsafeReadMatrix, unsafeWriteMatrix,
34 unsafeThawMatrix, unsafeFreezeMatrix
35) where
36
37import Data.Packed.Internal
38
39import Control.Monad.ST(ST, runST)
40import Foreign.Storable(Storable, peekElemOff, pokeElemOff)
41
42#if MIN_VERSION_base(4,4,0)
43import Control.Monad.ST.Unsafe(unsafeIOToST)
44#else
45import Control.Monad.ST(unsafeIOToST)
46#endif
47
48{-# INLINE ioReadV #-}
49ioReadV :: Storable t => Vector t -> Int -> IO t
50ioReadV v k = unsafeWith v $ \s -> peekElemOff s k
51
52{-# INLINE ioWriteV #-}
53ioWriteV :: Storable t => Vector t -> Int -> t -> IO ()
54ioWriteV v k x = unsafeWith v $ \s -> pokeElemOff s k x
55
56newtype STVector s t = STVector (Vector t)
57
58thawVector :: Storable t => Vector t -> ST s (STVector s t)
59thawVector = unsafeIOToST . fmap STVector . cloneVector
60
61unsafeThawVector :: Storable t => Vector t -> ST s (STVector s t)
62unsafeThawVector = unsafeIOToST . return . STVector
63
64runSTVector :: Storable t => (forall s . ST s (STVector s t)) -> Vector t
65runSTVector st = runST (st >>= unsafeFreezeVector)
66
67{-# INLINE unsafeReadVector #-}
68unsafeReadVector :: Storable t => STVector s t -> Int -> ST s t
69unsafeReadVector (STVector x) = unsafeIOToST . ioReadV x
70
71{-# INLINE unsafeWriteVector #-}
72unsafeWriteVector :: Storable t => STVector s t -> Int -> t -> ST s ()
73unsafeWriteVector (STVector x) k = unsafeIOToST . ioWriteV x k
74
75{-# INLINE modifyVector #-}
76modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s ()
77modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k
78
79liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s1 t -> ST s2 a
80liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x
81
82freezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t)
83freezeVector v = liftSTVector id v
84
85unsafeFreezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t)
86unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x
87
88{-# INLINE safeIndexV #-}
89safeIndexV f (STVector v) k
90 | k < 0 || k>= dim v = error $ "out of range error in vector (dim="
91 ++show (dim v)++", pos="++show k++")"
92 | otherwise = f (STVector v) k
93
94{-# INLINE readVector #-}
95readVector :: Storable t => STVector s t -> Int -> ST s t
96readVector = safeIndexV unsafeReadVector
97
98{-# INLINE writeVector #-}
99writeVector :: Storable t => STVector s t -> Int -> t -> ST s ()
100writeVector = safeIndexV unsafeWriteVector
101
102newUndefinedVector :: Storable t => Int -> ST s (STVector s t)
103newUndefinedVector = unsafeIOToST . fmap STVector . createVector
104
105{-# INLINE newVector #-}
106newVector :: Storable t => t -> Int -> ST s (STVector s t)
107newVector x n = do
108 v <- newUndefinedVector n
109 let go (-1) = return v
110 go !k = unsafeWriteVector v k x >> go (k-1 :: Int)
111 go (n-1)
112
113-------------------------------------------------------------------------
114
115{-# INLINE ioReadM #-}
116ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t
117ioReadM (Matrix _ nc cv RowMajor) r c = ioReadV cv (r*nc+c)
118ioReadM (Matrix nr _ fv ColumnMajor) r c = ioReadV fv (c*nr+r)
119
120{-# INLINE ioWriteM #-}
121ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO ()
122ioWriteM (Matrix _ nc cv RowMajor) r c val = ioWriteV cv (r*nc+c) val
123ioWriteM (Matrix nr _ fv ColumnMajor) r c val = ioWriteV fv (c*nr+r) val
124
125newtype STMatrix s t = STMatrix (Matrix t)
126
127thawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t)
128thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix
129
130unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t)
131unsafeThawMatrix = unsafeIOToST . return . STMatrix
132
133runSTMatrix :: Storable t => (forall s . ST s (STMatrix s t)) -> Matrix t
134runSTMatrix st = runST (st >>= unsafeFreezeMatrix)
135
136{-# INLINE unsafeReadMatrix #-}
137unsafeReadMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t
138unsafeReadMatrix (STMatrix x) r = unsafeIOToST . ioReadM x r
139
140{-# INLINE unsafeWriteMatrix #-}
141unsafeWriteMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s ()
142unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c
143
144{-# INLINE modifyMatrix #-}
145modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
146modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c
147
148liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s1 t -> ST s2 a
149liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x
150
151unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t)
152unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x
153
154freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t)
155freezeMatrix m = liftSTMatrix id m
156
157cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o)
158
159{-# INLINE safeIndexM #-}
160safeIndexM f (STMatrix m) r c
161 | r<0 || r>=rows m ||
162 c<0 || c>=cols m = error $ "out of range error in matrix (size="
163 ++show (rows m,cols m)++", pos="++show (r,c)++")"
164 | otherwise = f (STMatrix m) r c
165
166{-# INLINE readMatrix #-}
167readMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t
168readMatrix = safeIndexM unsafeReadMatrix
169
170{-# INLINE writeMatrix #-}
171writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s ()
172writeMatrix = safeIndexM unsafeWriteMatrix
173
174newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t)
175newUndefinedMatrix ord r c = unsafeIOToST $ fmap STMatrix $ createMatrix ord r c
176
177{-# NOINLINE newMatrix #-}
178newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t)
179newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c)