import uuid from socketio import packet from socketio.pubsub_manager import PubSubManager from werkzeug.test import EnvironBuilder class SocketIOTestClient(object): """ This class is useful for testing a Flask-SocketIO server. It works in a similar way to the Flask Test Client, but adapted to the Socket.IO server. :param app: The Flask application instance. :param socketio: The application's ``SocketIO`` instance. :param namespace: The namespace for the client. If not provided, the client connects to the server on the global namespace. :param query_string: A string with custom query string arguments. :param headers: A dictionary with custom HTTP headers. :param flask_test_client: The instance of the Flask test client currently in use. Passing the Flask test client is optional, but is necessary if you want the Flask user session and any other cookies set in HTTP routes accessible from Socket.IO events. """ queue = {} acks = {} def __init__(self, app, socketio, namespace=None, query_string=None, headers=None, flask_test_client=None): def _mock_send_packet(sid, pkt): if pkt.packet_type == packet.EVENT or \ pkt.packet_type == packet.BINARY_EVENT: if sid not in self.queue: self.queue[sid] = [] if pkt.data[0] == 'message' or pkt.data[0] == 'json': self.queue[sid].append({'name': pkt.data[0], 'args': pkt.data[1], 'namespace': pkt.namespace or '/'}) else: self.queue[sid].append({'name': pkt.data[0], 'args': pkt.data[1:], 'namespace': pkt.namespace or '/'}) elif pkt.packet_type == packet.ACK or \ pkt.packet_type == packet.BINARY_ACK: self.acks[sid] = {'args': pkt.data, 'namespace': pkt.namespace or '/'} elif pkt.packet_type == packet.DISCONNECT: self.connected[pkt.namespace or '/'] = False self.app = app self.flask_test_client = flask_test_client self.sid = uuid.uuid4().hex self.queue[self.sid] = [] self.acks[self.sid] = None self.callback_counter = 0 self.socketio = socketio self.connected = {} socketio.server._send_packet = _mock_send_packet socketio.server.environ[self.sid] = {} socketio.server.async_handlers = False # easier to test when socketio.server.eio.async_handlers = False # events are sync if isinstance(socketio.server.manager, PubSubManager): raise RuntimeError('Test client cannot be used with a message ' 'queue. Disable the queue on your test ' 'configuration.') socketio.server.manager.initialize() self.connect(namespace=namespace, query_string=query_string, headers=headers) def is_connected(self, namespace=None): """Check if a namespace is connected. :param namespace: The namespace to check. The global namespace is assumed if this argument is not provided. """ return self.connected.get(namespace or '/', False) def connect(self, namespace=None, query_string=None, headers=None): """Connect the client. :param namespace: The namespace for the client. If not provided, the client connects to the server on the global namespace. :param query_string: A string with custom query string arguments. :param headers: A dictionary with custom HTTP headers. Note that it is usually not necessary to explicitly call this method, since a connection is automatically established when an instance of this class is created. An example where it this method would be useful is when the application accepts multiple namespace connections. """ url = '/socket.io' if query_string: if query_string[0] != '?': query_string = '?' + query_string url += query_string environ = EnvironBuilder(url, headers=headers).get_environ() environ['flask.app'] = self.app if self.flask_test_client: # inject cookies from Flask self.flask_test_client.cookie_jar.inject_wsgi(environ) self.connected['/'] = True if self.socketio.server._handle_eio_connect( self.sid, environ) is False: del self.connected['/'] if namespace is not None and namespace != '/': self.connected[namespace] = True pkt = packet.Packet(packet.CONNECT, namespace=namespace) with self.app.app_context(): if self.socketio.server._handle_eio_message( self.sid, pkt.encode()) is False: del self.connected[namespace] def disconnect(self, namespace=None): """Disconnect the client. :param namespace: The namespace to disconnect. The global namespace is assumed if this argument is not provided. """ if not self.is_connected(namespace): raise RuntimeError('not connected') pkt = packet.Packet(packet.DISCONNECT, namespace=namespace) with self.app.app_context(): self.socketio.server._handle_eio_message(self.sid, pkt.encode()) del self.connected[namespace or '/'] def emit(self, event, *args, **kwargs): """Emit an event to the server. :param event: The event name. :param *args: The event arguments. :param callback: ``True`` if the client requests a callback, ``False`` if not. Note that client-side callbacks are not implemented, a callback request will just tell the server to provide the arguments to invoke the callback, but no callback is invoked. Instead, the arguments that the server provided for the callback are returned by this function. :param namespace: The namespace of the event. The global namespace is assumed if this argument is not provided. """ namespace = kwargs.pop('namespace', None) if not self.is_connected(namespace): raise RuntimeError('not connected') callback = kwargs.pop('callback', False) id = None if callback: self.callback_counter += 1 id = self.callback_counter pkt = packet.Packet(packet.EVENT, data=[event] + list(args), namespace=namespace, id=id) with self.app.app_context(): encoded_pkt = pkt.encode() if isinstance(encoded_pkt, list): for epkt in encoded_pkt: self.socketio.server._handle_eio_message(self.sid, epkt) else: self.socketio.server._handle_eio_message(self.sid, encoded_pkt) ack = self.acks.pop(self.sid, None) if ack is not None: return ack['args'][0] if len(ack['args']) == 1 \ else ack['args'] def send(self, data, json=False, callback=False, namespace=None): """Send a text or JSON message to the server. :param data: A string, dictionary or list to send to the server. :param json: ``True`` to send a JSON message, ``False`` to send a text message. :param callback: ``True`` if the client requests a callback, ``False`` if not. Note that client-side callbacks are not implemented, a callback request will just tell the server to provide the arguments to invoke the callback, but no callback is invoked. Instead, the arguments that the server provided for the callback are returned by this function. :param namespace: The namespace of the event. The global namespace is assumed if this argument is not provided. """ if json: msg = 'json' else: msg = 'message' return self.emit(msg, data, callback=callback, namespace=namespace) def get_received(self, namespace=None): """Return the list of messages received from the server. Since this is not a real client, any time the server emits an event, the event is simply stored. The test code can invoke this method to obtain the list of events that were received since the last call. :param namespace: The namespace to get events from. The global namespace is assumed if this argument is not provided. """ if not self.is_connected(namespace): raise RuntimeError('not connected') namespace = namespace or '/' r = [pkt for pkt in self.queue[self.sid] if pkt['namespace'] == namespace] self.queue[self.sid] = [pkt for pkt in self.queue[self.sid] if pkt not in r] return r