diff options
Diffstat (limited to 'packages/base/src/Internal/Vector.hs')
-rw-r--r-- | packages/base/src/Internal/Vector.hs | 447 |
1 files changed, 447 insertions, 0 deletions
diff --git a/packages/base/src/Internal/Vector.hs b/packages/base/src/Internal/Vector.hs new file mode 100644 index 0000000..27ee13c --- /dev/null +++ b/packages/base/src/Internal/Vector.hs | |||
@@ -0,0 +1,447 @@ | |||
1 | {-# LANGUAGE MagicHash, CPP, UnboxedTuples, BangPatterns, FlexibleContexts #-} | ||
2 | {-# LANGUAGE TypeSynonymInstances #-} | ||
3 | |||
4 | |||
5 | -- | | ||
6 | -- Module : Internal.Vector | ||
7 | -- Copyright : (c) Alberto Ruiz 2007-15 | ||
8 | -- License : BSD3 | ||
9 | -- Maintainer : Alberto Ruiz | ||
10 | -- Stability : provisional | ||
11 | -- | ||
12 | |||
13 | module Internal.Vector where | ||
14 | |||
15 | import Internal.Tools | ||
16 | import Foreign.Marshal.Array ( peekArray, copyArray, advancePtr ) | ||
17 | import Foreign.ForeignPtr ( ForeignPtr, castForeignPtr ) | ||
18 | import Foreign.Ptr ( Ptr ) | ||
19 | import Foreign.Storable | ||
20 | ( Storable, peekElemOff, pokeElemOff, sizeOf ) | ||
21 | import Foreign.C.Types ( CInt ) | ||
22 | import Data.Complex ( Complex ) | ||
23 | import System.IO.Unsafe ( unsafePerformIO ) | ||
24 | import GHC.ForeignPtr ( mallocPlainForeignPtrBytes ) | ||
25 | import GHC.Base ( realWorld#, IO(IO), when ) | ||
26 | import qualified Data.Vector.Storable as Vector | ||
27 | ( Vector, slice, length ) | ||
28 | import Data.Vector.Storable | ||
29 | ( fromList, unsafeToForeignPtr, unsafeFromForeignPtr, unsafeWith ) | ||
30 | |||
31 | |||
32 | #ifdef BINARY | ||
33 | |||
34 | import Data.Binary | ||
35 | import Control.Monad(replicateM) | ||
36 | import qualified Data.ByteString.Internal as BS | ||
37 | import Data.Vector.Storable.Internal(updPtr) | ||
38 | import Foreign.Ptr(plusPtr) | ||
39 | |||
40 | #endif | ||
41 | |||
42 | |||
43 | |||
44 | type Vector = Vector.Vector | ||
45 | |||
46 | -- | Number of elements | ||
47 | dim :: (Storable t) => Vector t -> Int | ||
48 | dim = Vector.length | ||
49 | |||
50 | |||
51 | -- C-Haskell vector adapter | ||
52 | -- vec :: Adapt (CInt -> Ptr t -> r) (Vector t) r | ||
53 | vec :: (Storable t) => Vector t -> (((CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b | ||
54 | vec x f = unsafeWith x $ \p -> do | ||
55 | let v g = do | ||
56 | g (fi $ dim x) p | ||
57 | f v | ||
58 | {-# INLINE vec #-} | ||
59 | |||
60 | |||
61 | -- allocates memory for a new vector | ||
62 | createVector :: Storable a => Int -> IO (Vector a) | ||
63 | createVector n = do | ||
64 | when (n < 0) $ error ("trying to createVector of negative dim: "++show n) | ||
65 | fp <- doMalloc undefined | ||
66 | return $ unsafeFromForeignPtr fp 0 n | ||
67 | where | ||
68 | -- | ||
69 | -- Use the much cheaper Haskell heap allocated storage | ||
70 | -- for foreign pointer space we control | ||
71 | -- | ||
72 | doMalloc :: Storable b => b -> IO (ForeignPtr b) | ||
73 | doMalloc dummy = do | ||
74 | mallocPlainForeignPtrBytes (n * sizeOf dummy) | ||
75 | |||
76 | {- | creates a Vector from a list: | ||
77 | |||
78 | @> fromList [2,3,5,7] | ||
79 | 4 |> [2.0,3.0,5.0,7.0]@ | ||
80 | |||
81 | -} | ||
82 | |||
83 | safeRead v = inlinePerformIO . unsafeWith v | ||
84 | {-# INLINE safeRead #-} | ||
85 | |||
86 | inlinePerformIO :: IO a -> a | ||
87 | inlinePerformIO (IO m) = case m realWorld# of (# _, r #) -> r | ||
88 | {-# INLINE inlinePerformIO #-} | ||
89 | |||
90 | {- extracts the Vector elements to a list | ||
91 | |||
92 | >>> toList (linspace 5 (1,10)) | ||
93 | [1.0,3.25,5.5,7.75,10.0] | ||
94 | |||
95 | -} | ||
96 | toList :: Storable a => Vector a -> [a] | ||
97 | toList v = safeRead v $ peekArray (dim v) | ||
98 | |||
99 | {- | Create a vector from a list of elements and explicit dimension. The input | ||
100 | list is truncated if it is too long, so it may safely | ||
101 | be used, for instance, with infinite lists. | ||
102 | |||
103 | >>> 5 |> [1..] | ||
104 | fromList [1.0,2.0,3.0,4.0,5.0] | ||
105 | |||
106 | -} | ||
107 | (|>) :: (Storable a) => Int -> [a] -> Vector a | ||
108 | infixl 9 |> | ||
109 | n |> l | ||
110 | | length l' == n = fromList l' | ||
111 | | otherwise = error "list too short for |>" | ||
112 | where | ||
113 | l' = take n l | ||
114 | |||
115 | |||
116 | -- | Create a vector of indexes, useful for matrix extraction using '??' | ||
117 | idxs :: [Int] -> Vector I | ||
118 | idxs js = fromList (map fromIntegral js) :: Vector I | ||
119 | |||
120 | {- | takes a number of consecutive elements from a Vector | ||
121 | |||
122 | >>> subVector 2 3 (fromList [1..10]) | ||
123 | fromList [3.0,4.0,5.0] | ||
124 | |||
125 | -} | ||
126 | subVector :: Storable t => Int -- ^ index of the starting element | ||
127 | -> Int -- ^ number of elements to extract | ||
128 | -> Vector t -- ^ source | ||
129 | -> Vector t -- ^ result | ||
130 | subVector = Vector.slice | ||
131 | |||
132 | |||
133 | |||
134 | |||
135 | {- | Reads a vector position: | ||
136 | |||
137 | >>> fromList [0..9] @> 7 | ||
138 | 7.0 | ||
139 | |||
140 | -} | ||
141 | (@>) :: Storable t => Vector t -> Int -> t | ||
142 | infixl 9 @> | ||
143 | v @> n | ||
144 | | n >= 0 && n < dim v = at' v n | ||
145 | | otherwise = error "vector index out of range" | ||
146 | {-# INLINE (@>) #-} | ||
147 | |||
148 | -- | access to Vector elements without range checking | ||
149 | at' :: Storable a => Vector a -> Int -> a | ||
150 | at' v n = safeRead v $ flip peekElemOff n | ||
151 | {-# INLINE at' #-} | ||
152 | |||
153 | {- | concatenate a list of vectors | ||
154 | |||
155 | >>> vjoin [fromList [1..5::Double], konst 1 3] | ||
156 | fromList [1.0,2.0,3.0,4.0,5.0,1.0,1.0,1.0] | ||
157 | |||
158 | -} | ||
159 | vjoin :: Storable t => [Vector t] -> Vector t | ||
160 | vjoin [] = fromList [] | ||
161 | vjoin [v] = v | ||
162 | vjoin as = unsafePerformIO $ do | ||
163 | let tot = sum (map dim as) | ||
164 | r <- createVector tot | ||
165 | unsafeWith r $ \ptr -> | ||
166 | joiner as tot ptr | ||
167 | return r | ||
168 | where joiner [] _ _ = return () | ||
169 | joiner (v:cs) _ p = do | ||
170 | let n = dim v | ||
171 | unsafeWith v $ \pb -> copyArray p pb n | ||
172 | joiner cs 0 (advancePtr p n) | ||
173 | |||
174 | |||
175 | {- | Extract consecutive subvectors of the given sizes. | ||
176 | |||
177 | >>> takesV [3,4] (linspace 10 (1,10::Double)) | ||
178 | [fromList [1.0,2.0,3.0],fromList [4.0,5.0,6.0,7.0]] | ||
179 | |||
180 | -} | ||
181 | takesV :: Storable t => [Int] -> Vector t -> [Vector t] | ||
182 | takesV ms w | sum ms > dim w = error $ "takesV " ++ show ms ++ " on dim = " ++ (show $ dim w) | ||
183 | | otherwise = go ms w | ||
184 | where go [] _ = [] | ||
185 | go (n:ns) v = subVector 0 n v | ||
186 | : go ns (subVector n (dim v - n) v) | ||
187 | |||
188 | --------------------------------------------------------------- | ||
189 | |||
190 | -- | transforms a complex vector into a real vector with alternating real and imaginary parts | ||
191 | asReal :: (RealFloat a, Storable a) => Vector (Complex a) -> Vector a | ||
192 | asReal v = unsafeFromForeignPtr (castForeignPtr fp) (2*i) (2*n) | ||
193 | where (fp,i,n) = unsafeToForeignPtr v | ||
194 | |||
195 | -- | transforms a real vector into a complex vector with alternating real and imaginary parts | ||
196 | asComplex :: (RealFloat a, Storable a) => Vector a -> Vector (Complex a) | ||
197 | asComplex v = unsafeFromForeignPtr (castForeignPtr fp) (i `div` 2) (n `div` 2) | ||
198 | where (fp,i,n) = unsafeToForeignPtr v | ||
199 | |||
200 | -------------------------------------------------------------------------------- | ||
201 | |||
202 | |||
203 | -- | map on Vectors | ||
204 | mapVector :: (Storable a, Storable b) => (a-> b) -> Vector a -> Vector b | ||
205 | mapVector f v = unsafePerformIO $ do | ||
206 | w <- createVector (dim v) | ||
207 | unsafeWith v $ \p -> | ||
208 | unsafeWith w $ \q -> do | ||
209 | let go (-1) = return () | ||
210 | go !k = do x <- peekElemOff p k | ||
211 | pokeElemOff q k (f x) | ||
212 | go (k-1) | ||
213 | go (dim v -1) | ||
214 | return w | ||
215 | {-# INLINE mapVector #-} | ||
216 | |||
217 | -- | zipWith for Vectors | ||
218 | zipVectorWith :: (Storable a, Storable b, Storable c) => (a-> b -> c) -> Vector a -> Vector b -> Vector c | ||
219 | zipVectorWith f u v = unsafePerformIO $ do | ||
220 | let n = min (dim u) (dim v) | ||
221 | w <- createVector n | ||
222 | unsafeWith u $ \pu -> | ||
223 | unsafeWith v $ \pv -> | ||
224 | unsafeWith w $ \pw -> do | ||
225 | let go (-1) = return () | ||
226 | go !k = do x <- peekElemOff pu k | ||
227 | y <- peekElemOff pv k | ||
228 | pokeElemOff pw k (f x y) | ||
229 | go (k-1) | ||
230 | go (n -1) | ||
231 | return w | ||
232 | {-# INLINE zipVectorWith #-} | ||
233 | |||
234 | -- | unzipWith for Vectors | ||
235 | unzipVectorWith :: (Storable (a,b), Storable c, Storable d) | ||
236 | => ((a,b) -> (c,d)) -> Vector (a,b) -> (Vector c,Vector d) | ||
237 | unzipVectorWith f u = unsafePerformIO $ do | ||
238 | let n = dim u | ||
239 | v <- createVector n | ||
240 | w <- createVector n | ||
241 | unsafeWith u $ \pu -> | ||
242 | unsafeWith v $ \pv -> | ||
243 | unsafeWith w $ \pw -> do | ||
244 | let go (-1) = return () | ||
245 | go !k = do z <- peekElemOff pu k | ||
246 | let (x,y) = f z | ||
247 | pokeElemOff pv k x | ||
248 | pokeElemOff pw k y | ||
249 | go (k-1) | ||
250 | go (n-1) | ||
251 | return (v,w) | ||
252 | {-# INLINE unzipVectorWith #-} | ||
253 | |||
254 | foldVector :: Storable a => (a -> b -> b) -> b -> Vector a -> b | ||
255 | foldVector f x v = unsafePerformIO $ | ||
256 | unsafeWith v $ \p -> do | ||
257 | let go (-1) s = return s | ||
258 | go !k !s = do y <- peekElemOff p k | ||
259 | go (k-1::Int) (f y s) | ||
260 | go (dim v -1) x | ||
261 | {-# INLINE foldVector #-} | ||
262 | |||
263 | -- the zero-indexed index is passed to the folding function | ||
264 | foldVectorWithIndex :: Storable a => (Int -> a -> b -> b) -> b -> Vector a -> b | ||
265 | foldVectorWithIndex f x v = unsafePerformIO $ | ||
266 | unsafeWith v $ \p -> do | ||
267 | let go (-1) s = return s | ||
268 | go !k !s = do y <- peekElemOff p k | ||
269 | go (k-1::Int) (f k y s) | ||
270 | go (dim v -1) x | ||
271 | {-# INLINE foldVectorWithIndex #-} | ||
272 | |||
273 | foldLoop f s0 d = go (d - 1) s0 | ||
274 | where | ||
275 | go 0 s = f (0::Int) s | ||
276 | go !j !s = go (j - 1) (f j s) | ||
277 | |||
278 | foldVectorG f s0 v = foldLoop g s0 (dim v) | ||
279 | where g !k !s = f k (safeRead v . flip peekElemOff) s | ||
280 | {-# INLINE g #-} -- Thanks to Ryan Ingram (http://permalink.gmane.org/gmane.comp.lang.haskell.cafe/46479) | ||
281 | {-# INLINE foldVectorG #-} | ||
282 | |||
283 | ------------------------------------------------------------------- | ||
284 | |||
285 | -- | monadic map over Vectors | ||
286 | -- the monad @m@ must be strict | ||
287 | mapVectorM :: (Storable a, Storable b, Monad m) => (a -> m b) -> Vector a -> m (Vector b) | ||
288 | mapVectorM f v = do | ||
289 | w <- return $! unsafePerformIO $! createVector (dim v) | ||
290 | mapVectorM' w 0 (dim v -1) | ||
291 | return w | ||
292 | where mapVectorM' w' !k !t | ||
293 | | k == t = do | ||
294 | x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k | ||
295 | y <- f x | ||
296 | return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y | ||
297 | | otherwise = do | ||
298 | x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k | ||
299 | y <- f x | ||
300 | _ <- return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y | ||
301 | mapVectorM' w' (k+1) t | ||
302 | {-# INLINE mapVectorM #-} | ||
303 | |||
304 | -- | monadic map over Vectors | ||
305 | mapVectorM_ :: (Storable a, Monad m) => (a -> m ()) -> Vector a -> m () | ||
306 | mapVectorM_ f v = do | ||
307 | mapVectorM' 0 (dim v -1) | ||
308 | where mapVectorM' !k !t | ||
309 | | k == t = do | ||
310 | x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k | ||
311 | f x | ||
312 | | otherwise = do | ||
313 | x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k | ||
314 | _ <- f x | ||
315 | mapVectorM' (k+1) t | ||
316 | {-# INLINE mapVectorM_ #-} | ||
317 | |||
318 | -- | monadic map over Vectors with the zero-indexed index passed to the mapping function | ||
319 | -- the monad @m@ must be strict | ||
320 | mapVectorWithIndexM :: (Storable a, Storable b, Monad m) => (Int -> a -> m b) -> Vector a -> m (Vector b) | ||
321 | mapVectorWithIndexM f v = do | ||
322 | w <- return $! unsafePerformIO $! createVector (dim v) | ||
323 | mapVectorM' w 0 (dim v -1) | ||
324 | return w | ||
325 | where mapVectorM' w' !k !t | ||
326 | | k == t = do | ||
327 | x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k | ||
328 | y <- f k x | ||
329 | return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y | ||
330 | | otherwise = do | ||
331 | x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k | ||
332 | y <- f k x | ||
333 | _ <- return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y | ||
334 | mapVectorM' w' (k+1) t | ||
335 | {-# INLINE mapVectorWithIndexM #-} | ||
336 | |||
337 | -- | monadic map over Vectors with the zero-indexed index passed to the mapping function | ||
338 | mapVectorWithIndexM_ :: (Storable a, Monad m) => (Int -> a -> m ()) -> Vector a -> m () | ||
339 | mapVectorWithIndexM_ f v = do | ||
340 | mapVectorM' 0 (dim v -1) | ||
341 | where mapVectorM' !k !t | ||
342 | | k == t = do | ||
343 | x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k | ||
344 | f k x | ||
345 | | otherwise = do | ||
346 | x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k | ||
347 | _ <- f k x | ||
348 | mapVectorM' (k+1) t | ||
349 | {-# INLINE mapVectorWithIndexM_ #-} | ||
350 | |||
351 | |||
352 | mapVectorWithIndex :: (Storable a, Storable b) => (Int -> a -> b) -> Vector a -> Vector b | ||
353 | --mapVectorWithIndex g = head . mapVectorWithIndexM (\a b -> [g a b]) | ||
354 | mapVectorWithIndex f v = unsafePerformIO $ do | ||
355 | w <- createVector (dim v) | ||
356 | unsafeWith v $ \p -> | ||
357 | unsafeWith w $ \q -> do | ||
358 | let go (-1) = return () | ||
359 | go !k = do x <- peekElemOff p k | ||
360 | pokeElemOff q k (f k x) | ||
361 | go (k-1) | ||
362 | go (dim v -1) | ||
363 | return w | ||
364 | {-# INLINE mapVectorWithIndex #-} | ||
365 | |||
366 | -------------------------------------------------------------------------------- | ||
367 | |||
368 | |||
369 | #ifdef BINARY | ||
370 | |||
371 | -- a 64K cache, with a Double taking 13 bytes in Bytestring, | ||
372 | -- implies a chunk size of 5041 | ||
373 | chunk :: Int | ||
374 | chunk = 5000 | ||
375 | |||
376 | chunks :: Int -> [Int] | ||
377 | chunks d = let c = d `div` chunk | ||
378 | m = d `mod` chunk | ||
379 | in if m /= 0 then reverse (m:(replicate c chunk)) else (replicate c chunk) | ||
380 | |||
381 | putVector v = mapM_ put $! toList v | ||
382 | |||
383 | getVector d = do | ||
384 | xs <- replicateM d get | ||
385 | return $! fromList xs | ||
386 | |||
387 | -------------------------------------------------------------------------------- | ||
388 | |||
389 | toByteString :: Storable t => Vector t -> BS.ByteString | ||
390 | toByteString v = BS.PS (castForeignPtr fp) (sz*o) (sz * dim v) | ||
391 | where | ||
392 | (fp,o,_n) = unsafeToForeignPtr v | ||
393 | sz = sizeOf (v@>0) | ||
394 | |||
395 | |||
396 | fromByteString :: Storable t => BS.ByteString -> Vector t | ||
397 | fromByteString (BS.PS fp o n) = r | ||
398 | where | ||
399 | r = unsafeFromForeignPtr (castForeignPtr (updPtr (`plusPtr` o) fp)) 0 n' | ||
400 | n' = n `div` sz | ||
401 | sz = sizeOf (r@>0) | ||
402 | |||
403 | -------------------------------------------------------------------------------- | ||
404 | |||
405 | instance (Binary a, Storable a) => Binary (Vector a) where | ||
406 | |||
407 | put v = do | ||
408 | let d = dim v | ||
409 | put d | ||
410 | mapM_ putVector $! takesV (chunks d) v | ||
411 | |||
412 | -- put = put . v2bs | ||
413 | |||
414 | get = do | ||
415 | d <- get | ||
416 | vs <- mapM getVector $ chunks d | ||
417 | return $! vjoin vs | ||
418 | |||
419 | -- get = fmap bs2v get | ||
420 | |||
421 | #endif | ||
422 | |||
423 | |||
424 | ------------------------------------------------------------------- | ||
425 | |||
426 | {- | creates a Vector of the specified length using the supplied function to | ||
427 | to map the index to the value at that index. | ||
428 | |||
429 | @> buildVector 4 fromIntegral | ||
430 | 4 |> [0.0,1.0,2.0,3.0]@ | ||
431 | |||
432 | -} | ||
433 | buildVector :: Storable a => Int -> (Int -> a) -> Vector a | ||
434 | buildVector len f = | ||
435 | fromList $ map f [0 .. (len - 1)] | ||
436 | |||
437 | |||
438 | -- | zip for Vectors | ||
439 | zipVector :: (Storable a, Storable b, Storable (a,b)) => Vector a -> Vector b -> Vector (a,b) | ||
440 | zipVector = zipVectorWith (,) | ||
441 | |||
442 | -- | unzip for Vectors | ||
443 | unzipVector :: (Storable a, Storable b, Storable (a,b)) => Vector (a,b) -> (Vector a,Vector b) | ||
444 | unzipVector = unzipVectorWith id | ||
445 | |||
446 | ------------------------------------------------------------------- | ||
447 | |||