summaryrefslogtreecommitdiff
path: root/Data/BitSyntax.hs
blob: 6d14d0c150a54732e0135b6b7a4ad15a59addeef (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ForeignFunctionInterface #-}
-- | This module contains fuctions and templates for building up and breaking
--   down packed bit structures. It's something like Erlang's bit-syntax (or,
--   actually, more like Python's struct module).
--
--   This code uses Data.ByteString which is included in GHC 6.5 and you can
--   get it for 6.4 at <http://www.cse.unsw.edu.au/~dons/fps.html>
module Data.BitSyntax (
  -- * Building bit structures
  -- | The core function here is makeBits, which is a perfectly normal function.
  --   Here's an example which makes a SOCKS4a request header:
  -- @
  --   makeBits [U8 4, U8 1, U16 80, U32 10, NullTerminated \"username\",
  --             NullTerminated \"www.haskell.org\"]
  -- @
  BitBlock(..),
  makeBits,
  -- * Breaking up bit structures
  -- | The main function for this is bitSyn, which is a template function and
  --   so you'll need to run with @-fth@ to enable template haskell
  --   <http://www.haskell.org/th/>.
  --
  --   To expand the function you use the splice command:
  -- @
  --   $(bitSyn [...])
  -- @
  --
  -- The expanded function has type @ByteString -> (...)@ where the elements of
  -- the tuple depend of the argument to bitSyn (that's why it has to be a template
  -- function).
  --
  -- Heres an example, translated from the Erlang manual, which parses an IP header:
  --
  -- @
  -- decodeOptions bs ([_, hlen], _, _, _, _, _, _, _, _, _)
  --   | hlen > 5  = return $ BS.splitAt (fromIntegral ((hlen - 5) * 4)) bs
  --   | otherwise = return (BS.empty, bs)
  -- @
  --
  -- @
  -- ipDecode = $(bitSyn [PackedBits [4, 4], Unsigned 1, Unsigned 2, Unsigned 2,
  --                      PackedBits [3, 13], Unsigned 1, Unsigned 1, Unsigned 2,
  --                      Fixed 4, Fixed 4, Context \'decodeOptions, Rest])
  -- @
  --
  -- @
  -- ipPacket = BS.pack [0x45, 0, 0, 0x34, 0xd8, 0xd2, 0x40, 0, 0x40, 0x06,
  --                     0xa0, 0xca, 0xac, 0x12, 0x68, 0x4d, 0xac, 0x18,
  --                     0x00, 0xaf]
  -- @
  --
  -- This function has several weaknesses compared to the Erlang version: The
  -- elements of the bit structure are not named in place, instead you have to
  -- do a pattern match on the resulting tuple and match up the indexes. The
  -- type system helps in this, but it's still not quite as nice.

  ReadType(..), bitSyn,

  -- I get errors if these aren't exported (Can't find interface-file
  -- declaration for Data.BitSyntax.decodeU16)
  decodeU8, decodeU16, decodeU32, decodeU16LE, decodeU32LE) where

import Language.Haskell.TH.Lib
import Language.Haskell.TH.Syntax

import qualified Data.ByteString as BS
import Data.Char (ord)
import Control.Monad
-- import Test.QuickCheck (Arbitrary(), arbitrary, Gen())

import Foreign

foreign import ccall unsafe "htonl" htonl :: Word32 -> Word32
foreign import ccall unsafe "htons" htons :: Word16 -> Word16

-- There's no good way to convert to little-endian. The htons functions only
-- convert to big endian and they don't have any little endian friends. So we
-- need to detect which kind of system we are on and act accordingly. We can
-- detect the type of system by seeing if htonl actaully doesn't anything (it's
-- the identity function on big-endian systems, of course). If it doesn't we're
-- on a big-endian system and so need to do the byte-swapping in Haskell because
-- the C functions are no-ops

-- | A native Haskell version of htonl for the case where we need to convert
--   to little-endian on a big-endian system
endianSwitch32 :: Word32 -> Word32
endianSwitch32 a = ((a .&. 0xff) `shiftL` 24) .|.
                   ((a .&. 0xff00) `shiftL` 8) .|.
                   ((a .&. 0xff0000) `shiftR` 8) .|.
                   (a `shiftR` 24)

-- | A native Haskell version of htons for the case where we need to convert
--   to little-endian on a big-endian system
endianSwitch16 :: Word16 -> Word16
endianSwitch16 a = ((a .&. 0xff) `shiftL` 8) .|.
                   (a `shiftR` 8)

littleEndian32 :: Word32 -> Word32
littleEndian32 a = if htonl 1 == 1
                     then endianSwitch32 a
                     else a

littleEndian16 :: Word16 -> Word16
littleEndian16 a = if htonl 1 == 1
                     then endianSwitch16 a
                     else a

data BitBlock = -- | Unsigned 8-bit int
                U8 Int |
                -- | Unsigned 16-bit int
                U16 Int |
                -- | Unsigned 32-bit int
                U32 Int |
                -- | Little-endian, unsigned 16-bit int
                U16LE Int |
                -- | Little-endian, unsigned 32-bit int
                U32LE Int |
                -- | Appends the string with a trailing NUL byte
                NullTerminated String |
                -- | Appends the string without any terminator
                RawString String |
                -- | Appends a ByteString
                RawByteString BS.ByteString |
                -- | Packs a series of bit fields together. The argument is
                --   a list of pairs where the first element is the size
                --   (in bits) and the second is the value. The sum of the
                --   sizes for a given PackBits must be a multiple of 8
                PackBits [(Int, Int)]
                deriving (Show)

-- Encodes a member of the Bits class as a series of bytes and returns the
-- ByteString of those bytes.
getBytes :: (Integral a, Bounded a, FiniteBits a) => a -> BS.ByteString
getBytes input =
    let getByte _ 0 = []
        getByte x remaining = (fromIntegral $ (x .&. 0xff)) :
                              getByte (shiftR x 8) (remaining - 1)
        in
        if (finiteBitSize input `mod` 8) /= 0
           then error "Input data bit size must be a multiple of 8"
           else BS.pack $ getByte input (finiteBitSize input `div` 8)

-- Performs the work behind PackBits
packBits :: (Word8, Int, [Word8])  -- ^ The current byte, the number of bits
                                   --   used in that byte and the (reverse)
                                   --   list of produced bytes
         -> (Int, Int)  -- ^ The size (in bits) of the value, and the value
         -> (Word8, Int, [Word8])  -- See first argument
packBits (current, used, bytes) (size, value) =
  if bitsWritten < size
    then packBits (0, 0, current' : bytes) (size - bitsWritten, value)
    else if used' == 8
           then (0, 0, current' : bytes)
           else (current', used', bytes)
  where
    top = size - 1
    topOfByte = 7 - used
    aligned = value `shift` (topOfByte - top)
    newBits = (fromIntegral aligned) :: Word8
    current' = current .|. newBits
    bitsWritten = min (8 - used) size
    used' = used + bitsWritten

bits :: BitBlock -> BS.ByteString
bits (U8 v) = BS.pack [((fromIntegral v) :: Word8)]
bits (U16 v) = getBytes ((htons $ fromIntegral v) :: Word16)
bits (U32 v) = getBytes ((htonl $ fromIntegral v) :: Word32)
bits (U16LE v) = getBytes (littleEndian16 $ fromIntegral v)
bits (U32LE v) = getBytes (littleEndian32 $ fromIntegral v)
bits (NullTerminated str) = BS.pack $ (map (fromIntegral . ord) str) ++ [0]
bits (RawString str) = BS.pack $ map (fromIntegral . ord) str
bits (RawByteString bs) = bs
bits (PackBits bitspec) =
  if (sum $ map fst bitspec) `mod` 8 /= 0
    then error "Sum of sizes of a bit spec must == 0 mod 8"
    else (\(_, _, a) -> BS.pack $ reverse a) $ foldl packBits (0, 0, []) bitspec

-- | Make a binary string from the list of elements given
makeBits :: [BitBlock] -> BS.ByteString
makeBits = BS.concat . (map bits)

data ReadType = -- | An unsigned number of some number of bytes. Valid
                --   arguments are 1, 2 and 4
                Unsigned Integer |
                -- | An unsigned, little-endian integer of some number of
                --   bytes. Valid arguments are 2 and 4
                UnsignedLE Integer |
                -- | A variable length element to be decoded by a custom
                --   function. The function's name is given as the single
                --   argument and should have type
                --   @Monad m => ByteString -> m (v, ByteString)@
                Variable Name |
                -- | Skip some number of bytes
                Skip Integer |
                -- | A fixed size field, the result of which is a ByteString
                --   of that length.
                Fixed Integer |
                -- | Decode a value and ignore it (the result will not be part
                --   of the returned tuple)
                Ignore ReadType |
                -- | Like variable, but the decoding function is passed the
                --   entire result tuple so far. Thus the function whose name
                --   passed has type
                --   @Monad m => ByteString -> (...) -> m (v, ByteString)@
                Context Name |
                -- | Takes the most recent element of the result tuple and
                --   interprets it as the length of this field. Results in
                --   a ByteString
                LengthPrefixed |
                -- | Decode a series of bit fields, results in a list of
                --   Integers. Each element of the argument is the length of
                --   the bit field. The sums of the lengths must be a multiple
                --   of 8
                PackedBits [Integer] |
                -- | Results in a ByteString containing the undecoded bytes so
                --   far. Generally used at the end to return the trailing body
                --   of a structure, it can actually be used at any point in the
                --   decoding to return the trailing part at that point.
                Rest

fromBytes :: (Num a, Bits a) => [a] -> a
fromBytes input =
    let dofb accum [] = accum
        dofb accum (x:xs) = dofb ((shiftL accum 8) .|. x) xs
        in
        dofb 0 $ reverse input

decodeU8 :: BS.ByteString -> Word8
decodeU8 = fromIntegral . head . BS.unpack
decodeU16 :: BS.ByteString -> Word16
decodeU16 = htons . fromBytes . map fromIntegral . BS.unpack
decodeU32 :: BS.ByteString -> Word32
decodeU32 = htonl . fromBytes . map fromIntegral . BS.unpack
decodeU16LE :: BS.ByteString -> Word16
decodeU16LE = littleEndian16 . fromBytes . map fromIntegral . BS.unpack
decodeU32LE :: BS.ByteString -> Word32
decodeU32LE = littleEndian32 . fromBytes . map fromIntegral . BS.unpack

decodeBits :: [Integer] -> BS.ByteString -> [Integer]
decodeBits sizes bs =
  reverse values
  where
    (values, _, _) = foldl unpackBits ([], 0, BS.unpack bitdata) sizes
    bytesize = (sum sizes) `shiftR` 3
    (bitdata, _) = BS.splitAt (fromIntegral bytesize) bs

unpackBits :: ([Integer], Integer, [Word8]) -> Integer -> ([Integer], Integer, [Word8])
unpackBits state size = unpackBitsInner 0 state size

unpackBitsInner :: Integer ->
                   ([Integer], Integer, [Word8]) ->
                   Integer ->
                   ([Integer], Integer, [Word8])
unpackBitsInner _ (output, used, []) _ = (output, used, [])
unpackBitsInner val (output, used, current : input) bitsToGet =
  if bitsToGet' > 0
    then unpackBitsInner val'' (output, 0, input) bitsToGet'
    else if used' < 8
           then (val'' : output, used', current'' : input)
           else (val'' : output, 0, input)
  where
    bitsAv = 8 - used
    bitsTaken = min bitsAv bitsToGet
    val' = val `shift` (fromIntegral bitsTaken)
    current' = current `shiftR` (fromIntegral (8 - bitsTaken))
    current'' = current `shiftL` (fromIntegral bitsTaken)
    val'' = val' .|. (fromIntegral current')
    bitsToGet' = bitsToGet - bitsTaken
    used' = used + bitsTaken

readElement :: ([Stmt], Name, [Name]) -> ReadType -> Q ([Stmt], Name, [Name])

readElement (stmts, inputname, tuplenames) (Context funcname) = do
  valname <- newName "val"
  restname <- newName "rest"

  let stmt = BindS (TupP [VarP valname, VarP restname])
                   (AppE (AppE (VarE funcname)
                               (VarE inputname))
                         (TupE $ map VarE $ reverse tuplenames))

  return (stmt : stmts, restname, valname : tuplenames)

readElement (stmts, inputname, tuplenames) (Fixed n) = do
  valname <- newName "val"
  restname <- newName "rest"
  let dec1 = ValD (TupP [VarP valname, VarP restname])
                  (NormalB $ AppE (AppE (VarE 'BS.splitAt)
                                        (LitE (IntegerL n)))
                                  (VarE inputname))
                  []

  return (LetS [dec1] : stmts, restname, valname : tuplenames)

readElement state@(_, _, tuplenames) (Ignore n) = do
  (a, b, _) <- readElement state n
  return (a, b, tuplenames)

readElement (stmts, inputname, tuplenames) LengthPrefixed = do
  valname <- newName "val"
  restname <- newName "rest"

  let sourcename = head tuplenames
      dec = ValD (TupP [VarP valname, VarP restname])
                 (NormalB $ AppE (AppE (VarE 'BS.splitAt)
                                       (AppE (VarE 'fromIntegral)
                                             (VarE sourcename)))
                                 (VarE inputname))
                 []

  return (LetS [dec] : stmts, restname, valname : tuplenames)

readElement (stmts, inputname, tuplenames) (Variable funcname) = do
  valname <- newName "val"
  restname <- newName "rest"

  let stmt = BindS (TupP [VarP valname, VarP restname])
                   (AppE (VarE funcname) (VarE inputname))

  return (stmt : stmts, restname, valname : tuplenames)

readElement (stmts, inputname, tuplenames) Rest = do
  restname <- newName "rest"
  let dec = ValD (VarP restname)
                 (NormalB $ VarE inputname)
                 []
  return (LetS [dec] : stmts, inputname, restname : tuplenames)

readElement (stmts, inputname, tuplenames) (Skip n) = do
  -- Expands to something like:
  --   rest = Data.ByteString.drop n input
  restname <- newName "rest"
  let dec = ValD (VarP restname)
                 (NormalB $ AppE (AppE (VarE 'BS.drop)
                                       (LitE (IntegerL n)))
                                 (VarE inputname))
                 []
  return (LetS [dec] : stmts, restname, tuplenames)

readElement state (Unsigned size) = do
  -- Expands to something like:
  --    (aval, arest) = Data.ByteString.splitAt 1 input
  --    a = BitSyntax.decodeU8 aval
  let decodefunc = case size of
                     1 -> 'decodeU8
                     2 -> 'decodeU16
                     _ -> 'decodeU32 -- Default to 32
  decodeHelper state (VarE decodefunc) size

readElement state (UnsignedLE size) = do
  -- Expands to something like:
  --    (aval, arest) = Data.ByteString.splitAt 1 input
  --    a = BitSyntax.decodeU8LE aval
  let decodefunc = case size of
                     2 -> 'decodeU16LE
                     _ -> 'decodeU32LE -- Default to 4
  decodeHelper state (VarE decodefunc) size

readElement state (PackedBits sizes) =
  if sum sizes `mod` 8 /= 0
    then error "Sizes of packed bits must == 0 mod 8"
    else decodeHelper state
                      (AppE (VarE 'decodeBits)
                            (ListE $ map (LitE . IntegerL) sizes))
                      ((sum sizes) `shiftR` 3)

decodeHelper :: ([Stmt], Name, [Name])      -> Exp
                                            -> Integer
                                            -> Q ([Stmt], Name, [Name])
decodeHelper (stmts, inputname, tuplenames) decodefunc size = do
  valname <- newName "val"
  restname <- newName "rest"
  tuplename <- newName "tup"
  let dec1 = ValD (TupP [VarP valname, VarP restname])
                  (NormalB $ AppE (AppE (VarE 'BS.splitAt)
                                        (LitE (IntegerL size)))
                                  (VarE inputname))
                  []
  let dec2 = ValD (VarP tuplename)
                  (NormalB $ AppE decodefunc (VarE valname))
                  []

  return (LetS [dec1, dec2] : stmts, restname, tuplename : tuplenames)

decGetName :: Dec -> Name
decGetName (ValD (VarP name) _ _) = name
decGetName _                      = undefined -- Error!

bitSyn :: [ReadType] -> Q Exp
bitSyn elements = do
    inputname <- newName "input"
    (stmts, restname, tuplenames) <- foldM readElement ([], inputname, []) elements
    returnS <- NoBindS `liftM` [| return $(tupE . map varE $ reverse tuplenames) |]
    return $ LamE [VarP inputname] (DoE . reverse $ returnS : stmts)


-- Tests
prop_bitPacking :: [(Int, Int)] -> Bool
prop_bitPacking fields =
  prevalues == (map fromIntegral postvalues) ||
  any (< 1) (map fst fields) ||
  any (< 0) (map snd fields)
  where
    undershoot = sum (map fst fields) `mod` 8
    fields' = if undershoot > 0
                then (8 - undershoot, 1) : fields
                else fields
    prevalues = map snd fields'
    packed = bits $ PackBits fields'
    postvalues = decodeBits (map (fromIntegral . fst) fields') packed

{-
instance Arbitrary Word16 where
  arbitrary = (arbitrary :: Gen Int) >>= return . fromIntegral
instance Arbitrary Word32 where
  arbitrary = (arbitrary :: Gen Int) >>= return . fromIntegral
-}

-- | This only works on little-endian machines as it checks that the foreign
--   functions (htonl and htons) match the native ones
prop_nativeByteShuffle32 :: Word32 -> Bool
prop_nativeByteShuffle32 x = endianSwitch32 x == htonl x
prop_nativeByteShuffle16 :: Word16 -> Bool
prop_nativeByteShuffle16 x = endianSwitch16 x == htons x
prop_littleEndian16 :: Word16 -> Bool
prop_littleEndian16 x = littleEndian16 x == x
prop_littleEndian32 :: Word32 -> Bool
prop_littleEndian32 x = littleEndian32 x == x