diff --git a/chocolate-master b/chocolate-master index d946be0..54e6b72 100755 --- a/chocolate-master +++ b/chocolate-master @@ -40,6 +40,7 @@ MAX_RESPONSE_LEN = 1400 NET_PACKET_TYPE_QUERY = 13 NET_PACKET_TYPE_QUERY_RESPONSE = 14 +NET_PACKET_TYPE_NAT_HOLE_PUNCH = 16 # Packet types, matches the constants in net_defs.h. @@ -53,6 +54,7 @@ NET_MASTER_PACKET_TYPE_SIGN_START = 6 NET_MASTER_PACKET_TYPE_SIGN_START_RESPONSE = 7 NET_MASTER_PACKET_TYPE_SIGN_END = 8 NET_MASTER_PACKET_TYPE_SIGN_END_RESPONSE = 9 +NET_MASTER_PACKET_TYPE_NAT_HOLE_PUNCH = 10 def bind_socket_to(sock, config): """ Bind the specified socket to the address/port configuration from @@ -356,6 +358,31 @@ class MasterServer: self.send_message(addr, NET_MASTER_PACKET_TYPE_SIGN_END_RESPONSE, signature) + def send_hole_punch(self, server, client_addr): + """Send a hole punch request to a server on behalf of a client.""" + client_addr_str = "%s:%d" % client_addr + packet = client_addr_str.encode("utf8") + b'\0' + self.send_message(server.addr, NET_MASTER_PACKET_TYPE_NAT_HOLE_PUNCH, + packet) + self.log_output(client_addr, "Sent hole punch to %s" % server) + + def process_hole_punch(self, data, addr): + """Process a NAT hole punch request from a client.""" + + # Packet just contains the address of the the server. Check it's really + # a server that we have registered. + _, server_addr_str = read_string(data) + self.log_output(addr, "Hole punch request for %r" % server_addr_str) + a, p = server_addr_str.split(":", 1) + server_addr = (a, int(p)) + + if server_addr not in self.servers: + self.log_output(addr, "Unknown server to hole punch") + return + + # Forward hole punch request to the server: + self.send_hole_punch(self.servers[server_addr], addr) + def process_packet(self, data, addr): """ Process a packet received from a server. """ @@ -371,6 +398,8 @@ class MasterServer: self.sign_start_message(addr) elif packet_type == NET_MASTER_PACKET_TYPE_SIGN_END: self.sign_end_message(data[2:], addr) + elif packet_type == NET_MASTER_PACKET_TYPE_NAT_HOLE_PUNCH: + self.process_hole_punch(data[2:], addr) def is_blocked(self, addr): addr_str = "%s:%i" % addr diff --git a/chocolate-master-test.py b/chocolate-master-test.py index a7b4f21..6a3b23c 100755 --- a/chocolate-master-test.py +++ b/chocolate-master-test.py @@ -28,6 +28,8 @@ import sys import struct import json +NET_PACKET_TYPE_NAT_HOLE_PUNCH = 16 + NET_MASTER_PACKET_TYPE_ADD = 0 NET_MASTER_PACKET_TYPE_ADD_RESPONSE = 1 NET_MASTER_PACKET_TYPE_QUERY = 2 @@ -38,6 +40,7 @@ NET_MASTER_PACKET_TYPE_SIGN_START = 6 NET_MASTER_PACKET_TYPE_SIGN_START_RESPONSE = 7 NET_MASTER_PACKET_TYPE_SIGN_END = 8 NET_MASTER_PACKET_TYPE_SIGN_END_RESPONSE = 9 +NET_MASTER_PACKET_TYPE_NAT_HOLE_PUNCH = 10 UDP_PORT = 2342 @@ -222,17 +225,39 @@ def sign_end(addr_str): NET_MASTER_PACKET_TYPE_SIGN_END_RESPONSE) print(response) +def hole_punch(master_addr_str, server_addr_str): + """Send a NAT hole punch request to the master server.""" + + master_addr = parse_address(master_addr_str) + server_addr = parse_address(server_addr_str) + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + + # Send request + print("Sending hole punch request to master at %s" % str(master_addr)) + packet = server_addr_str.encode("utf8") + b"\x00" + send_message(sock, master_addr, NET_MASTER_PACKET_TYPE_NAT_HOLE_PUNCH, + packet) + + # Wait for response. + print("Waiting for response...") + response = get_response(sock, server_addr, + NET_PACKET_TYPE_NAT_HOLE_PUNCH) + + print("Got hole punch request from server via master server.") + + commands = [ ("query", query_master), ("add", add_to_master), ("get-metadata", get_metadata), ("sign-start", sign_start), ("sign-end", sign_end), + ("hole-punch", hole_punch), ] for name, callback in commands: if len(sys.argv) > 2 and name == sys.argv[1]: - callback(sys.argv[2]) + callback(*sys.argv[2:]) break else: print("Usage:") @@ -241,4 +266,5 @@ else: print("chocolate-master-test.py get-metadata
") print("chocolate-master-test.py sign-start ") print("chocolate-master-test.py sign-end ") + print("chocolate-master-test.py hole-punch