summaryrefslogtreecommitdiffhomepage
path: root/libs/flask_socketio/test_client.py
blob: 0603d77c733de90cc93a30fb688c5e293c74342c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import uuid

from socketio import packet
from socketio.pubsub_manager import PubSubManager
from werkzeug.test import EnvironBuilder


class SocketIOTestClient(object):
    """
    This class is useful for testing a Flask-SocketIO server. It works in a
    similar way to the Flask Test Client, but adapted to the Socket.IO server.

    :param app: The Flask application instance.
    :param socketio: The application's ``SocketIO`` instance.
    :param namespace: The namespace for the client. If not provided, the client
                      connects to the server on the global namespace.
    :param query_string: A string with custom query string arguments.
    :param headers: A dictionary with custom HTTP headers.
    :param auth: Optional authentication data, given as a dictionary.
    :param flask_test_client: The instance of the Flask test client
                              currently in use. Passing the Flask test
                              client is optional, but is necessary if you
                              want the Flask user session and any other
                              cookies set in HTTP routes accessible from
                              Socket.IO events.
    """
    clients = {}

    def __init__(self, app, socketio, namespace=None, query_string=None,
                 headers=None, auth=None, flask_test_client=None):
        def _mock_send_packet(eio_sid, pkt):
            # make sure the packet can be encoded and decoded
            epkt = pkt.encode()
            if not isinstance(epkt, list):
                pkt = packet.Packet(encoded_packet=epkt)
            else:
                pkt = packet.Packet(encoded_packet=epkt[0])
                for att in epkt[1:]:
                    pkt.add_attachment(att)
            client = self.clients.get(eio_sid)
            if not client:
                return
            if pkt.packet_type == packet.EVENT or \
                    pkt.packet_type == packet.BINARY_EVENT:
                if pkt.data[0] == 'message' or pkt.data[0] == 'json':
                    client.queue.append({
                        'name': pkt.data[0],
                        'args': pkt.data[1],
                        'namespace': pkt.namespace or '/'})
                else:
                    client.queue.append({
                        'name': pkt.data[0],
                        'args': pkt.data[1:],
                        'namespace': pkt.namespace or '/'})
            elif pkt.packet_type == packet.ACK or \
                    pkt.packet_type == packet.BINARY_ACK:
                client.acks = {'args': pkt.data,
                               'namespace': pkt.namespace or '/'}
            elif pkt.packet_type in [packet.DISCONNECT, packet.CONNECT_ERROR]:
                client.connected[pkt.namespace or '/'] = False

        _current_packet = None

        def _mock_send_eio_packet(eio_sid, eio_pkt):
            nonlocal _current_packet
            if _current_packet is not None:
                _current_packet.add_attachment(eio_pkt.data)
                if _current_packet.attachment_count == \
                        len(_current_packet.attachments):
                    _mock_send_packet(eio_sid, _current_packet)
                    _current_packet = None
            else:
                pkt = packet.Packet(encoded_packet=eio_pkt.data)
                if pkt.attachment_count == 0:
                    _mock_send_packet(eio_sid, pkt)
                else:
                    _current_packet = pkt

        self.app = app
        self.flask_test_client = flask_test_client
        self.eio_sid = uuid.uuid4().hex
        self.clients[self.eio_sid] = self
        self.callback_counter = 0
        self.socketio = socketio
        self.connected = {}
        self.queue = []
        self.acks = None
        socketio.server._send_packet = _mock_send_packet
        socketio.server._send_eio_packet = _mock_send_eio_packet
        socketio.server.environ[self.eio_sid] = {}
        socketio.server.async_handlers = False      # easier to test when
        socketio.server.eio.async_handlers = False  # events are sync
        if isinstance(socketio.server.manager, PubSubManager):
            raise RuntimeError('Test client cannot be used with a message '
                               'queue. Disable the queue on your test '
                               'configuration.')
        socketio.server.manager.initialize()
        self.connect(namespace=namespace, query_string=query_string,
                     headers=headers, auth=auth)

    def is_connected(self, namespace=None):
        """Check if a namespace is connected.

        :param namespace: The namespace to check. The global namespace is
                         assumed if this argument is not provided.
        """
        return self.connected.get(namespace or '/', False)

    def connect(self, namespace=None, query_string=None, headers=None,
                auth=None):
        """Connect the client.

        :param namespace: The namespace for the client. If not provided, the
                          client connects to the server on the global
                          namespace.
        :param query_string: A string with custom query string arguments.
        :param headers: A dictionary with custom HTTP headers.
        :param auth: Optional authentication data, given as a dictionary.

        Note that it is usually not necessary to explicitly call this method,
        since a connection is automatically established when an instance of
        this class is created. An example where it this method would be useful
        is when the application accepts multiple namespace connections.
        """
        url = '/socket.io'
        namespace = namespace or '/'
        if query_string:
            if query_string[0] != '?':
                query_string = '?' + query_string
            url += query_string
        environ = EnvironBuilder(url, headers=headers).get_environ()
        environ['flask.app'] = self.app
        if self.flask_test_client:
            # inject cookies from Flask
            if hasattr(self.flask_test_client, '_add_cookies_to_wsgi'):
                # flask >= 2.3
                self.flask_test_client._add_cookies_to_wsgi(environ)
            else:  # pragma: no cover
                # flask < 2.3
                self.flask_test_client.cookie_jar.inject_wsgi(environ)
        self.socketio.server._handle_eio_connect(self.eio_sid, environ)
        pkt = packet.Packet(packet.CONNECT, auth, namespace=namespace)
        self.socketio.server._handle_eio_message(self.eio_sid, pkt.encode())
        sid = self.socketio.server.manager.sid_from_eio_sid(self.eio_sid,
                                                            namespace)
        if sid:
            self.connected[namespace] = True

    def disconnect(self, namespace=None):
        """Disconnect the client.

        :param namespace: The namespace to disconnect. The global namespace is
                         assumed if this argument is not provided.
        """
        if not self.is_connected(namespace):
            raise RuntimeError('not connected')
        pkt = packet.Packet(packet.DISCONNECT, namespace=namespace)
        self.socketio.server._handle_eio_message(self.eio_sid, pkt.encode())
        del self.connected[namespace or '/']

    def emit(self, event, *args, **kwargs):
        """Emit an event to the server.

        :param event: The event name.
        :param *args: The event arguments.
        :param callback: ``True`` if the client requests a callback, ``False``
                         if not. Note that client-side callbacks are not
                         implemented, a callback request will just tell the
                         server to provide the arguments to invoke the
                         callback, but no callback is invoked. Instead, the
                         arguments that the server provided for the callback
                         are returned by this function.
        :param namespace: The namespace of the event. The global namespace is
                          assumed if this argument is not provided.
        """
        namespace = kwargs.pop('namespace', None)
        if not self.is_connected(namespace):
            raise RuntimeError('not connected')
        callback = kwargs.pop('callback', False)
        id = None
        if callback:
            self.callback_counter += 1
            id = self.callback_counter
        pkt = packet.Packet(packet.EVENT, data=[event] + list(args),
                            namespace=namespace, id=id)
        encoded_pkt = pkt.encode()
        if isinstance(encoded_pkt, list):
            for epkt in encoded_pkt:
                self.socketio.server._handle_eio_message(self.eio_sid, epkt)
        else:
            self.socketio.server._handle_eio_message(self.eio_sid, encoded_pkt)
        if self.acks is not None:
            ack = self.acks
            self.acks = None
            return ack['args'][0] if len(ack['args']) == 1 \
                else ack['args']

    def send(self, data, json=False, callback=False, namespace=None):
        """Send a text or JSON message to the server.

        :param data: A string, dictionary or list to send to the server.
        :param json: ``True`` to send a JSON message, ``False`` to send a text
                     message.
        :param callback: ``True`` if the client requests a callback, ``False``
                         if not. Note that client-side callbacks are not
                         implemented, a callback request will just tell the
                         server to provide the arguments to invoke the
                         callback, but no callback is invoked. Instead, the
                         arguments that the server provided for the callback
                         are returned by this function.
        :param namespace: The namespace of the event. The global namespace is
                          assumed if this argument is not provided.
        """
        if json:
            msg = 'json'
        else:
            msg = 'message'
        return self.emit(msg, data, callback=callback, namespace=namespace)

    def get_received(self, namespace=None):
        """Return the list of messages received from the server.

        Since this is not a real client, any time the server emits an event,
        the event is simply stored. The test code can invoke this method to
        obtain the list of events that were received since the last call.

        :param namespace: The namespace to get events from. The global
                          namespace is assumed if this argument is not
                          provided.
        """
        if not self.is_connected(namespace):
            raise RuntimeError('not connected')
        namespace = namespace or '/'
        r = [pkt for pkt in self.queue if pkt['namespace'] == namespace]
        self.queue = [pkt for pkt in self.queue if pkt not in r]
        return r