diff options
Diffstat (limited to 'packages/base/src/Data')
-rw-r--r-- | packages/base/src/Data/Packed/Matrix.hs | 16 | ||||
-rw-r--r-- | packages/base/src/Data/Packed/Vector.hs | 40 |
2 files changed, 41 insertions, 15 deletions
diff --git a/packages/base/src/Data/Packed/Matrix.hs b/packages/base/src/Data/Packed/Matrix.hs index 2420c94..6445ce4 100644 --- a/packages/base/src/Data/Packed/Matrix.hs +++ b/packages/base/src/Data/Packed/Matrix.hs | |||
@@ -1,6 +1,7 @@ | |||
1 | {-# LANGUAGE TypeFamilies #-} | 1 | {-# LANGUAGE TypeFamilies #-} |
2 | {-# LANGUAGE FlexibleContexts #-} | 2 | {-# LANGUAGE FlexibleContexts #-} |
3 | {-# LANGUAGE FlexibleInstances #-} | 3 | {-# LANGUAGE FlexibleInstances #-} |
4 | {-# LANGUAGE UndecidableInstances #-} | ||
4 | {-# LANGUAGE MultiParamTypeClasses #-} | 5 | {-# LANGUAGE MultiParamTypeClasses #-} |
5 | {-# LANGUAGE CPP #-} | 6 | {-# LANGUAGE CPP #-} |
6 | 7 | ||
@@ -53,20 +54,15 @@ import Control.Monad(liftM) | |||
53 | #ifdef BINARY | 54 | #ifdef BINARY |
54 | 55 | ||
55 | import Data.Binary | 56 | import Data.Binary |
56 | import Control.Monad(replicateM) | ||
57 | 57 | ||
58 | instance (Binary a, Element a, Storable a) => Binary (Matrix a) where | 58 | instance (Binary (Vector a), Element a) => Binary (Matrix a) where |
59 | put m = do | 59 | put m = do |
60 | let r = rows m | 60 | put (cols m) |
61 | let c = cols m | 61 | put (flatten m) |
62 | put r | ||
63 | put c | ||
64 | mapM_ (\i -> mapM_ (\j -> put $ m @@> (i,j)) [0..(c-1)]) [0..(r-1)] | ||
65 | get = do | 62 | get = do |
66 | r <- get | ||
67 | c <- get | 63 | c <- get |
68 | xs <- replicateM r $ replicateM c get | 64 | v <- get |
69 | return $ fromLists xs | 65 | return (reshape c v) |
70 | 66 | ||
71 | #endif | 67 | #endif |
72 | 68 | ||
diff --git a/packages/base/src/Data/Packed/Vector.hs b/packages/base/src/Data/Packed/Vector.hs index 31dcf47..2104f52 100644 --- a/packages/base/src/Data/Packed/Vector.hs +++ b/packages/base/src/Data/Packed/Vector.hs | |||
@@ -22,7 +22,8 @@ module Data.Packed.Vector ( | |||
22 | subVector, takesV, vjoin, join, | 22 | subVector, takesV, vjoin, join, |
23 | mapVector, mapVectorWithIndex, zipVector, zipVectorWith, unzipVector, unzipVectorWith, | 23 | mapVector, mapVectorWithIndex, zipVector, zipVectorWith, unzipVector, unzipVectorWith, |
24 | mapVectorM, mapVectorM_, mapVectorWithIndexM, mapVectorWithIndexM_, | 24 | mapVectorM, mapVectorM_, mapVectorWithIndexM, mapVectorWithIndexM_, |
25 | foldLoop, foldVector, foldVectorG, foldVectorWithIndex | 25 | foldLoop, foldVector, foldVectorG, foldVectorWithIndex, |
26 | toByteString, fromByteString | ||
26 | ) where | 27 | ) where |
27 | 28 | ||
28 | import Data.Packed.Internal.Vector | 29 | import Data.Packed.Internal.Vector |
@@ -35,6 +36,12 @@ import Foreign.Storable | |||
35 | import Data.Binary | 36 | import Data.Binary |
36 | import Control.Monad(replicateM) | 37 | import Control.Monad(replicateM) |
37 | 38 | ||
39 | import Data.ByteString.Internal as BS | ||
40 | import Foreign.ForeignPtr(castForeignPtr) | ||
41 | import Data.Vector.Storable.Internal(updPtr) | ||
42 | import Foreign.Ptr(plusPtr) | ||
43 | |||
44 | |||
38 | -- a 64K cache, with a Double taking 13 bytes in Bytestring, | 45 | -- a 64K cache, with a Double taking 13 bytes in Bytestring, |
39 | -- implies a chunk size of 5041 | 46 | -- implies a chunk size of 5041 |
40 | chunk :: Int | 47 | chunk :: Int |
@@ -43,28 +50,51 @@ chunk = 5000 | |||
43 | chunks :: Int -> [Int] | 50 | chunks :: Int -> [Int] |
44 | chunks d = let c = d `div` chunk | 51 | chunks d = let c = d `div` chunk |
45 | m = d `mod` chunk | 52 | m = d `mod` chunk |
46 | in if m /= 0 then reverse (m:(replicate c chunk)) else (replicate c chunk) | 53 | in if m /= 0 then reverse (m:(replicate c chunk)) else (replicate c chunk) |
47 | 54 | ||
48 | putVector v = do | 55 | putVector v = mapM_ put $! toList v |
49 | let d = dim v | ||
50 | mapM_ (\i -> put $ v @> i) [0..(d-1)] | ||
51 | 56 | ||
52 | getVector d = do | 57 | getVector d = do |
53 | xs <- replicateM d get | 58 | xs <- replicateM d get |
54 | return $! fromList xs | 59 | return $! fromList xs |
55 | 60 | ||
61 | -------------------------------------------------------------------------------- | ||
62 | |||
63 | toByteString :: Storable t => Vector t -> ByteString | ||
64 | toByteString v = BS.PS (castForeignPtr fp) (sz*o) (sz * dim v) | ||
65 | where | ||
66 | (fp,o,_n) = unsafeToForeignPtr v | ||
67 | sz = sizeOf (v@>0) | ||
68 | |||
69 | |||
70 | fromByteString :: Storable t => ByteString -> Vector t | ||
71 | fromByteString (BS.PS fp o n) = r | ||
72 | where | ||
73 | r = unsafeFromForeignPtr (castForeignPtr (updPtr (`plusPtr` o) fp)) 0 n' | ||
74 | n' = n `div` sz | ||
75 | sz = sizeOf (r@>0) | ||
76 | |||
77 | -------------------------------------------------------------------------------- | ||
78 | |||
56 | instance (Binary a, Storable a) => Binary (Vector a) where | 79 | instance (Binary a, Storable a) => Binary (Vector a) where |
80 | |||
57 | put v = do | 81 | put v = do |
58 | let d = dim v | 82 | let d = dim v |
59 | put d | 83 | put d |
60 | mapM_ putVector $! takesV (chunks d) v | 84 | mapM_ putVector $! takesV (chunks d) v |
85 | |||
86 | -- put = put . v2bs | ||
87 | |||
61 | get = do | 88 | get = do |
62 | d <- get | 89 | d <- get |
63 | vs <- mapM getVector $ chunks d | 90 | vs <- mapM getVector $ chunks d |
64 | return $! vjoin vs | 91 | return $! vjoin vs |
65 | 92 | ||
93 | -- get = fmap bs2v get | ||
94 | |||
66 | #endif | 95 | #endif |
67 | 96 | ||
97 | |||
68 | ------------------------------------------------------------------- | 98 | ------------------------------------------------------------------- |
69 | 99 | ||
70 | {- | creates a Vector of the specified length using the supplied function to | 100 | {- | creates a Vector of the specified length using the supplied function to |