Send hole punch requests when adding servers.

If a server is behind a NAT gateway, we might not get a response when
we send it a query. However, a hole punch request may make it accessible.
After two seconds with no response, try a hole punch request and if we get
a reply, retry the request.

We keep state on which servers require hole punching and only allow hole
punch requests from clients if they are flagged as such.
This commit is contained in:
Simon Howard 2019-02-10 00:06:35 -05:00
parent 214d01ff4e
commit 9dca3ddca1
1 changed files with 39 additions and 15 deletions

View File

@ -92,6 +92,7 @@ class Server:
self.addr = addr
self.add_time = time()
self.verified = False
self.needs_hole_punch = False
self.metadata = {}
self.refresh()
@ -138,6 +139,7 @@ class MasterServer:
self.sock = self.open_socket(server_address)
self.query_sock = self.open_socket(query_address)
self.query_address = query_address
self.block_patterns = block_patterns
if secure_demo.available and SIGNING_KEY:
@ -177,28 +179,32 @@ class MasterServer:
query that we sent to a server. """
# Unknown?
if addr not in self.servers:
return
server = self.servers[addr]
# Check packet type
packet_type, = struct.unpack(">h", data[0:2])
# If we have requested a hole punch from the server and received it,
# try sending another query request; it may succeed now.
if packet_type == NET_PACKET_TYPE_NAT_HOLE_PUNCH:
if not server.verified and server.needs_hole_punch:
self.log_output(server.addr, "Got hole punch; resending query")
self.send_query(server)
return
if packet_type != NET_PACKET_TYPE_QUERY_RESPONSE:
return
# Read metadata from query and store it for future use.
metadata = self.parse_query_data(data[2:])
metadata["address"], metadata["port"] = addr
server.set_metadata(metadata)
# Server responded to our query, so it is verified.
# We can send a positive response to its add request.
if not server.verified:
self.log_output(server.addr, "Server responded to query, added")
server.verified = True
@ -360,11 +366,15 @@ class MasterServer:
def send_hole_punch(self, server, client_addr):
"""Send a hole punch request to a server on behalf of a client."""
# Don't send hole punch requests to servers we added without needing
# hole punching.
if not server.needs_hole_punch:
return
client_addr_str = "%s:%d" % client_addr
packet = client_addr_str.encode("utf8") + b'\0'
self.send_message(server.addr, NET_MASTER_PACKET_TYPE_NAT_HOLE_PUNCH,
packet)
self.log_output(client_addr, "Sent hole punch to %s" % server)
self.log_output(client_addr, "Requested hole punch from %s" % server)
def process_hole_punch(self, data, addr):
"""Process a NAT hole punch request from a client."""
@ -425,6 +435,27 @@ class MasterServer:
except Exception as e:
print("error on query socket packet from %s: %s" % (addr, e))
def check_unverified_server(self, server):
"""Check the given server that has not yet been verified."""
now = time()
# After 2 seconds, send a hole punch request to the server for the
# query address. Our queries have gone unanswered but if the server
# responds to hole punch requests we may be able to try again and
# get a response.
if (not server.needs_hole_punch and self.query_address
and now - server.refresh_time > 2):
server.needs_hole_punch = True
self.send_hole_punch(server, self.query_address)
# After 5 seconds, if we get no response at all then the add request
# is rejected.
if now - server.refresh_time > 5:
self.log_output(server.addr,
"No response to query, add rejected")
self.send_add_response(server, 0)
del self.servers[server.addr]
def age_servers(self):
""" Check server timestamps and flush out stale servers. """
servers = list(self.servers.values())
@ -434,15 +465,8 @@ class MasterServer:
"Timed out: no heartbeat in %i secs" %
(time() - server.refresh_time))
del self.servers[server.addr]
# Expect a response to queries quickly, otherwise add
# requests are rejected.
elif not server.verified and time() - server.refresh_time > 5:
self.log_output(server.addr,
"No response to query, add rejected")
self.send_add_response(server, 0)
del self.servers[server.addr]
elif not server.verified:
self.check_unverified_server(server)
def open_socket(self, address):
""" Open a server socket and bind to the specified address. """
@ -458,7 +482,7 @@ class MasterServer:
self.log_output(None, "Server started.")
while True:
r, w, x = select([self.sock, self.query_sock], [], [], 5)
r, w, x = select([self.sock, self.query_sock], [], [], 1)
self.age_servers()