summaryrefslogtreecommitdiff
path: root/Mainline.hs
blob: d24b3376fb7602e38282e5ebea4c36b158244a86 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
{-# LANGUAGE DeriveDataTypeable         #-}
{-# LANGUAGE DeriveFoldable             #-}
{-# LANGUAGE DeriveFunctor              #-}
{-# LANGUAGE DeriveTraversable          #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Mainline where

import Control.Arrow
import Control.Concurrent.STM
import Crypto.Random
import Data.BEncode             as BE
import Data.BEncode.BDict       as BE
import Data.Bool
import Data.ByteArray
import Data.ByteString          (ByteString)
import Data.ByteString          as B
import Data.ByteString.Lazy     (toStrict)
import Data.Data
import Data.IP
import Data.Maybe
import Data.Monoid
import qualified Data.Serialize as S
import Data.Typeable
import Data.Word
import Network.Address          (Address, fromSockAddr, sockAddrPort,
                                 toSockAddr, withPort)
import Network.QueryResponse
import Network.Socket

newtype NodeId = NodeId ByteString
 deriving (Eq,Ord,Show,ByteArrayAccess, BEncode)

data NodeInfo = NodeInfo
  { nodeId   :: NodeId
  , nodeIP :: IP
  , nodePort :: PortNumber
  }

nodeAddr :: NodeInfo -> SockAddr
nodeAddr (NodeInfo _ ip port) = toSockAddr ip `withPort` port

nodeInfo :: NodeId -> SockAddr -> Either String NodeInfo
nodeInfo nid saddr
    | Just ip <- fromSockAddr saddr
    , Just port <- sockAddrPort saddr = Right $ NodeInfo nid ip port
    | otherwise                       = Left "Address family not supported."

-- | Types of RPC errors.
data ErrorCode
    -- | Some error doesn't fit in any other category.
  = GenericError

    -- | Occurs when server fail to process procedure call.
  | ServerError

    -- | Malformed packet, invalid arguments or bad token.
  | ProtocolError

    -- | Occurs when client trying to call method server don't know.
  | MethodUnknown
    deriving (Show, Read, Eq, Ord, Bounded, Typeable, Data)

-- | According to the table:
-- <http://bittorrent.org/beps/bep_0005.html#errors>
instance Enum ErrorCode where
  fromEnum GenericError  = 201
  fromEnum ServerError   = 202
  fromEnum ProtocolError = 203
  fromEnum MethodUnknown = 204
  {-# INLINE fromEnum #-}
  toEnum 201 = GenericError
  toEnum 202 = ServerError
  toEnum 203 = ProtocolError
  toEnum 204 = MethodUnknown
  toEnum _   = GenericError
  {-# INLINE toEnum #-}

instance BEncode ErrorCode where
  toBEncode = toBEncode . fromEnum
  {-# INLINE toBEncode #-}
  fromBEncode b = toEnum <$> fromBEncode b
  {-# INLINE fromBEncode #-}

data Error = Error
    { errorCode    :: !ErrorCode  -- ^ The type of error.
    , errorMessage :: !ByteString -- ^ Human-readable text message.
    } deriving ( Show, Eq, Ord, Typeable, Data, Read )

newtype TransactionId = TransactionId ByteString
 deriving (Eq, Ord, Show, BEncode)

newtype Method = Method ByteString
 deriving (Eq, Ord, Show, BEncode)

data Message a = Q { msgOrigin   :: NodeId
                   , msgID       :: TransactionId
                   , qryPayload  :: a
                   , qryMethod   :: Method
                   , qryReadOnly :: Bool }

               | R { msgOrigin      :: NodeId
                   , msgID          :: TransactionId
                   , rspPayload     :: Either Error a
                   , rspReflectedIP :: Maybe SockAddr }

instance BE.BEncode (Message BValue) where
    toBEncode = encodeMessage
    fromBEncode = error "fromBEncode"

encodeMessage (Q origin tid a meth ro)
    = case a of
        BDict args -> encodeQuery tid meth (BDict $ genericArgs origin ro `union` args)
        _          -> encodeQuery tid meth a -- XXX: Not really a valid query.
encodeMessage (R origin tid v ip)
    = case v of
        Right vals -> encodeResponse tid vals (BString . encodeAddr <$> ip)
        Left  err  -> encodeError tid err

encodeAddr :: SockAddr -> ByteString
encodeAddr (SockAddrInet port addr)
    = S.runPut (S.putWord32host addr >> S.put (fromIntegral port :: Word16))
encodeAddr (SockAddrInet6 port _ addr _)
    = S.runPut (S.put addr >> S.put (fromIntegral port :: Word16))
encodeAddr _ = B.empty

genericArgs nodeid ro =
       "id" .=! nodeid
    .: "ro" .=? bool Nothing (Just (1 :: Int)) ro
    .: endDict

encodeError     tid (Error ecode emsg)       = encodeAny tid "e" (ecode,emsg) id
encodeResponse  tid rvals              rip   = encodeAny tid "r" rvals        ("ip" .=? rip   .:)
encodeQuery     tid qmeth              qargs = encodeAny tid "q" qmeth        ("a"  .=! qargs .:)

encodeAny tid key val aux = toDict $
    aux $  key .=! val
        .: "t" .=! tid
        .: "y" .=! key
        .: endDict

parsePacket :: ByteString -> SockAddr -> Either String (Message BValue, NodeInfo)
parsePacket bs addr = do pkt <- BE.decode bs
                         ni <- nodeInfo (msgOrigin pkt) addr
                         return (pkt, ni)

encodePacket :: Message BValue -> NodeInfo -> (ByteString, SockAddr)
encodePacket msg ni = ( toStrict $ BE.encode msg
                      , nodeAddr ni )

newClient ::
      SockAddr -> IO (Client String Method TransactionId NodeInfo (Message BValue))
newClient addr = do
    udp <- udpTransport addr
    nid <- error "todo: tentative node id"
    self <- atomically $ newTVar
             $ NodeInfo nid (fromMaybe (toEnum 0) $ fromSockAddr addr)
                            (fromMaybe 0 $ sockAddrPort addr)
    -- drg <- getSystemDRG
    let net = layerTransport parsePacket encodePacket udp
        dispatch tbl = DispatchMethods
            { classifyInbound = classify
            , lookupHandler = handlers
            , tableMethods = tbl
            }
        mapT = transactionMethods mapMethods gen
        gen :: Word16 -> (TransactionId, Word16)
        gen cnt = (TransactionId $ S.encode cnt, cnt+1)
    map_var <- atomically $ newTVar (0, mempty)
    return Client
            { clientNet           = net
            , clientDispatcher    = dispatch mapT
            , clientErrorReporter = ignoreErrors -- TODO
            , clientPending       = map_var
            , clientAddress       = atomically (readTVar self)
            , clientResponseId    = return
            }

classify :: Message BValue -> MessageClass String Method TransactionId
classify (Q { msgID = tid, qryMethod = meth }) = IsQuery meth tid
classify (R { msgID = tid                   }) = IsResponse tid

encodePayload tid self dest b = R (nodeId self) tid (Right $ BE.toBEncode b) (Just $ nodeAddr dest)

errorPayload tid self dest e = R (nodeId self) tid (Left e) (Just $ nodeAddr dest)

decodePayload :: BEncode a => Message BValue -> Either String a
decodePayload msg = BE.fromBEncode $ qryPayload msg

handler f = Just $ MethodHandler decodePayload encodePayload f

handlers :: Method -> Maybe (MethodHandler String TransactionId NodeInfo (Message BValue))
handlers (Method "ping"     ) = error "handler pingH"
handlers (Method "find_node") = error "find_node"
handlers (Method "get_peers") = error "get_peers"
handlers (Method meth       ) = Just $ MethodHandler decodePayload errorPayload (defaultH meth)

data Ping = Ping

pingH :: NodeInfo -> Ping -> IO Ping
pingH = error "pingH"

defaultH :: ByteString -> NodeInfo -> BValue -> IO Error
defaultH meth _ _ = return $ Error MethodUnknown ("Unknown method " <> meth)