diff options
Diffstat (limited to 'Data/BitSyntax.hs')
-rw-r--r-- | Data/BitSyntax.hs | 429 |
1 files changed, 429 insertions, 0 deletions
diff --git a/Data/BitSyntax.hs b/Data/BitSyntax.hs new file mode 100644 index 00000000..dbb43f45 --- /dev/null +++ b/Data/BitSyntax.hs | |||
@@ -0,0 +1,429 @@ | |||
1 | {-# LANGUAGE TemplateHaskell #-} | ||
2 | {-# LANGUAGE ForeignFunctionInterface #-} | ||
3 | -- | This module contains fuctions and templates for building up and breaking | ||
4 | -- down packed bit structures. It's something like Erlang's bit-syntax (or, | ||
5 | -- actually, more like Python's struct module). | ||
6 | -- | ||
7 | -- This code uses Data.ByteString which is included in GHC 6.5 and you can | ||
8 | -- get it for 6.4 at <http://www.cse.unsw.edu.au/~dons/fps.html> | ||
9 | module Data.BitSyntax ( | ||
10 | -- * Building bit structures | ||
11 | -- | The core function here is makeBits, which is a perfectly normal function. | ||
12 | -- Here's an example which makes a SOCKS4a request header: | ||
13 | -- @ | ||
14 | -- makeBits [U8 4, U8 1, U16 80, U32 10, NullTerminated \"username\", | ||
15 | -- NullTerminated \"www.haskell.org\"] | ||
16 | -- @ | ||
17 | BitBlock(..), | ||
18 | makeBits, | ||
19 | -- * Breaking up bit structures | ||
20 | -- | The main function for this is bitSyn, which is a template function and | ||
21 | -- so you'll need to run with @-fth@ to enable template haskell | ||
22 | -- <http://www.haskell.org/th/>. | ||
23 | -- | ||
24 | -- To expand the function you use the splice command: | ||
25 | -- @ | ||
26 | -- $(bitSyn [...]) | ||
27 | -- @ | ||
28 | -- | ||
29 | -- The expanded function has type @ByteString -> (...)@ where the elements of | ||
30 | -- the tuple depend of the argument to bitSyn (that's why it has to be a template | ||
31 | -- function). | ||
32 | -- | ||
33 | -- Heres an example, translated from the Erlang manual, which parses an IP header: | ||
34 | -- | ||
35 | -- @ | ||
36 | -- decodeOptions bs ([_, hlen], _, _, _, _, _, _, _, _, _) | ||
37 | -- | hlen > 5 = return $ BS.splitAt (fromIntegral ((hlen - 5) * 4)) bs | ||
38 | -- | otherwise = return (BS.empty, bs) | ||
39 | -- @ | ||
40 | -- | ||
41 | -- @ | ||
42 | -- ipDecode = $(bitSyn [PackedBits [4, 4], Unsigned 1, Unsigned 2, Unsigned 2, | ||
43 | -- PackedBits [3, 13], Unsigned 1, Unsigned 1, Unsigned 2, | ||
44 | -- Fixed 4, Fixed 4, Context \'decodeOptions, Rest]) | ||
45 | -- @ | ||
46 | -- | ||
47 | -- @ | ||
48 | -- ipPacket = BS.pack [0x45, 0, 0, 0x34, 0xd8, 0xd2, 0x40, 0, 0x40, 0x06, | ||
49 | -- 0xa0, 0xca, 0xac, 0x12, 0x68, 0x4d, 0xac, 0x18, | ||
50 | -- 0x00, 0xaf] | ||
51 | -- @ | ||
52 | -- | ||
53 | -- This function has several weaknesses compared to the Erlang version: The | ||
54 | -- elements of the bit structure are not named in place, instead you have to | ||
55 | -- do a pattern match on the resulting tuple and match up the indexes. The | ||
56 | -- type system helps in this, but it's still not quite as nice. | ||
57 | |||
58 | ReadType(..), bitSyn, | ||
59 | |||
60 | -- I get errors if these aren't exported (Can't find interface-file | ||
61 | -- declaration for Data.BitSyntax.decodeU16) | ||
62 | decodeU8, decodeU16, decodeU32, decodeU16LE, decodeU32LE) where | ||
63 | |||
64 | import Language.Haskell.TH.Lib | ||
65 | import Language.Haskell.TH.Syntax | ||
66 | |||
67 | import qualified Data.ByteString as BS | ||
68 | import Data.Char (ord) | ||
69 | import Control.Monad | ||
70 | -- import Test.QuickCheck (Arbitrary(), arbitrary, Gen()) | ||
71 | |||
72 | import Foreign | ||
73 | |||
74 | foreign import ccall unsafe "htonl" htonl :: Word32 -> Word32 | ||
75 | foreign import ccall unsafe "htons" htons :: Word16 -> Word16 | ||
76 | |||
77 | -- There's no good way to convert to little-endian. The htons functions only | ||
78 | -- convert to big endian and they don't have any little endian friends. So we | ||
79 | -- need to detect which kind of system we are on and act accordingly. We can | ||
80 | -- detect the type of system by seeing if htonl actaully doesn't anything (it's | ||
81 | -- the identity function on big-endian systems, of course). If it doesn't we're | ||
82 | -- on a big-endian system and so need to do the byte-swapping in Haskell because | ||
83 | -- the C functions are no-ops | ||
84 | |||
85 | -- | A native Haskell version of htonl for the case where we need to convert | ||
86 | -- to little-endian on a big-endian system | ||
87 | endianSwitch32 :: Word32 -> Word32 | ||
88 | endianSwitch32 a = ((a .&. 0xff) `shiftL` 24) .|. | ||
89 | ((a .&. 0xff00) `shiftL` 8) .|. | ||
90 | ((a .&. 0xff0000) `shiftR` 8) .|. | ||
91 | (a `shiftR` 24) | ||
92 | |||
93 | -- | A native Haskell version of htons for the case where we need to convert | ||
94 | -- to little-endian on a big-endian system | ||
95 | endianSwitch16 :: Word16 -> Word16 | ||
96 | endianSwitch16 a = ((a .&. 0xff) `shiftL` 8) .|. | ||
97 | (a `shiftR` 8) | ||
98 | |||
99 | littleEndian32 :: Word32 -> Word32 | ||
100 | littleEndian32 a = if htonl 1 == 1 | ||
101 | then endianSwitch32 a | ||
102 | else a | ||
103 | |||
104 | littleEndian16 :: Word16 -> Word16 | ||
105 | littleEndian16 a = if htonl 1 == 1 | ||
106 | then endianSwitch16 a | ||
107 | else a | ||
108 | |||
109 | data BitBlock = -- | Unsigned 8-bit int | ||
110 | U8 Int | | ||
111 | -- | Unsigned 16-bit int | ||
112 | U16 Int | | ||
113 | -- | Unsigned 32-bit int | ||
114 | U32 Int | | ||
115 | -- | Little-endian, unsigned 16-bit int | ||
116 | U16LE Int | | ||
117 | -- | Little-endian, unsigned 32-bit int | ||
118 | U32LE Int | | ||
119 | -- | Appends the string with a trailing NUL byte | ||
120 | NullTerminated String | | ||
121 | -- | Appends the string without any terminator | ||
122 | RawString String | | ||
123 | -- | Appends a ByteString | ||
124 | RawByteString BS.ByteString | | ||
125 | -- | Packs a series of bit fields together. The argument is | ||
126 | -- a list of pairs where the first element is the size | ||
127 | -- (in bits) and the second is the value. The sum of the | ||
128 | -- sizes for a given PackBits must be a multiple of 8 | ||
129 | PackBits [(Int, Int)] | ||
130 | deriving (Show) | ||
131 | |||
132 | -- Encodes a member of the Bits class as a series of bytes and returns the | ||
133 | -- ByteString of those bytes. | ||
134 | getBytes :: (Integral a, Bounded a, Bits a) => a -> BS.ByteString | ||
135 | getBytes input = | ||
136 | let getByte _ 0 = [] | ||
137 | getByte x remaining = (fromIntegral $ (x .&. 0xff)) : | ||
138 | getByte (shiftR x 8) (remaining - 1) | ||
139 | in | ||
140 | if (bitSize input `mod` 8) /= 0 | ||
141 | then error "Input data bit size must be a multiple of 8" | ||
142 | else BS.pack $ getByte input (bitSize input `div` 8) | ||
143 | |||
144 | -- Performs the work behind PackBits | ||
145 | packBits :: (Word8, Int, [Word8]) -- ^ The current byte, the number of bits | ||
146 | -- used in that byte and the (reverse) | ||
147 | -- list of produced bytes | ||
148 | -> (Int, Int) -- ^ The size (in bits) of the value, and the value | ||
149 | -> (Word8, Int, [Word8]) -- See first argument | ||
150 | packBits (current, used, bytes) (size, value) = | ||
151 | if bitsWritten < size | ||
152 | then packBits (0, 0, current' : bytes) (size - bitsWritten, value) | ||
153 | else if used' == 8 | ||
154 | then (0, 0, current' : bytes) | ||
155 | else (current', used', bytes) | ||
156 | where | ||
157 | top = size - 1 | ||
158 | topOfByte = 7 - used | ||
159 | aligned = value `shift` (topOfByte - top) | ||
160 | newBits = (fromIntegral aligned) :: Word8 | ||
161 | current' = current .|. newBits | ||
162 | bitsWritten = min (8 - used) size | ||
163 | used' = used + bitsWritten | ||
164 | |||
165 | bits :: BitBlock -> BS.ByteString | ||
166 | bits (U8 v) = BS.pack [((fromIntegral v) :: Word8)] | ||
167 | bits (U16 v) = getBytes ((htons $ fromIntegral v) :: Word16) | ||
168 | bits (U32 v) = getBytes ((htonl $ fromIntegral v) :: Word32) | ||
169 | bits (U16LE v) = getBytes (littleEndian16 $ fromIntegral v) | ||
170 | bits (U32LE v) = getBytes (littleEndian32 $ fromIntegral v) | ||
171 | bits (NullTerminated str) = BS.pack $ (map (fromIntegral . ord) str) ++ [0] | ||
172 | bits (RawString str) = BS.pack $ map (fromIntegral . ord) str | ||
173 | bits (RawByteString bs) = bs | ||
174 | bits (PackBits bitspec) = | ||
175 | if (sum $ map fst bitspec) `mod` 8 /= 0 | ||
176 | then error "Sum of sizes of a bit spec must == 0 mod 8" | ||
177 | else (\(_, _, a) -> BS.pack $ reverse a) $ foldl packBits (0, 0, []) bitspec | ||
178 | |||
179 | -- | Make a binary string from the list of elements given | ||
180 | makeBits :: [BitBlock] -> BS.ByteString | ||
181 | makeBits = BS.concat . (map bits) | ||
182 | |||
183 | data ReadType = -- | An unsigned number of some number of bytes. Valid | ||
184 | -- arguments are 1, 2 and 4 | ||
185 | Unsigned Integer | | ||
186 | -- | An unsigned, little-endian integer of some number of | ||
187 | -- bytes. Valid arguments are 2 and 4 | ||
188 | UnsignedLE Integer | | ||
189 | -- | A variable length element to be decoded by a custom | ||
190 | -- function. The function's name is given as the single | ||
191 | -- argument and should have type | ||
192 | -- @Monad m => ByteString -> m (v, ByteString)@ | ||
193 | Variable Name | | ||
194 | -- | Skip some number of bytes | ||
195 | Skip Integer | | ||
196 | -- | A fixed size field, the result of which is a ByteString | ||
197 | -- of that length. | ||
198 | Fixed Integer | | ||
199 | -- | Decode a value and ignore it (the result will not be part | ||
200 | -- of the returned tuple) | ||
201 | Ignore ReadType | | ||
202 | -- | Like variable, but the decoding function is passed the | ||
203 | -- entire result tuple so far. Thus the function whose name | ||
204 | -- passed has type | ||
205 | -- @Monad m => ByteString -> (...) -> m (v, ByteString)@ | ||
206 | Context Name | | ||
207 | -- | Takes the most recent element of the result tuple and | ||
208 | -- interprets it as the length of this field. Results in | ||
209 | -- a ByteString | ||
210 | LengthPrefixed | | ||
211 | -- | Decode a series of bit fields, results in a list of | ||
212 | -- Integers. Each element of the argument is the length of | ||
213 | -- the bit field. The sums of the lengths must be a multiple | ||
214 | -- of 8 | ||
215 | PackedBits [Integer] | | ||
216 | -- | Results in a ByteString containing the undecoded bytes so | ||
217 | -- far. Generally used at the end to return the trailing body | ||
218 | -- of a structure, it can actually be used at any point in the | ||
219 | -- decoding to return the trailing part at that point. | ||
220 | Rest | ||
221 | |||
222 | fromBytes :: (Num a, Bits a) => [a] -> a | ||
223 | fromBytes input = | ||
224 | let dofb accum [] = accum | ||
225 | dofb accum (x:xs) = dofb ((shiftL accum 8) .|. x) xs | ||
226 | in | ||
227 | dofb 0 $ reverse input | ||
228 | |||
229 | decodeU8 :: BS.ByteString -> Word8 | ||
230 | decodeU8 = fromIntegral . head . BS.unpack | ||
231 | decodeU16 :: BS.ByteString -> Word16 | ||
232 | decodeU16 = htons . fromBytes . map fromIntegral . BS.unpack | ||
233 | decodeU32 :: BS.ByteString -> Word32 | ||
234 | decodeU32 = htonl . fromBytes . map fromIntegral . BS.unpack | ||
235 | decodeU16LE :: BS.ByteString -> Word16 | ||
236 | decodeU16LE = littleEndian16 . fromBytes . map fromIntegral . BS.unpack | ||
237 | decodeU32LE :: BS.ByteString -> Word32 | ||
238 | decodeU32LE = littleEndian32 . fromBytes . map fromIntegral . BS.unpack | ||
239 | |||
240 | decodeBits :: [Integer] -> BS.ByteString -> [Integer] | ||
241 | decodeBits sizes bs = | ||
242 | reverse values | ||
243 | where | ||
244 | (values, _, _) = foldl unpackBits ([], 0, BS.unpack bitdata) sizes | ||
245 | bytesize = (sum sizes) `shiftR` 3 | ||
246 | (bitdata, _) = BS.splitAt (fromIntegral bytesize) bs | ||
247 | |||
248 | unpackBits :: ([Integer], Integer, [Word8]) -> Integer -> ([Integer], Integer, [Word8]) | ||
249 | unpackBits state size = unpackBitsInner 0 state size | ||
250 | |||
251 | unpackBitsInner :: Integer -> | ||
252 | ([Integer], Integer, [Word8]) -> | ||
253 | Integer -> | ||
254 | ([Integer], Integer, [Word8]) | ||
255 | unpackBitsInner _ (output, used, []) _ = (output, used, []) | ||
256 | unpackBitsInner val (output, used, current : input) bitsToGet = | ||
257 | if bitsToGet' > 0 | ||
258 | then unpackBitsInner val'' (output, 0, input) bitsToGet' | ||
259 | else if used' < 8 | ||
260 | then (val'' : output, used', current'' : input) | ||
261 | else (val'' : output, 0, input) | ||
262 | where | ||
263 | bitsAv = 8 - used | ||
264 | bitsTaken = min bitsAv bitsToGet | ||
265 | val' = val `shift` (fromIntegral bitsTaken) | ||
266 | current' = current `shiftR` (fromIntegral (8 - bitsTaken)) | ||
267 | current'' = current `shiftL` (fromIntegral bitsTaken) | ||
268 | val'' = val' .|. (fromIntegral current') | ||
269 | bitsToGet' = bitsToGet - bitsTaken | ||
270 | used' = used + bitsTaken | ||
271 | |||
272 | readElement :: ([Stmt], Name, [Name]) -> ReadType -> Q ([Stmt], Name, [Name]) | ||
273 | |||
274 | readElement (stmts, inputname, tuplenames) (Context funcname) = do | ||
275 | valname <- newName "val" | ||
276 | restname <- newName "rest" | ||
277 | |||
278 | let stmt = BindS (TupP [VarP valname, VarP restname]) | ||
279 | (AppE (AppE (VarE funcname) | ||
280 | (VarE inputname)) | ||
281 | (TupE $ map VarE $ reverse tuplenames)) | ||
282 | |||
283 | return (stmt : stmts, restname, valname : tuplenames) | ||
284 | |||
285 | readElement (stmts, inputname, tuplenames) (Fixed n) = do | ||
286 | valname <- newName "val" | ||
287 | restname <- newName "rest" | ||
288 | let dec1 = ValD (TupP [VarP valname, VarP restname]) | ||
289 | (NormalB $ AppE (AppE (VarE 'BS.splitAt) | ||
290 | (LitE (IntegerL n))) | ||
291 | (VarE inputname)) | ||
292 | [] | ||
293 | |||
294 | return (LetS [dec1] : stmts, restname, valname : tuplenames) | ||
295 | |||
296 | readElement state@(_, _, tuplenames) (Ignore n) = do | ||
297 | (a, b, _) <- readElement state n | ||
298 | return (a, b, tuplenames) | ||
299 | |||
300 | readElement (stmts, inputname, tuplenames) LengthPrefixed = do | ||
301 | valname <- newName "val" | ||
302 | restname <- newName "rest" | ||
303 | |||
304 | let sourcename = head tuplenames | ||
305 | dec = ValD (TupP [VarP valname, VarP restname]) | ||
306 | (NormalB $ AppE (AppE (VarE 'BS.splitAt) | ||
307 | (AppE (VarE 'fromIntegral) | ||
308 | (VarE sourcename))) | ||
309 | (VarE inputname)) | ||
310 | [] | ||
311 | |||
312 | return (LetS [dec] : stmts, restname, valname : tuplenames) | ||
313 | |||
314 | readElement (stmts, inputname, tuplenames) (Variable funcname) = do | ||
315 | valname <- newName "val" | ||
316 | restname <- newName "rest" | ||
317 | |||
318 | let stmt = BindS (TupP [VarP valname, VarP restname]) | ||
319 | (AppE (VarE funcname) (VarE inputname)) | ||
320 | |||
321 | return (stmt : stmts, restname, valname : tuplenames) | ||
322 | |||
323 | readElement (stmts, inputname, tuplenames) Rest = do | ||
324 | restname <- newName "rest" | ||
325 | let dec = ValD (VarP restname) | ||
326 | (NormalB $ VarE inputname) | ||
327 | [] | ||
328 | return (LetS [dec] : stmts, inputname, restname : tuplenames) | ||
329 | |||
330 | readElement (stmts, inputname, tuplenames) (Skip n) = do | ||
331 | -- Expands to something like: | ||
332 | -- rest = Data.ByteString.drop n input | ||
333 | restname <- newName "rest" | ||
334 | let dec = ValD (VarP restname) | ||
335 | (NormalB $ AppE (AppE (VarE 'BS.drop) | ||
336 | (LitE (IntegerL n))) | ||
337 | (VarE inputname)) | ||
338 | [] | ||
339 | return (LetS [dec] : stmts, restname, tuplenames) | ||
340 | |||
341 | readElement state (Unsigned size) = do | ||
342 | -- Expands to something like: | ||
343 | -- (aval, arest) = Data.ByteString.splitAt 1 input | ||
344 | -- a = BitSyntax.decodeU8 aval | ||
345 | let decodefunc = case size of | ||
346 | 1 -> 'decodeU8 | ||
347 | 2 -> 'decodeU16 | ||
348 | _ -> 'decodeU32 -- Default to 32 | ||
349 | decodeHelper state (VarE decodefunc) size | ||
350 | |||
351 | readElement state (UnsignedLE size) = do | ||
352 | -- Expands to something like: | ||
353 | -- (aval, arest) = Data.ByteString.splitAt 1 input | ||
354 | -- a = BitSyntax.decodeU8LE aval | ||
355 | let decodefunc = case size of | ||
356 | 2 -> 'decodeU16LE | ||
357 | _ -> 'decodeU32LE -- Default to 4 | ||
358 | decodeHelper state (VarE decodefunc) size | ||
359 | |||
360 | readElement state (PackedBits sizes) = | ||
361 | if sum sizes `mod` 8 /= 0 | ||
362 | then error "Sizes of packed bits must == 0 mod 8" | ||
363 | else decodeHelper state | ||
364 | (AppE (VarE 'decodeBits) | ||
365 | (ListE $ map (LitE . IntegerL) sizes)) | ||
366 | ((sum sizes) `shiftR` 3) | ||
367 | |||
368 | decodeHelper :: ([Stmt], Name, [Name]) -> Exp | ||
369 | -> Integer | ||
370 | -> Q ([Stmt], Name, [Name]) | ||
371 | decodeHelper (stmts, inputname, tuplenames) decodefunc size = do | ||
372 | valname <- newName "val" | ||
373 | restname <- newName "rest" | ||
374 | tuplename <- newName "tup" | ||
375 | let dec1 = ValD (TupP [VarP valname, VarP restname]) | ||
376 | (NormalB $ AppE (AppE (VarE 'BS.splitAt) | ||
377 | (LitE (IntegerL size))) | ||
378 | (VarE inputname)) | ||
379 | [] | ||
380 | let dec2 = ValD (VarP tuplename) | ||
381 | (NormalB $ AppE decodefunc (VarE valname)) | ||
382 | [] | ||
383 | |||
384 | return (LetS [dec1, dec2] : stmts, restname, tuplename : tuplenames) | ||
385 | |||
386 | decGetName :: Dec -> Name | ||
387 | decGetName (ValD (VarP name) _ _) = name | ||
388 | decGetName _ = undefined -- Error! | ||
389 | |||
390 | bitSyn :: [ReadType] -> Q Exp | ||
391 | bitSyn elements = do | ||
392 | inputname <- newName "input" | ||
393 | (stmts, restname, tuplenames) <- foldM readElement ([], inputname, []) elements | ||
394 | returnS <- NoBindS `liftM` [| return $(tupE . map varE $ reverse tuplenames) |] | ||
395 | return $ LamE [VarP inputname] (DoE . reverse $ returnS : stmts) | ||
396 | |||
397 | |||
398 | -- Tests | ||
399 | prop_bitPacking :: [(Int, Int)] -> Bool | ||
400 | prop_bitPacking fields = | ||
401 | prevalues == (map fromIntegral postvalues) || | ||
402 | any (< 1) (map fst fields) || | ||
403 | any (< 0) (map snd fields) | ||
404 | where | ||
405 | undershoot = sum (map fst fields) `mod` 8 | ||
406 | fields' = if undershoot > 0 | ||
407 | then (8 - undershoot, 1) : fields | ||
408 | else fields | ||
409 | prevalues = map snd fields' | ||
410 | packed = bits $ PackBits fields' | ||
411 | postvalues = decodeBits (map (fromIntegral . fst) fields') packed | ||
412 | |||
413 | {- | ||
414 | instance Arbitrary Word16 where | ||
415 | arbitrary = (arbitrary :: Gen Int) >>= return . fromIntegral | ||
416 | instance Arbitrary Word32 where | ||
417 | arbitrary = (arbitrary :: Gen Int) >>= return . fromIntegral | ||
418 | -} | ||
419 | |||
420 | -- | This only works on little-endian machines as it checks that the foreign | ||
421 | -- functions (htonl and htons) match the native ones | ||
422 | prop_nativeByteShuffle32 :: Word32 -> Bool | ||
423 | prop_nativeByteShuffle32 x = endianSwitch32 x == htonl x | ||
424 | prop_nativeByteShuffle16 :: Word16 -> Bool | ||
425 | prop_nativeByteShuffle16 x = endianSwitch16 x == htons x | ||
426 | prop_littleEndian16 :: Word16 -> Bool | ||
427 | prop_littleEndian16 x = littleEndian16 x == x | ||
428 | prop_littleEndian32 :: Word32 -> Bool | ||
429 | prop_littleEndian32 x = littleEndian32 x == x | ||