summaryrefslogtreecommitdiff
path: root/src/Network/RPC.hs
blob: 2333766a826ab8d8e1238aecf778f8d1bdb55f51 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
{-# LANGUAGE ConstraintKinds        #-}
{-# LANGUAGE DeriveDataTypeable         #-}
{-# LANGUAGE DeriveFunctor              #-}
{-# LANGUAGE DeriveFoldable             #-}
{-# LANGUAGE DeriveTraversable          #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE FlexibleContexts       #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses  #-}
{-# LANGUAGE RankNTypes             #-}
{-# LANGUAGE ScopedTypeVariables    #-}
{-# LANGUAGE TypeFamilies           #-}
{-# LANGUAGE StandaloneDeriving     #-}
module Network.RPC where

import Control.Applicative
import qualified Text.ParserCombinators.ReadP as RP
import Data.Digest.CRC32C
import Data.Word
import Data.Monoid
import Data.Hashable
import Data.String
import Data.Bits
import Data.ByteString (ByteString)
import Data.Kind       (Constraint)
import Data.Data
import Data.Default
import Data.List.Split
import Data.Ord
import Data.IP
import Network.Socket
import Text.PrettyPrint as PP hiding ((<>))
import Text.PrettyPrint.HughesPJClass  hiding (($$), (<>))
import Text.Read        (readMaybe)
import Data.Serialize                  as S
import qualified Data.ByteString.Char8 as Char8
import qualified Data.ByteString as BS
import Data.ByteString.Base16          as Base16
import System.Entropy

class (Eq a, Serialize a, Typeable a, Hashable a, Pretty a)
    => Address a where
  toSockAddr   :: a        -> SockAddr
  fromSockAddr :: SockAddr -> Maybe a

fromAddr :: (Address a, Address b) => a -> Maybe b
fromAddr = fromSockAddr . toSockAddr

-- | Note that port is zeroed.
instance Address IPv4 where
  toSockAddr = SockAddrInet 0 . toHostAddress
  fromSockAddr (SockAddrInet _ h) = Just (fromHostAddress h)
  fromSockAddr  _                 = Nothing

-- | Note that port is zeroed.
instance Address IPv6 where
  toSockAddr h = SockAddrInet6 0 0 (toHostAddress6 h) 0
  fromSockAddr (SockAddrInet6 _ _ h _) = Just (fromHostAddress6 h)
  fromSockAddr  _                      = Nothing

-- | Note that port is zeroed.
instance Address IP where
  toSockAddr (IPv4 h) = toSockAddr h
  toSockAddr (IPv6 h) = toSockAddr h
  fromSockAddr sa =
        IPv4 <$> fromSockAddr sa
    <|> IPv6 <$> fromSockAddr sa




data MessageClass = Error | Query | Response
 deriving (Eq,Ord,Enum,Bounded,Data,Show,Read)

class Envelope envelope where
    type TransactionID envelope
    data NodeId envelope

    envelopePayload     :: envelope a -> a
    envelopeTransaction :: envelope a -> TransactionID envelope
    envelopeClass       :: envelope a -> MessageClass

    -- | > buildReply self addr qry response
    --
    --    [ self ]     this node's id.
    --
    --    [ addr ]     SockAddr of query origin.
    --
    --    [ qry ]      received query message.
    --
    --    [ response ] response payload.
    --
    -- Returns: response message envelope
    buildReply :: NodeId envelope -> SockAddr -> envelope a -> b -> envelope b

-- | In Kademlia, the distance metric is XOR and the result is
-- interpreted as an unsigned integer.
newtype NodeDistance nodeid = NodeDistance nodeid
  deriving (Eq, Ord)

-- | distance(A,B) = |A xor B| Smaller values are closer.
distance :: Bits nid => nid -> nid -> NodeDistance nid
distance a b = NodeDistance $ xor a b

instance Serialize nodeid => Show (NodeDistance nodeid) where
  show (NodeDistance w) = Char8.unpack $ Base16.encode $ S.encode w

instance Serialize nodeid => Pretty (NodeDistance nodeid) where
  pPrint n = text $ show n


-- | When 'get'ing an IP it must be 'isolate'd to the appropriate
-- number of bytes since we have no other way of telling which
-- address type we are trying to parse
instance Serialize IP where
    put (IPv4 ip) = put ip
    put (IPv6 ip) = put ip

    get = do
      n <- remaining
      case n of
        4  -> IPv4 <$> get
        16 -> IPv6 <$> get
        _ -> fail (show n ++ " is the wrong number of remaining bytes to parse IP")

instance Serialize IPv4 where
    put = putWord32host    .  toHostAddress
    get = fromHostAddress <$> getWord32host

instance Serialize IPv6 where
    put ip = put $ toHostAddress6 ip
    get = fromHostAddress6 <$> get

instance Pretty IPv4 where
  pPrint = PP.text . show
  {-# INLINE pPrint #-}

instance Pretty IPv6 where
  pPrint = PP.text . show
  {-# INLINE pPrint #-}

instance Pretty IP where
  pPrint = PP.text . show
  {-# INLINE pPrint #-}

instance Hashable IPv4 where
  hashWithSalt = hashUsing toHostAddress
  {-# INLINE hashWithSalt #-}

instance Hashable IPv6 where
  hashWithSalt s a = hashWithSalt s (toHostAddress6 a)

instance Hashable IP where
  hashWithSalt s (IPv4 h) = hashWithSalt s h
  hashWithSalt s (IPv6 h) = hashWithSalt s h





data NodeAddr a = NodeAddr
  { nodeHost ::                !a
  , nodePort :: {-# UNPACK #-} !PortNumber
  } deriving (Eq, Ord, Typeable, Functor, Foldable, Traversable)

instance Show a => Show (NodeAddr a) where
  showsPrec i NodeAddr {..}
    = showsPrec i nodeHost <> showString ":" <> showsPrec i nodePort

instance Read (NodeAddr IPv4) where
  readsPrec i = RP.readP_to_S $ do
    ipv4 <- RP.readS_to_P (readsPrec i)
    _    <- RP.char ':'
    port <- toEnum <$> RP.readS_to_P (readsPrec i)
    return $ NodeAddr ipv4 port

-- | @127.0.0.1:6882@
instance Default (NodeAddr IPv4) where
  def = "127.0.0.1:6882"

-- | KRPC compatible encoding.
instance Serialize a => Serialize (NodeAddr a) where
  get = NodeAddr <$> get <*> get
  {-# INLINE get #-}
  put NodeAddr {..} = put nodeHost >> put nodePort
  {-# INLINE put #-}

-- | Example:
--
--   @nodePort \"127.0.0.1:6881\" == 6881@
--
instance IsString (NodeAddr IPv4) where
  fromString str
    | [hostAddrStr, portStr] <- splitWhen (== ':') str
    , Just hostAddr <- readMaybe hostAddrStr
    , Just portNum  <- toEnum <$> readMaybe portStr
                = NodeAddr hostAddr portNum
    | otherwise = error $ "fromString: unable to parse (NodeAddr IPv4): " ++ str

instance Hashable PortNumber where
  hashWithSalt s = hashWithSalt s . fromEnum
  {-# INLINE hashWithSalt #-}

instance Pretty PortNumber where
  pPrint = PP.int . fromEnum
  {-# INLINE pPrint #-}


instance Hashable a => Hashable (NodeAddr a) where
  hashWithSalt s NodeAddr {..} = hashWithSalt s (nodeHost, nodePort)
  {-# INLINE hashWithSalt #-}

instance Pretty ip => Pretty (NodeAddr ip) where
  pPrint NodeAddr {..} = pPrint nodeHost <> ":" <> pPrint nodePort


instance Serialize PortNumber where
  get = fromIntegral <$> getWord16be
  {-# INLINE get #-}
  put = putWord16be . fromIntegral
  {-# INLINE put #-}




data NodeInfo dht addr u = NodeInfo
  { nodeId   :: !(NodeId dht)
  , nodeAddr :: !(NodeAddr addr)
  , nodeAnnotation :: u
  } deriving (Functor, Foldable, Traversable)

deriving instance ( Show (NodeId dht)
                  , Show addr
                  , Show u ) => Show (NodeInfo dht addr u)

mapAddress :: (addr -> b) -> NodeInfo dht addr u -> NodeInfo dht b u
mapAddress f ni = ni { nodeAddr = fmap f (nodeAddr ni) }

traverseAddress :: Applicative f => (addr -> f b) -> NodeInfo dht addr u -> f (NodeInfo dht b u)
traverseAddress f ni = fmap (\addr -> ni { nodeAddr = addr }) $ traverse f (nodeAddr ni)

-- Warning: Eq and Ord only look at the nodeId field.
instance Eq (NodeId dht) => Eq (NodeInfo dht a u) where
    a == b  = (nodeId a == nodeId b)

instance Ord (NodeId dht) => Ord (NodeInfo dht a u) where
  compare = comparing nodeId


-- TODO WARN is the 'system' random suitable for this?
-- | Generate random NodeID used for the entire session.
--   Distribution of ID's should be as uniform as possible.
--
genNodeId :: forall dht.
             ( Serialize (NodeId dht)
             , FiniteBits (NodeId dht)
             ) => IO (NodeId dht)
genNodeId = either error id . S.decode <$> getEntropy nodeIdSize
 where
    nodeIdSize = finiteBitSize (undefined :: NodeId dht) `div` 8

-- | Generate a random 'NodeId' within a range suitable for a bucket.  To
-- obtain a sample for bucket number /index/ where /is_last/ indicates if this
-- is for the current deepest bucket in our routing table:
--
-- > sample <- genBucketSample nid (bucketRange index is_last)
genBucketSample :: ( FiniteBits (NodeId dht)
                   , Serialize (NodeId dht)
                   ) => NodeId dht -> (Int,Word8,Word8) -> IO (NodeId dht)
genBucketSample n qmb = genBucketSample' getEntropy n qmb

-- | Generalizion of 'genBucketSample' that accepts a byte generator
-- function to use instead of the system entropy.
genBucketSample' :: forall m dht.
                    ( Applicative m
                    , FiniteBits (NodeId dht)
                    , Serialize (NodeId dht)
                    ) =>
    (Int -> m ByteString) -> NodeId dht -> (Int,Word8,Word8) -> m (NodeId dht)
genBucketSample' gen self (q,m,b)
    | q <= 0           =  either error id . S.decode <$> gen nodeIdSize
    | q >= nodeIdSize  =  pure self
    | otherwise        =  either error id . S.decode .  build <$> gen (nodeIdSize - q + 1)
 where
    nodeIdSize = finiteBitSize (undefined :: NodeId dht) `div` 8
    build tl = BS.init hd <> BS.cons (h .|. t) (BS.tail tl)
     where
        hd = BS.take q $ S.encode self
        h = xor b (complement m .&. BS.last hd)
        t = m .&. BS.head tl


class Envelope envelope => WireFormat raw envelope where
    type SerializableTo raw :: * -> Constraint
    type CipherContext raw envelope

    decodeHeaders :: CipherContext raw envelope -> ByteString -> Either String (envelope raw)
    decodePayload :: SerializableTo raw a => envelope raw -> Either String (envelope a)

    encodeHeaders :: CipherContext raw envelope -> envelope raw -> ByteString
    encodePayload :: SerializableTo raw a => envelope a -> envelope raw