summaryrefslogtreecommitdiff
path: root/dht/src/Hans/Checksum.hs
blob: 7afc93c7c29e822779f2e9438871a2609f13f916 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE BangPatterns #-}

-- BANNERSTART
-- - Copyright 2006-2008, Galois, Inc.
-- - This software is distributed under a standard, three-clause BSD license.
-- - Please see the file LICENSE, distributed with this software, for specific
-- - terms and conditions.
-- Author: Adam Wick <awick@galois.com>
-- BANNEREND
-- |A module providing checksum computations to other parts of Hans. The
-- checksum here is the standard Internet 16-bit checksum (the one's 
-- complement of the one's complement sum of the data).

module Hans.Checksum(
    -- * Checksums
    computeChecksum,
    Checksum(..),
    PartialChecksum(),
    emptyPartialChecksum,
    finalizeChecksum,
    stepChecksum,

    Pair8(..),
  ) where

import           Data.Bits (Bits(shiftL,shiftR,complement,clearBit,(.&.)))
import           Data.List (foldl')
import           Data.Word (Word8,Word16,Word32)
import qualified Data.ByteString        as S
import qualified Data.ByteString.Lazy   as L
import qualified Data.ByteString.Short  as Sh
import qualified Data.ByteString.Unsafe as S


data PartialChecksum = PartialChecksum { pcAccum :: {-# UNPACK #-} !Word32
                                       , pcCarry ::                !(Maybe Word8)
                                       } deriving (Eq,Show)

emptyPartialChecksum :: PartialChecksum
emptyPartialChecksum  = PartialChecksum
  { pcAccum = 0
  , pcCarry = Nothing
  }

finalizeChecksum :: PartialChecksum -> Word16
finalizeChecksum pc = complement (fromIntegral (fold32 (fold32 result)))
  where
  fold32 :: Word32 -> Word32
  fold32 x = (x .&. 0xFFFF) + (x `shiftR` 16)

  result = case pcCarry pc of
    Nothing   -> pcAccum pc
    Just prev -> stepChecksum (pcAccum pc) prev 0
{-# INLINE finalizeChecksum #-}


computeChecksum :: Checksum a => a -> Word16
computeChecksum a = finalizeChecksum (extendChecksum a emptyPartialChecksum)
{-# INLINE computeChecksum #-}

-- | Incremental checksum computation interface.
class Checksum a where
  extendChecksum :: a -> PartialChecksum -> PartialChecksum


data Pair8 = Pair8 !Word8 !Word8

instance Checksum Pair8 where
  extendChecksum (Pair8 hi lo) = \ PartialChecksum { .. } ->
    case pcCarry of
      Nothing -> PartialChecksum { pcAccum = stepChecksum pcAccum hi lo
                                 , pcCarry = Nothing }
      Just c  -> PartialChecksum { pcAccum = stepChecksum pcAccum c hi
                                 , pcCarry = Just lo }
  {-# INLINE extendChecksum #-}

instance Checksum Word16 where
  extendChecksum w = \pc -> extendChecksum (Pair8 hi lo) pc
    where
    lo = fromIntegral  w
    hi = fromIntegral (w `shiftR` 8)
  {-# INLINE extendChecksum #-}

instance Checksum Word32 where
  extendChecksum w = \pc ->
    extendChecksum (fromIntegral  w              :: Word16) $
    extendChecksum (fromIntegral (w `shiftR` 16) :: Word16) pc
  {-# INLINE extendChecksum #-}
  
instance Checksum a => Checksum [a] where
  extendChecksum as = \pc -> foldl' (flip extendChecksum) pc as
  {-# INLINE extendChecksum #-}

instance Checksum L.ByteString where
  extendChecksum lbs = \pc -> extendChecksum (L.toChunks lbs) pc
  {-# INLINE extendChecksum #-}

-- XXX this could be faster if we could mirror the structure of the instance for
-- S.ByteString
instance Checksum Sh.ShortByteString where
  extendChecksum shb = \ pc -> extendChecksum (Sh.fromShort shb) pc


instance Checksum S.ByteString where
  extendChecksum b pc
    | S.null b  = pc
    | otherwise = case pcCarry pc of
        Nothing   -> result
        Just prev -> extendChecksum (S.tail b) PartialChecksum
          { pcCarry = Nothing
          , pcAccum = stepChecksum (pcAccum pc) prev (S.unsafeIndex b 0)
          }
    where

    n' = S.length b
    n  = clearBit n' 0 -- aligned to two

    result = PartialChecksum
      { pcAccum = loop (pcAccum pc) 0
      , pcCarry = carry
      }

    carry
      | odd n'    = Just $! S.unsafeIndex b n
      | otherwise = Nothing

    loop !acc off
      | off < n   = loop (stepChecksum acc hi lo) (off + 2)
      | otherwise = acc
      where hi    = S.unsafeIndex b off
            lo    = S.unsafeIndex b (off+1)

stepChecksum :: Word32 -> Word8 -> Word8 -> Word32
stepChecksum acc hi lo = acc + fromIntegral hi `shiftL` 8 + fromIntegral lo
{-# INLINE stepChecksum #-}