summaryrefslogtreecommitdiffhomepage
path: root/libs/engineio/socket.py
blob: de8fd354d04bd85f52cb301686a3cdfca9b4544a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import sys
import time

from . import base_socket
from . import exceptions
from . import packet
from . import payload


class Socket(base_socket.BaseSocket):
    """An Engine.IO socket."""
    def poll(self):
        """Wait for packets to send to the client."""
        queue_empty = self.server.get_queue_empty_exception()
        try:
            packets = [self.queue.get(
                timeout=self.server.ping_interval + self.server.ping_timeout)]
            self.queue.task_done()
        except queue_empty:
            raise exceptions.QueueEmpty()
        if packets == [None]:
            return []
        while True:
            try:
                pkt = self.queue.get(block=False)
                self.queue.task_done()
                if pkt is None:
                    self.queue.put(None)
                    break
                packets.append(pkt)
            except queue_empty:
                break
        return packets

    def receive(self, pkt):
        """Receive packet from the client."""
        packet_name = packet.packet_names[pkt.packet_type] \
            if pkt.packet_type < len(packet.packet_names) else 'UNKNOWN'
        self.server.logger.info('%s: Received packet %s data %s',
                                self.sid, packet_name,
                                pkt.data if not isinstance(pkt.data, bytes)
                                else '<binary>')
        if pkt.packet_type == packet.PONG:
            self.schedule_ping()
        elif pkt.packet_type == packet.MESSAGE:
            self.server._trigger_event('message', self.sid, pkt.data,
                                       run_async=self.server.async_handlers)
        elif pkt.packet_type == packet.UPGRADE:
            self.send(packet.Packet(packet.NOOP))
        elif pkt.packet_type == packet.CLOSE:
            self.close(wait=False, abort=True)
        else:
            raise exceptions.UnknownPacketError()

    def check_ping_timeout(self):
        """Make sure the client is still responding to pings."""
        if self.closed:
            raise exceptions.SocketIsClosedError()
        if self.last_ping and \
                time.time() - self.last_ping > self.server.ping_timeout:
            self.server.logger.info('%s: Client is gone, closing socket',
                                    self.sid)
            # Passing abort=False here will cause close() to write a
            # CLOSE packet. This has the effect of updating half-open sockets
            # to their correct state of disconnected
            self.close(wait=False, abort=False)
            return False
        return True

    def send(self, pkt):
        """Send a packet to the client."""
        if not self.check_ping_timeout():
            return
        else:
            self.queue.put(pkt)
        self.server.logger.info('%s: Sending packet %s data %s',
                                self.sid, packet.packet_names[pkt.packet_type],
                                pkt.data if not isinstance(pkt.data, bytes)
                                else '<binary>')

    def handle_get_request(self, environ, start_response):
        """Handle a long-polling GET request from the client."""
        connections = [
            s.strip()
            for s in environ.get('HTTP_CONNECTION', '').lower().split(',')]
        transport = environ.get('HTTP_UPGRADE', '').lower()
        if 'upgrade' in connections and transport in self.upgrade_protocols:
            self.server.logger.info('%s: Received request to upgrade to %s',
                                    self.sid, transport)
            return getattr(self, '_upgrade_' + transport)(environ,
                                                          start_response)
        if self.upgrading or self.upgraded:
            # we are upgrading to WebSocket, do not return any more packets
            # through the polling endpoint
            return [packet.Packet(packet.NOOP)]
        try:
            packets = self.poll()
        except exceptions.QueueEmpty:
            exc = sys.exc_info()
            self.close(wait=False)
            raise exc[1].with_traceback(exc[2])
        return packets

    def handle_post_request(self, environ):
        """Handle a long-polling POST request from the client."""
        length = int(environ.get('CONTENT_LENGTH', '0'))
        if length > self.server.max_http_buffer_size:
            raise exceptions.ContentTooLongError()
        else:
            body = environ['wsgi.input'].read(length).decode('utf-8')
            p = payload.Payload(encoded_payload=body)
            for pkt in p.packets:
                self.receive(pkt)

    def close(self, wait=True, abort=False):
        """Close the socket connection."""
        if not self.closed and not self.closing:
            self.closing = True
            self.server._trigger_event('disconnect', self.sid, run_async=False)
            if not abort:
                self.send(packet.Packet(packet.CLOSE))
            self.closed = True
            self.queue.put(None)
            if wait:
                self.queue.join()

    def schedule_ping(self):
        self.server.start_background_task(self._send_ping)

    def _send_ping(self):
        self.last_ping = None
        self.server.sleep(self.server.ping_interval)
        if not self.closing and not self.closed:
            self.last_ping = time.time()
            self.send(packet.Packet(packet.PING))

    def _upgrade_websocket(self, environ, start_response):
        """Upgrade the connection from polling to websocket."""
        if self.upgraded:
            raise IOError('Socket has been upgraded already')
        if self.server._async['websocket'] is None:
            # the selected async mode does not support websocket
            return self.server._bad_request()
        ws = self.server._async['websocket'](
            self._websocket_handler, self.server)
        return ws(environ, start_response)

    def _websocket_handler(self, ws):
        """Engine.IO handler for websocket transport."""
        def websocket_wait():
            data = ws.wait()
            if data and len(data) > self.server.max_http_buffer_size:
                raise ValueError('packet is too large')
            return data

        # try to set a socket timeout matching the configured ping interval
        # and timeout
        for attr in ['_sock', 'socket']:  # pragma: no cover
            if hasattr(ws, attr) and hasattr(getattr(ws, attr), 'settimeout'):
                getattr(ws, attr).settimeout(
                    self.server.ping_interval + self.server.ping_timeout)

        if self.connected:
            # the socket was already connected, so this is an upgrade
            self.upgrading = True  # hold packet sends during the upgrade

            pkt = websocket_wait()
            decoded_pkt = packet.Packet(encoded_packet=pkt)
            if decoded_pkt.packet_type != packet.PING or \
                    decoded_pkt.data != 'probe':
                self.server.logger.info(
                    '%s: Failed websocket upgrade, no PING packet', self.sid)
                self.upgrading = False
                return []
            ws.send(packet.Packet(packet.PONG, data='probe').encode())
            self.queue.put(packet.Packet(packet.NOOP))  # end poll

            pkt = websocket_wait()
            decoded_pkt = packet.Packet(encoded_packet=pkt)
            if decoded_pkt.packet_type != packet.UPGRADE:
                self.upgraded = False
                self.server.logger.info(
                    ('%s: Failed websocket upgrade, expected UPGRADE packet, '
                     'received %s instead.'),
                    self.sid, pkt)
                self.upgrading = False
                return []
            self.upgraded = True
            self.upgrading = False
        else:
            self.connected = True
            self.upgraded = True

        # start separate writer thread
        def writer():
            while True:
                packets = None
                try:
                    packets = self.poll()
                except exceptions.QueueEmpty:
                    break
                if not packets:
                    # empty packet list returned -> connection closed
                    break
                try:
                    for pkt in packets:
                        ws.send(pkt.encode())
                except:
                    break
            ws.close()

        writer_task = self.server.start_background_task(writer)

        self.server.logger.info(
            '%s: Upgrade to websocket successful', self.sid)

        while True:
            p = None
            try:
                p = websocket_wait()
            except Exception as e:
                # if the socket is already closed, we can assume this is a
                # downstream error of that
                if not self.closed:  # pragma: no cover
                    self.server.logger.info(
                        '%s: Unexpected error "%s", closing connection',
                        self.sid, str(e))
                break
            if p is None:
                # connection closed by client
                break
            pkt = packet.Packet(encoded_packet=p)
            try:
                self.receive(pkt)
            except exceptions.UnknownPacketError:  # pragma: no cover
                pass
            except exceptions.SocketIsClosedError:  # pragma: no cover
                self.server.logger.info('Receive error -- socket is closed')
                break
            except:  # pragma: no cover
                # if we get an unexpected exception we log the error and exit
                # the connection properly
                self.server.logger.exception('Unknown receive error')
                break

        self.queue.put(None)  # unlock the writer task so that it can exit
        writer_task.join()
        self.close(wait=False, abort=True)

        return []