From 7163e8027574d2a02e1f852a84d9252c51ade573 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Wed, 18 Jun 2014 07:41:31 +0200 Subject: to/from ByteString --- packages/base/src/Data/Packed/Matrix.hs | 16 ++++------ packages/base/src/Data/Packed/Vector.hs | 40 +++++++++++++++++++++--- packages/base/src/Numeric/LinearAlgebra/Devel.hs | 7 +++-- 3 files changed, 46 insertions(+), 17 deletions(-) (limited to 'packages/base/src') diff --git a/packages/base/src/Data/Packed/Matrix.hs b/packages/base/src/Data/Packed/Matrix.hs index 2420c94..6445ce4 100644 --- a/packages/base/src/Data/Packed/Matrix.hs +++ b/packages/base/src/Data/Packed/Matrix.hs @@ -1,6 +1,7 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE CPP #-} @@ -53,20 +54,15 @@ import Control.Monad(liftM) #ifdef BINARY import Data.Binary -import Control.Monad(replicateM) -instance (Binary a, Element a, Storable a) => Binary (Matrix a) where +instance (Binary (Vector a), Element a) => Binary (Matrix a) where put m = do - let r = rows m - let c = cols m - put r - put c - mapM_ (\i -> mapM_ (\j -> put $ m @@> (i,j)) [0..(c-1)]) [0..(r-1)] + put (cols m) + put (flatten m) get = do - r <- get c <- get - xs <- replicateM r $ replicateM c get - return $ fromLists xs + v <- get + return (reshape c v) #endif diff --git a/packages/base/src/Data/Packed/Vector.hs b/packages/base/src/Data/Packed/Vector.hs index 31dcf47..2104f52 100644 --- a/packages/base/src/Data/Packed/Vector.hs +++ b/packages/base/src/Data/Packed/Vector.hs @@ -22,7 +22,8 @@ module Data.Packed.Vector ( subVector, takesV, vjoin, join, mapVector, mapVectorWithIndex, zipVector, zipVectorWith, unzipVector, unzipVectorWith, mapVectorM, mapVectorM_, mapVectorWithIndexM, mapVectorWithIndexM_, - foldLoop, foldVector, foldVectorG, foldVectorWithIndex + foldLoop, foldVector, foldVectorG, foldVectorWithIndex, + toByteString, fromByteString ) where import Data.Packed.Internal.Vector @@ -35,6 +36,12 @@ import Foreign.Storable import Data.Binary import Control.Monad(replicateM) +import Data.ByteString.Internal as BS +import Foreign.ForeignPtr(castForeignPtr) +import Data.Vector.Storable.Internal(updPtr) +import Foreign.Ptr(plusPtr) + + -- a 64K cache, with a Double taking 13 bytes in Bytestring, -- implies a chunk size of 5041 chunk :: Int @@ -43,28 +50,51 @@ chunk = 5000 chunks :: Int -> [Int] chunks d = let c = d `div` chunk m = d `mod` chunk - in if m /= 0 then reverse (m:(replicate c chunk)) else (replicate c chunk) + in if m /= 0 then reverse (m:(replicate c chunk)) else (replicate c chunk) -putVector v = do - let d = dim v - mapM_ (\i -> put $ v @> i) [0..(d-1)] +putVector v = mapM_ put $! toList v getVector d = do xs <- replicateM d get return $! fromList xs +-------------------------------------------------------------------------------- + +toByteString :: Storable t => Vector t -> ByteString +toByteString v = BS.PS (castForeignPtr fp) (sz*o) (sz * dim v) + where + (fp,o,_n) = unsafeToForeignPtr v + sz = sizeOf (v@>0) + + +fromByteString :: Storable t => ByteString -> Vector t +fromByteString (BS.PS fp o n) = r + where + r = unsafeFromForeignPtr (castForeignPtr (updPtr (`plusPtr` o) fp)) 0 n' + n' = n `div` sz + sz = sizeOf (r@>0) + +-------------------------------------------------------------------------------- + instance (Binary a, Storable a) => Binary (Vector a) where + put v = do let d = dim v put d mapM_ putVector $! takesV (chunks d) v + + -- put = put . v2bs + get = do d <- get vs <- mapM getVector $ chunks d return $! vjoin vs + -- get = fmap bs2v get + #endif + ------------------------------------------------------------------- {- | creates a Vector of the specified length using the supplied function to diff --git a/packages/base/src/Numeric/LinearAlgebra/Devel.hs b/packages/base/src/Numeric/LinearAlgebra/Devel.hs index fce8b71..55894e0 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Devel.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Devel.hs @@ -49,9 +49,12 @@ module Numeric.LinearAlgebra.Devel( mapMatrixWithIndex, mapMatrixWithIndexM, mapMatrixWithIndexM_, liftMatrix, liftMatrix2, liftMatrix2Auto, - -- * Misc + -- * Sparse representation CSR(..), fromCSR, mkCSR, - GMatrix(..) + GMatrix(..), + + -- * Misc + toByteString, fromByteString ) where -- cgit v1.2.3