diff options
Diffstat (limited to 'Data/BitSyntax.hs')
-rw-r--r-- | Data/BitSyntax.hs | 428 |
1 files changed, 428 insertions, 0 deletions
diff --git a/Data/BitSyntax.hs b/Data/BitSyntax.hs new file mode 100644 index 00000000..aebdd522 --- /dev/null +++ b/Data/BitSyntax.hs | |||
@@ -0,0 +1,428 @@ | |||
1 | {-# OPTIONS_GHC -fth #-} | ||
2 | -- | This module contains fuctions and templates for building up and breaking | ||
3 | -- down packed bit structures. It's something like Erlang's bit-syntax (or, | ||
4 | -- actually, more like Python's struct module). | ||
5 | -- | ||
6 | -- This code uses Data.ByteString which is included in GHC 6.5 and you can | ||
7 | -- get it for 6.4 at <http://www.cse.unsw.edu.au/~dons/fps.html> | ||
8 | module Data.BitSyntax ( | ||
9 | -- * Building bit structures | ||
10 | -- | The core function here is makeBits, which is a perfectly normal function. | ||
11 | -- Here's an example which makes a SOCKS4a request header: | ||
12 | -- @ | ||
13 | -- makeBits [U8 4, U8 1, U16 80, U32 10, NullTerminated \"username\", | ||
14 | -- NullTerminated \"www.haskell.org\"] | ||
15 | -- @ | ||
16 | BitBlock(..), | ||
17 | makeBits, | ||
18 | -- * Breaking up bit structures | ||
19 | -- | The main function for this is bitSyn, which is a template function and | ||
20 | -- so you'll need to run with @-fth@ to enable template haskell | ||
21 | -- <http://www.haskell.org/th/>. | ||
22 | -- | ||
23 | -- To expand the function you use the splice command: | ||
24 | -- @ | ||
25 | -- $(bitSyn [...]) | ||
26 | -- @ | ||
27 | -- | ||
28 | -- The expanded function has type @ByteString -> (...)@ where the elements of | ||
29 | -- the tuple depend of the argument to bitSyn (that's why it has to be a template | ||
30 | -- function). | ||
31 | -- | ||
32 | -- Heres an example, translated from the Erlang manual, which parses an IP header: | ||
33 | -- | ||
34 | -- @ | ||
35 | -- decodeOptions bs ([_, hlen], _, _, _, _, _, _, _, _, _) | ||
36 | -- | hlen > 5 = return $ BS.splitAt (fromIntegral ((hlen - 5) * 4)) bs | ||
37 | -- | otherwise = return (BS.empty, bs) | ||
38 | -- @ | ||
39 | -- | ||
40 | -- @ | ||
41 | -- ipDecode = $(bitSyn [PackedBits [4, 4], Unsigned 1, Unsigned 2, Unsigned 2, | ||
42 | -- PackedBits [3, 13], Unsigned 1, Unsigned 1, Unsigned 2, | ||
43 | -- Fixed 4, Fixed 4, Context \'decodeOptions, Rest]) | ||
44 | -- @ | ||
45 | -- | ||
46 | -- @ | ||
47 | -- ipPacket = BS.pack [0x45, 0, 0, 0x34, 0xd8, 0xd2, 0x40, 0, 0x40, 0x06, | ||
48 | -- 0xa0, 0xca, 0xac, 0x12, 0x68, 0x4d, 0xac, 0x18, | ||
49 | -- 0x00, 0xaf] | ||
50 | -- @ | ||
51 | -- | ||
52 | -- This function has several weaknesses compared to the Erlang version: The | ||
53 | -- elements of the bit structure are not named in place, instead you have to | ||
54 | -- do a pattern match on the resulting tuple and match up the indexes. The | ||
55 | -- type system helps in this, but it's still not quite as nice. | ||
56 | |||
57 | ReadType(..), bitSyn, | ||
58 | |||
59 | -- I get errors if these aren't exported (Can't find interface-file | ||
60 | -- declaration for Data.BitSyntax.decodeU16) | ||
61 | decodeU8, decodeU16, decodeU32, decodeU16LE, decodeU32LE) where | ||
62 | |||
63 | import Language.Haskell.TH.Lib | ||
64 | import Language.Haskell.TH.Syntax | ||
65 | |||
66 | import qualified Data.ByteString as BS | ||
67 | import Data.Char (ord) | ||
68 | import Control.Monad | ||
69 | import Test.QuickCheck (Arbitrary(), arbitrary, Gen()) | ||
70 | |||
71 | import Foreign | ||
72 | |||
73 | foreign import ccall unsafe "htonl" htonl :: Word32 -> Word32 | ||
74 | foreign import ccall unsafe "htons" htons :: Word16 -> Word16 | ||
75 | |||
76 | -- There's no good way to convert to little-endian. The htons functions only | ||
77 | -- convert to big endian and they don't have any little endian friends. So we | ||
78 | -- need to detect which kind of system we are on and act accordingly. We can | ||
79 | -- detect the type of system by seeing if htonl actaully doesn't anything (it's | ||
80 | -- the identity function on big-endian systems, of course). If it doesn't we're | ||
81 | -- on a big-endian system and so need to do the byte-swapping in Haskell because | ||
82 | -- the C functions are no-ops | ||
83 | |||
84 | -- | A native Haskell version of htonl for the case where we need to convert | ||
85 | -- to little-endian on a big-endian system | ||
86 | endianSwitch32 :: Word32 -> Word32 | ||
87 | endianSwitch32 a = ((a .&. 0xff) `shiftL` 24) .|. | ||
88 | ((a .&. 0xff00) `shiftL` 8) .|. | ||
89 | ((a .&. 0xff0000) `shiftR` 8) .|. | ||
90 | (a `shiftR` 24) | ||
91 | |||
92 | -- | A native Haskell version of htons for the case where we need to convert | ||
93 | -- to little-endian on a big-endian system | ||
94 | endianSwitch16 :: Word16 -> Word16 | ||
95 | endianSwitch16 a = ((a .&. 0xff) `shiftL` 8) .|. | ||
96 | (a `shiftR` 8) | ||
97 | |||
98 | littleEndian32 :: Word32 -> Word32 | ||
99 | littleEndian32 a = if htonl 1 == 1 | ||
100 | then endianSwitch32 a | ||
101 | else a | ||
102 | |||
103 | littleEndian16 :: Word16 -> Word16 | ||
104 | littleEndian16 a = if htonl 1 == 1 | ||
105 | then endianSwitch16 a | ||
106 | else a | ||
107 | |||
108 | data BitBlock = -- | Unsigned 8-bit int | ||
109 | U8 Int | | ||
110 | -- | Unsigned 16-bit int | ||
111 | U16 Int | | ||
112 | -- | Unsigned 32-bit int | ||
113 | U32 Int | | ||
114 | -- | Little-endian, unsigned 16-bit int | ||
115 | U16LE Int | | ||
116 | -- | Little-endian, unsigned 32-bit int | ||
117 | U32LE Int | | ||
118 | -- | Appends the string with a trailing NUL byte | ||
119 | NullTerminated String | | ||
120 | -- | Appends the string without any terminator | ||
121 | RawString String | | ||
122 | -- | Appends a ByteString | ||
123 | RawByteString BS.ByteString | | ||
124 | -- | Packs a series of bit fields together. The argument is | ||
125 | -- a list of pairs where the first element is the size | ||
126 | -- (in bits) and the second is the value. The sum of the | ||
127 | -- sizes for a given PackBits must be a multiple of 8 | ||
128 | PackBits [(Int, Int)] | ||
129 | deriving (Show) | ||
130 | |||
131 | -- Encodes a member of the Bits class as a series of bytes and returns the | ||
132 | -- ByteString of those bytes. | ||
133 | getBytes :: (Integral a, Bounded a, Bits a) => a -> BS.ByteString | ||
134 | getBytes input = | ||
135 | let getByte _ 0 = [] | ||
136 | getByte x remaining = (fromIntegral $ (x .&. 0xff)) : | ||
137 | getByte (shiftR x 8) (remaining - 1) | ||
138 | in | ||
139 | if (bitSize input `mod` 8) /= 0 | ||
140 | then error "Input data bit size must be a multiple of 8" | ||
141 | else BS.pack $ getByte input (bitSize input `div` 8) | ||
142 | |||
143 | -- Performs the work behind PackBits | ||
144 | packBits :: (Word8, Int, [Word8]) -- ^ The current byte, the number of bits | ||
145 | -- used in that byte and the (reverse) | ||
146 | -- list of produced bytes | ||
147 | -> (Int, Int) -- ^ The size (in bits) of the value, and the value | ||
148 | -> (Word8, Int, [Word8]) -- See first argument | ||
149 | packBits (current, used, bytes) (size, value) = | ||
150 | if bitsWritten < size | ||
151 | then packBits (0, 0, current' : bytes) (size - bitsWritten, value) | ||
152 | else if used' == 8 | ||
153 | then (0, 0, current' : bytes) | ||
154 | else (current', used', bytes) | ||
155 | where | ||
156 | top = size - 1 | ||
157 | topOfByte = 7 - used | ||
158 | aligned = value `shift` (topOfByte - top) | ||
159 | newBits = (fromIntegral aligned) :: Word8 | ||
160 | current' = current .|. newBits | ||
161 | bitsWritten = min (8 - used) size | ||
162 | used' = used + bitsWritten | ||
163 | |||
164 | bits :: BitBlock -> BS.ByteString | ||
165 | bits (U8 v) = BS.pack [((fromIntegral v) :: Word8)] | ||
166 | bits (U16 v) = getBytes ((htons $ fromIntegral v) :: Word16) | ||
167 | bits (U32 v) = getBytes ((htonl $ fromIntegral v) :: Word32) | ||
168 | bits (U16LE v) = getBytes (littleEndian16 $ fromIntegral v) | ||
169 | bits (U32LE v) = getBytes (littleEndian32 $ fromIntegral v) | ||
170 | bits (NullTerminated str) = BS.pack $ (map (fromIntegral . ord) str) ++ [0] | ||
171 | bits (RawString str) = BS.pack $ map (fromIntegral . ord) str | ||
172 | bits (RawByteString bs) = bs | ||
173 | bits (PackBits bitspec) = | ||
174 | if (sum $ map fst bitspec) `mod` 8 /= 0 | ||
175 | then error "Sum of sizes of a bit spec must == 0 mod 8" | ||
176 | else (\(_, _, a) -> BS.pack $ reverse a) $ foldl packBits (0, 0, []) bitspec | ||
177 | |||
178 | -- | Make a binary string from the list of elements given | ||
179 | makeBits :: [BitBlock] -> BS.ByteString | ||
180 | makeBits = BS.concat . (map bits) | ||
181 | |||
182 | data ReadType = -- | An unsigned number of some number of bytes. Valid | ||
183 | -- arguments are 1, 2 and 4 | ||
184 | Unsigned Integer | | ||
185 | -- | An unsigned, little-endian integer of some number of | ||
186 | -- bytes. Valid arguments are 2 and 4 | ||
187 | UnsignedLE Integer | | ||
188 | -- | A variable length element to be decoded by a custom | ||
189 | -- function. The function's name is given as the single | ||
190 | -- argument and should have type | ||
191 | -- @Monad m => ByteString -> m (v, ByteString)@ | ||
192 | Variable Name | | ||
193 | -- | Skip some number of bytes | ||
194 | Skip Integer | | ||
195 | -- | A fixed size field, the result of which is a ByteString | ||
196 | -- of that length. | ||
197 | Fixed Integer | | ||
198 | -- | Decode a value and ignore it (the result will not be part | ||
199 | -- of the returned tuple) | ||
200 | Ignore ReadType | | ||
201 | -- | Like variable, but the decoding function is passed the | ||
202 | -- entire result tuple so far. Thus the function whose name | ||
203 | -- passed has type | ||
204 | -- @Monad m => ByteString -> (...) -> m (v, ByteString)@ | ||
205 | Context Name | | ||
206 | -- | Takes the most recent element of the result tuple and | ||
207 | -- interprets it as the length of this field. Results in | ||
208 | -- a ByteString | ||
209 | LengthPrefixed | | ||
210 | -- | Decode a series of bit fields, results in a list of | ||
211 | -- Integers. Each element of the argument is the length of | ||
212 | -- the bit field. The sums of the lengths must be a multiple | ||
213 | -- of 8 | ||
214 | PackedBits [Integer] | | ||
215 | -- | Results in a ByteString containing the undecoded bytes so | ||
216 | -- far. Generally used at the end to return the trailing body | ||
217 | -- of a structure, it can actually be used at any point in the | ||
218 | -- decoding to return the trailing part at that point. | ||
219 | Rest | ||
220 | |||
221 | fromBytes :: (Bits a) => [a] -> a | ||
222 | fromBytes input = | ||
223 | let dofb accum [] = accum | ||
224 | dofb accum (x:xs) = dofb ((shiftL accum 8) .|. x) xs | ||
225 | in | ||
226 | dofb 0 $ reverse input | ||
227 | |||
228 | decodeU8 :: BS.ByteString -> Word8 | ||
229 | decodeU8 = fromIntegral . head . BS.unpack | ||
230 | decodeU16 :: BS.ByteString -> Word16 | ||
231 | decodeU16 = htons . fromBytes . map fromIntegral . BS.unpack | ||
232 | decodeU32 :: BS.ByteString -> Word32 | ||
233 | decodeU32 = htonl . fromBytes . map fromIntegral . BS.unpack | ||
234 | decodeU16LE :: BS.ByteString -> Word16 | ||
235 | decodeU16LE = littleEndian16 . fromBytes . map fromIntegral . BS.unpack | ||
236 | decodeU32LE :: BS.ByteString -> Word32 | ||
237 | decodeU32LE = littleEndian32 . fromBytes . map fromIntegral . BS.unpack | ||
238 | |||
239 | decodeBits :: [Integer] -> BS.ByteString -> [Integer] | ||
240 | decodeBits sizes bs = | ||
241 | reverse values | ||
242 | where | ||
243 | (values, _, _) = foldl unpackBits ([], 0, BS.unpack bitdata) sizes | ||
244 | bytesize = (sum sizes) `shiftR` 3 | ||
245 | (bitdata, _) = BS.splitAt (fromIntegral bytesize) bs | ||
246 | |||
247 | unpackBits :: ([Integer], Integer, [Word8]) -> Integer -> ([Integer], Integer, [Word8]) | ||
248 | unpackBits state size = unpackBitsInner 0 state size | ||
249 | |||
250 | unpackBitsInner :: Integer -> | ||
251 | ([Integer], Integer, [Word8]) -> | ||
252 | Integer -> | ||
253 | ([Integer], Integer, [Word8]) | ||
254 | unpackBitsInner _ (output, used, []) _ = (output, used, []) | ||
255 | unpackBitsInner val (output, used, current : input) bitsToGet = | ||
256 | if bitsToGet' > 0 | ||
257 | then unpackBitsInner val'' (output, 0, input) bitsToGet' | ||
258 | else if used' < 8 | ||
259 | then (val'' : output, used', current'' : input) | ||
260 | else (val'' : output, 0, input) | ||
261 | where | ||
262 | bitsAv = 8 - used | ||
263 | bitsTaken = min bitsAv bitsToGet | ||
264 | val' = val `shift` (fromIntegral bitsTaken) | ||
265 | current' = current `shiftR` (fromIntegral (8 - bitsTaken)) | ||
266 | current'' = current `shiftL` (fromIntegral bitsTaken) | ||
267 | val'' = val' .|. (fromIntegral current') | ||
268 | bitsToGet' = bitsToGet - bitsTaken | ||
269 | used' = used + bitsTaken | ||
270 | |||
271 | readElement :: ([Stmt], Name, [Name]) -> ReadType -> Q ([Stmt], Name, [Name]) | ||
272 | |||
273 | readElement (stmts, inputname, tuplenames) (Context funcname) = do | ||
274 | valname <- newName "val" | ||
275 | restname <- newName "rest" | ||
276 | |||
277 | let stmt = BindS (TupP [VarP valname, VarP restname]) | ||
278 | (AppE (AppE (VarE funcname) | ||
279 | (VarE inputname)) | ||
280 | (TupE $ map VarE $ reverse tuplenames)) | ||
281 | |||
282 | return (stmt : stmts, restname, valname : tuplenames) | ||
283 | |||
284 | readElement (stmts, inputname, tuplenames) (Fixed n) = do | ||
285 | valname <- newName "val" | ||
286 | restname <- newName "rest" | ||
287 | let dec1 = ValD (TupP [VarP valname, VarP restname]) | ||
288 | (NormalB $ AppE (AppE (VarE 'BS.splitAt) | ||
289 | (LitE (IntegerL n))) | ||
290 | (VarE inputname)) | ||
291 | [] | ||
292 | |||
293 | return (LetS [dec1] : stmts, restname, valname : tuplenames) | ||
294 | |||
295 | readElement state@(_, _, tuplenames) (Ignore n) = do | ||
296 | (a, b, _) <- readElement state n | ||
297 | return (a, b, tuplenames) | ||
298 | |||
299 | readElement (stmts, inputname, tuplenames) LengthPrefixed = do | ||
300 | valname <- newName "val" | ||
301 | restname <- newName "rest" | ||
302 | |||
303 | let sourcename = head tuplenames | ||
304 | dec = ValD (TupP [VarP valname, VarP restname]) | ||
305 | (NormalB $ AppE (AppE (VarE 'BS.splitAt) | ||
306 | (AppE (VarE 'fromIntegral) | ||
307 | (VarE sourcename))) | ||
308 | (VarE inputname)) | ||
309 | [] | ||
310 | |||
311 | return (LetS [dec] : stmts, restname, valname : tuplenames) | ||
312 | |||
313 | readElement (stmts, inputname, tuplenames) (Variable funcname) = do | ||
314 | valname <- newName "val" | ||
315 | restname <- newName "rest" | ||
316 | |||
317 | let stmt = BindS (TupP [VarP valname, VarP restname]) | ||
318 | (AppE (VarE funcname) (VarE inputname)) | ||
319 | |||
320 | return (stmt : stmts, restname, valname : tuplenames) | ||
321 | |||
322 | readElement (stmts, inputname, tuplenames) Rest = do | ||
323 | restname <- newName "rest" | ||
324 | let dec = ValD (VarP restname) | ||
325 | (NormalB $ VarE inputname) | ||
326 | [] | ||
327 | return (LetS [dec] : stmts, inputname, restname : tuplenames) | ||
328 | |||
329 | readElement (stmts, inputname, tuplenames) (Skip n) = do | ||
330 | -- Expands to something like: | ||
331 | -- rest = Data.ByteString.drop n input | ||
332 | restname <- newName "rest" | ||
333 | let dec = ValD (VarP restname) | ||
334 | (NormalB $ AppE (AppE (VarE 'BS.drop) | ||
335 | (LitE (IntegerL n))) | ||
336 | (VarE inputname)) | ||
337 | [] | ||
338 | return (LetS [dec] : stmts, restname, tuplenames) | ||
339 | |||
340 | readElement state (Unsigned size) = do | ||
341 | -- Expands to something like: | ||
342 | -- (aval, arest) = Data.ByteString.splitAt 1 input | ||
343 | -- a = BitSyntax.decodeU8 aval | ||
344 | let decodefunc = case size of | ||
345 | 1 -> 'decodeU8 | ||
346 | 2 -> 'decodeU16 | ||
347 | _ -> 'decodeU32 -- Default to 32 | ||
348 | decodeHelper state (VarE decodefunc) size | ||
349 | |||
350 | readElement state (UnsignedLE size) = do | ||
351 | -- Expands to something like: | ||
352 | -- (aval, arest) = Data.ByteString.splitAt 1 input | ||
353 | -- a = BitSyntax.decodeU8LE aval | ||
354 | let decodefunc = case size of | ||
355 | 2 -> 'decodeU16LE | ||
356 | _ -> 'decodeU32LE -- Default to 4 | ||
357 | decodeHelper state (VarE decodefunc) size | ||
358 | |||
359 | readElement state (PackedBits sizes) = | ||
360 | if sum sizes `mod` 8 /= 0 | ||
361 | then error "Sizes of packed bits must == 0 mod 8" | ||
362 | else decodeHelper state | ||
363 | (AppE (VarE 'decodeBits) | ||
364 | (ListE $ map (LitE . IntegerL) sizes)) | ||
365 | ((sum sizes) `shiftR` 3) | ||
366 | |||
367 | decodeHelper :: ([Stmt], Name, [Name]) -> Exp | ||
368 | -> Integer | ||
369 | -> Q ([Stmt], Name, [Name]) | ||
370 | decodeHelper (stmts, inputname, tuplenames) decodefunc size = do | ||
371 | valname <- newName "val" | ||
372 | restname <- newName "rest" | ||
373 | tuplename <- newName "tup" | ||
374 | let dec1 = ValD (TupP [VarP valname, VarP restname]) | ||
375 | (NormalB $ AppE (AppE (VarE 'BS.splitAt) | ||
376 | (LitE (IntegerL size))) | ||
377 | (VarE inputname)) | ||
378 | [] | ||
379 | let dec2 = ValD (VarP tuplename) | ||
380 | (NormalB $ AppE decodefunc (VarE valname)) | ||
381 | [] | ||
382 | |||
383 | return (LetS [dec1, dec2] : stmts, restname, tuplename : tuplenames) | ||
384 | |||
385 | decGetName :: Dec -> Name | ||
386 | decGetName (ValD (VarP name) _ _) = name | ||
387 | decGetName _ = undefined -- Error! | ||
388 | |||
389 | bitSyn :: [ReadType] -> Q Exp | ||
390 | bitSyn elements = do | ||
391 | inputname <- newName "input" | ||
392 | (stmts, restname, tuplenames) <- foldM readElement ([], inputname, []) elements | ||
393 | returnS <- NoBindS `liftM` [| return $(tupE . map varE $ reverse tuplenames) |] | ||
394 | return $ LamE [VarP inputname] (DoE . reverse $ returnS : stmts) | ||
395 | |||
396 | |||
397 | -- Tests | ||
398 | prop_bitPacking :: [(Int, Int)] -> Bool | ||
399 | prop_bitPacking fields = | ||
400 | prevalues == (map fromIntegral postvalues) || | ||
401 | any (< 1) (map fst fields) || | ||
402 | any (< 0) (map snd fields) | ||
403 | where | ||
404 | undershoot = sum (map fst fields) `mod` 8 | ||
405 | fields' = if undershoot > 0 | ||
406 | then (8 - undershoot, 1) : fields | ||
407 | else fields | ||
408 | prevalues = map snd fields' | ||
409 | packed = bits $ PackBits fields' | ||
410 | postvalues = decodeBits (map (fromIntegral . fst) fields') packed | ||
411 | |||
412 | {- | ||
413 | instance Arbitrary Word16 where | ||
414 | arbitrary = (arbitrary :: Gen Int) >>= return . fromIntegral | ||
415 | instance Arbitrary Word32 where | ||
416 | arbitrary = (arbitrary :: Gen Int) >>= return . fromIntegral | ||
417 | -} | ||
418 | |||
419 | -- | This only works on little-endian machines as it checks that the foreign | ||
420 | -- functions (htonl and htons) match the native ones | ||
421 | prop_nativeByteShuffle32 :: Word32 -> Bool | ||
422 | prop_nativeByteShuffle32 x = endianSwitch32 x == htonl x | ||
423 | prop_nativeByteShuffle16 :: Word16 -> Bool | ||
424 | prop_nativeByteShuffle16 x = endianSwitch16 x == htons x | ||
425 | prop_littleEndian16 :: Word16 -> Bool | ||
426 | prop_littleEndian16 x = littleEndian16 x == x | ||
427 | prop_littleEndian32 :: Word32 -> Bool | ||
428 | prop_littleEndian32 x = littleEndian32 x == x | ||