diff --git a/chocolate-master b/chocolate-master index 930407f..0b627e9 100755 --- a/chocolate-master +++ b/chocolate-master @@ -55,6 +55,7 @@ NET_MASTER_PACKET_TYPE_SIGN_START_RESPONSE = 7 NET_MASTER_PACKET_TYPE_SIGN_END = 8 NET_MASTER_PACKET_TYPE_SIGN_END_RESPONSE = 9 NET_MASTER_PACKET_TYPE_NAT_HOLE_PUNCH = 10 +NET_MASTER_PACKET_TYPE_NAT_HOLE_PUNCH_ALL = 11 def bind_socket_to(sock, config): """ Bind the specified socket to the address/port configuration from @@ -288,14 +289,12 @@ class MasterServer: # Generate a list of strings representing servers. Only include # verified servers. - verified_servers = filter(lambda s: s.verified, self.servers.values()) - strings = [ str(server) for server in verified_servers] + verified_servers = [s for s in self.servers.values() if s.verified] + strings = [str(server) for server in verified_servers] # Send response packets. - for packet in self.strings_to_packets(strings): - self.send_message(addr, - NET_MASTER_PACKET_TYPE_QUERY_RESPONSE, + self.send_message(addr, NET_MASTER_PACKET_TYPE_QUERY_RESPONSE, packet) def process_metadata_request(self, addr): @@ -393,6 +392,16 @@ class MasterServer: # Forward hole punch request to the server: self.send_hole_punch(self.servers[server_addr], addr) + def process_hole_punch_all(self, addr): + """Process a hole punch request for all servers.""" + # For NET_MASTER_PACKET_TYPE_NAT_HOLE_PUNCH_ALL, we send hole punch + # requests on behalf of the client to all servers we have flagged as + # requiring hole punch assistance to contact. + self.log_output(addr, "Mass hole punch request") + for server in self.servers.values(): + if server.needs_hole_punch: + self.send_hole_punch(server, addr) + def process_packet(self, data, addr): """ Process a packet received from a server. """ @@ -410,6 +419,8 @@ class MasterServer: self.sign_end_message(data[2:], addr) elif packet_type == NET_MASTER_PACKET_TYPE_NAT_HOLE_PUNCH: self.process_hole_punch(data[2:], addr) + elif packet_type == NET_MASTER_PACKET_TYPE_NAT_HOLE_PUNCH_ALL: + self.process_hole_punch_all(addr) def is_blocked(self, addr): addr_str = "%s:%i" % addr