diff --git a/chocolate-master b/chocolate-master index 9d6646e..50e01a3 100755 --- a/chocolate-master +++ b/chocolate-master @@ -39,6 +39,11 @@ SERVER_TIMEOUT = 2 * 60 * 60 # 2 hours MAX_RESPONSE_LEN = 1000 +# Normal packet types. + +NET_PACKET_TYPE_QUERY = 13 +NET_PACKET_TYPE_QUERY_RESPONSE = 14 + # Packet types, matches the constants in net_defs.h. NET_MASTER_PACKET_TYPE_ADD = 0 @@ -56,6 +61,7 @@ class Server: def __init__(self, addr): self.addr = addr + self.verified = False self.refresh() def refresh(self): @@ -90,6 +96,42 @@ class MasterServer: self.server_addr = (address, port) self.open_socket() + def send_query(self, server): + """ Send a query to the specified server. """ + + packet = struct.pack(">h", NET_PACKET_TYPE_QUERY) + + self.query_sock.sendto(packet, server.addr) + + def parse_query_response(self, data, addr): + """ Parse a packet received (presumably) in response to a + query that we sent to a server. """ + + # Unknown? + + if addr not in self.servers: + return + + server = self.servers[addr] + + # Check packet type + + packet_type, = struct.unpack(">h", data[0:2]) + + if packet_type != NET_PACKET_TYPE_QUERY_RESPONSE: + return + + # TODO: Process rest of details so that we can maintain + # some information about list of servers? + + # Server responded to our query, so it is verified. + # We can send a positive response to its add request. + + if not server.verified: + self.log_output(server.addr, "Server responded to query, added") + server.verified = True + self.send_add_response(server, 1) + def send_message(self, addr, message_type, payload): """ Send a message of the specified type to the specified remote address. """ @@ -108,6 +150,11 @@ class MasterServer: for server in self.servers.values(): encoded_addr = server.encode_addr() + # Only include verified servers. + + if not server.verified: + continue + # Start a new packet? if len(packets[-1]) + len(encoded_addr) > MAX_RESPONSE_LEN: @@ -117,6 +164,13 @@ class MasterServer: return packets + def send_add_response(self, server, success): + """ Send a response to a server's add request. """ + + self.send_message(server.addr, + NET_MASTER_PACKET_TYPE_ADD_RESPONSE, + struct.pack(">h", success)) + def process_add_to_master(self, addr): """ Process an "add to master" request received from a server. """ @@ -125,15 +179,21 @@ class MasterServer: server = self.servers[addr] server.refresh() else: - self.log_output(addr, "Add to master") server = Server(addr) self.servers[addr] = server - # Send a reply indicating successful + # If the server has already been verified, we can send a + # reply immediately. Otherwise, query the server via a + # different socket first to verify it. + # Why is this needed? The server might be behind a NAT + # gateway. In this case, the master might be able to + # communicate with it, but other machines might not. - self.send_message(addr, - NET_MASTER_PACKET_TYPE_ADD_RESPONSE, - struct.pack(">h", 1)) + if server.verified: + self.send_add_response(1) + else: + self.log_output(addr, "Add request, sending query to confirm") + self.send_query(server) def process_query(self, addr): """ Process a query message received from a client. """ @@ -166,6 +226,16 @@ class MasterServer: except Exception, e: print e + def rx_packet_query_sock(self): + """ Invoked when a packet is received on the query socket. """ + + data, addr = self.query_sock.recvfrom(1024) + + try: + self.parse_query_response(data, addr) + except Exception, e: + print e + def age_servers(self): """ Check server timestamps and flush out stale servers. """ @@ -176,25 +246,42 @@ class MasterServer: (time() - server.add_time)) del self.servers[server.addr] + # Expect a response to queries quickly, otherwise add + # requests are rejected. + + if not server.verified and time() - server.add_time > 5: + self.log_output(server.addr, + "No response to query, add rejected") + self.send_add_response(server, 0) + del self.servers[server.addr] + def open_socket(self): """ Open the server socket and bind to the listening address. """ self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.sock.bind(self.server_addr) + # Query socket, used to send queries to servers to check that + # they are actually accessible. + + self.query_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + def run(self): """ Run the server main loop, listening for packets. """ self.log_output(self.server_addr, "Server started.") while True: - r, w, x = select([self.sock], [], [], 5) + r, w, x = select([self.sock, self.query_sock], [], [], 5) self.age_servers() if self.sock in r: self.rx_packet() + if self.query_sock in r: + self.rx_packet_query_sock() + if __name__ == "__main__": server = MasterServer(UDP_ADDRESS, UDP_PORT) server.run()