diff options
author | morpheus65535 <[email protected]> | 2021-05-08 10:25:29 -0400 |
---|---|---|
committer | GitHub <[email protected]> | 2021-05-08 10:25:29 -0400 |
commit | 72b6ab3c6a11e1c12d86563989d88d73e4e64377 (patch) | |
tree | 3739d75ba8814a226f241828afb888826977ef78 /libs/engineio | |
parent | 09a31cf9a42f4b08ef9953fc9dc3af47bbc39217 (diff) | |
download | bazarr-72b6ab3c6a11e1c12d86563989d88d73e4e64377.tar.gz bazarr-72b6ab3c6a11e1c12d86563989d88d73e4e64377.zip |
Added live update of UI using websocket. Make sure your reverse proxy upgrade the connection!
Diffstat (limited to 'libs/engineio')
-rw-r--r-- | libs/engineio/__init__.py | 2 | ||||
-rw-r--r-- | libs/engineio/async_drivers/aiohttp.py | 5 | ||||
-rw-r--r-- | libs/engineio/async_drivers/asgi.py | 50 | ||||
-rw-r--r-- | libs/engineio/async_drivers/gevent_uwsgi.py | 8 | ||||
-rw-r--r-- | libs/engineio/async_drivers/sanic.py | 19 | ||||
-rw-r--r-- | libs/engineio/async_drivers/tornado.py | 8 | ||||
-rw-r--r-- | libs/engineio/asyncio_client.py | 174 | ||||
-rw-r--r-- | libs/engineio/asyncio_server.py | 167 | ||||
-rw-r--r-- | libs/engineio/asyncio_socket.py | 75 | ||||
-rw-r--r-- | libs/engineio/client.py | 217 | ||||
-rw-r--r-- | libs/engineio/packet.py | 78 | ||||
-rw-r--r-- | libs/engineio/payload.py | 71 | ||||
-rw-r--r-- | libs/engineio/server.py | 222 | ||||
-rw-r--r-- | libs/engineio/socket.py | 66 |
14 files changed, 653 insertions, 509 deletions
diff --git a/libs/engineio/__init__.py b/libs/engineio/__init__.py index f2c5b774c..b897468d2 100644 --- a/libs/engineio/__init__.py +++ b/libs/engineio/__init__.py @@ -17,7 +17,7 @@ else: # pragma: no cover get_tornado_handler = None ASGIApp = None -__version__ = '3.11.2' +__version__ = '4.0.2dev' __all__ = ['__version__', 'Server', 'WSGIApp', 'Middleware', 'Client'] if AsyncServer is not None: # pragma: no cover diff --git a/libs/engineio/async_drivers/aiohttp.py b/libs/engineio/async_drivers/aiohttp.py index ad6987649..a59199588 100644 --- a/libs/engineio/async_drivers/aiohttp.py +++ b/libs/engineio/async_drivers/aiohttp.py @@ -3,7 +3,6 @@ import sys from urllib.parse import urlsplit from aiohttp.web import Response, WebSocketResponse -import six def create_route(app, engineio_server, engineio_endpoint): @@ -113,8 +112,8 @@ class WebSocket(object): # pragma: no cover async def wait(self): msg = await self._sock.receive() - if not isinstance(msg.data, six.binary_type) and \ - not isinstance(msg.data, six.text_type): + if not isinstance(msg.data, bytes) and \ + not isinstance(msg.data, str): raise IOError() return msg.data diff --git a/libs/engineio/async_drivers/asgi.py b/libs/engineio/async_drivers/asgi.py index 9f14ef05f..eb3139b5e 100644 --- a/libs/engineio/async_drivers/asgi.py +++ b/libs/engineio/async_drivers/asgi.py @@ -1,5 +1,6 @@ import os import sys +import asyncio from engineio.static_files import get_static_file @@ -19,6 +20,10 @@ class ASGIApp: :param engineio_path: The endpoint where the Engine.IO application should be installed. The default value is appropriate for most cases. + :param on_startup: function to be called on application startup; can be + coroutine + :param on_shutdown: function to be called on application shutdown; can be + coroutine Example usage:: @@ -34,11 +39,14 @@ class ASGIApp: uvicorn.run(app, '127.0.0.1', 5000) """ def __init__(self, engineio_server, other_asgi_app=None, - static_files=None, engineio_path='engine.io'): + static_files=None, engineio_path='engine.io', + on_startup=None, on_shutdown=None): self.engineio_server = engineio_server self.other_asgi_app = other_asgi_app self.engineio_path = engineio_path.strip('/') self.static_files = static_files or {} + self.on_startup = on_startup + self.on_shutdown = on_shutdown async def __call__(self, scope, receive, send): if scope['type'] in ['http', 'websocket'] and \ @@ -73,11 +81,29 @@ class ASGIApp: await self.not_found(receive, send) async def lifespan(self, receive, send): - event = await receive() - if event['type'] == 'lifespan.startup': - await send({'type': 'lifespan.startup.complete'}) - elif event['type'] == 'lifespan.shutdown': - await send({'type': 'lifespan.shutdown.complete'}) + while True: + event = await receive() + if event['type'] == 'lifespan.startup': + if self.on_startup: + try: + await self.on_startup() \ + if asyncio.iscoroutinefunction(self.on_startup) \ + else self.on_startup() + except: + await send({'type': 'lifespan.startup.failed'}) + return + await send({'type': 'lifespan.startup.complete'}) + elif event['type'] == 'lifespan.shutdown': + if self.on_shutdown: + try: + await self.on_shutdown() \ + if asyncio.iscoroutinefunction(self.on_shutdown) \ + else self.on_shutdown() + except: + await send({'type': 'lifespan.shutdown.failed'}) + return + await send({'type': 'lifespan.shutdown.complete'}) + return async def not_found(self, receive, send): """Return a 404 Not Found error to the client.""" @@ -111,7 +137,7 @@ async def translate_request(scope, receive, send): if event['type'] == 'http.request': payload += event.get('body') or b'' elif event['type'] == 'websocket.connect': - await send({'type': 'websocket.accept'}) + pass else: return {} @@ -139,6 +165,7 @@ async def translate_request(scope, receive, send): 'SERVER_PORT': '0', 'asgi.receive': receive, 'asgi.send': send, + 'asgi.scope': scope, } for hdr_name, hdr_value in scope['headers']: @@ -163,6 +190,14 @@ async def translate_request(scope, receive, send): async def make_response(status, headers, payload, environ): headers = [(h[0].encode('utf-8'), h[1].encode('utf-8')) for h in headers] + if environ['asgi.scope']['type'] == 'websocket': + if status.startswith('200 '): + await environ['asgi.send']({'type': 'websocket.accept', + 'headers': headers}) + else: + await environ['asgi.send']({'type': 'websocket.close'}) + return + await environ['asgi.send']({'type': 'http.response.start', 'status': int(status.split(' ')[0]), 'headers': headers}) @@ -183,6 +218,7 @@ class WebSocket(object): # pragma: no cover async def __call__(self, environ): self.asgi_receive = environ['asgi.receive'] self.asgi_send = environ['asgi.send'] + await self.asgi_send({'type': 'websocket.accept'}) await self.handler(self) async def close(self): diff --git a/libs/engineio/async_drivers/gevent_uwsgi.py b/libs/engineio/async_drivers/gevent_uwsgi.py index 07fa2a79d..bdee812de 100644 --- a/libs/engineio/async_drivers/gevent_uwsgi.py +++ b/libs/engineio/async_drivers/gevent_uwsgi.py @@ -1,7 +1,5 @@ from __future__ import absolute_import -import six - import gevent from gevent import queue from gevent.event import Event @@ -75,7 +73,7 @@ class uWSGIWebSocket(object): # pragma: no cover def _send(self, msg): """Transmits message either in binary or UTF-8 text mode, depending on its type.""" - if isinstance(msg, six.binary_type): + if isinstance(msg, bytes): method = uwsgi.websocket_send_binary else: method = uwsgi.websocket_send @@ -86,11 +84,11 @@ class uWSGIWebSocket(object): # pragma: no cover def _decode_received(self, msg): """Returns either bytes or str, depending on message type.""" - if not isinstance(msg, six.binary_type): + if not isinstance(msg, bytes): # already decoded - do nothing return msg # only decode from utf-8 if message is not binary data - type = six.byte2int(msg[0:1]) + type = ord(msg[0:1]) if type >= 48: # no binary return msg.decode('utf-8') # binary message, don't try to decode diff --git a/libs/engineio/async_drivers/sanic.py b/libs/engineio/async_drivers/sanic.py index 6929654b9..e9555f310 100644 --- a/libs/engineio/async_drivers/sanic.py +++ b/libs/engineio/async_drivers/sanic.py @@ -1,16 +1,15 @@ import sys from urllib.parse import urlsplit -from sanic.response import HTTPResponse -try: +try: # pragma: no cover + from sanic.response import HTTPResponse from sanic.websocket import WebSocketProtocol except ImportError: - # the installed version of sanic does not have websocket support + HTTPResponse = None WebSocketProtocol = None -import six -def create_route(app, engineio_server, engineio_endpoint): +def create_route(app, engineio_server, engineio_endpoint): # pragma: no cover """This function sets up the engine.io endpoint as a route for the application. @@ -26,7 +25,7 @@ def create_route(app, engineio_server, engineio_endpoint): pass -def translate_request(request): +def translate_request(request): # pragma: no cover """This function takes the arguments passed to the request handler and uses them to generate a WSGI compatible environ dictionary. """ @@ -89,7 +88,7 @@ def translate_request(request): return environ -def make_response(status, headers, payload, environ): +def make_response(status, headers, payload, environ): # pragma: no cover """This function generates an appropriate response object for this async mode. """ @@ -100,7 +99,7 @@ def make_response(status, headers, payload, environ): content_type = h[1] else: headers_dict[h[0]] = h[1] - return HTTPResponse(body_bytes=payload, content_type=content_type, + return HTTPResponse(body=payload, content_type=content_type, status=int(status.split()[0]), headers=headers_dict) @@ -129,8 +128,8 @@ class WebSocket(object): # pragma: no cover async def wait(self): data = await self._sock.recv() - if not isinstance(data, six.binary_type) and \ - not isinstance(data, six.text_type): + if not isinstance(data, bytes) and \ + not isinstance(data, str): raise IOError() return data diff --git a/libs/engineio/async_drivers/tornado.py b/libs/engineio/async_drivers/tornado.py index adfe18f5a..eb1c4de8a 100644 --- a/libs/engineio/async_drivers/tornado.py +++ b/libs/engineio/async_drivers/tornado.py @@ -5,15 +5,13 @@ from .. import exceptions import tornado.web import tornado.websocket -import six def get_tornado_handler(engineio_server): class Handler(tornado.websocket.WebSocketHandler): # pragma: no cover def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if isinstance(engineio_server.cors_allowed_origins, - six.string_types): + if isinstance(engineio_server.cors_allowed_origins, str): if engineio_server.cors_allowed_origins == '*': self.allowed_origins = None else: @@ -170,8 +168,8 @@ class WebSocket(object): # pragma: no cover async def wait(self): msg = await self.tornado_handler.get_next_message() - if not isinstance(msg, six.binary_type) and \ - not isinstance(msg, six.text_type): + if not isinstance(msg, bytes) and \ + not isinstance(msg, str): raise IOError() return msg diff --git a/libs/engineio/asyncio_client.py b/libs/engineio/asyncio_client.py index 049b4bd95..4a11eb3b2 100644 --- a/libs/engineio/asyncio_client.py +++ b/libs/engineio/asyncio_client.py @@ -1,17 +1,36 @@ import asyncio +import signal import ssl +import threading try: import aiohttp except ImportError: # pragma: no cover aiohttp = None -import six from . import client from . import exceptions from . import packet from . import payload +async_signal_handler_set = False + + +def async_signal_handler(): + """SIGINT handler. + + Disconnect all active async clients. + """ + async def _handler(): + asyncio.get_event_loop().stop() + for c in client.connected_clients[:]: + if c.is_asyncio_based(): + await c.disconnect() + else: # pragma: no cover + pass + + asyncio.ensure_future(_handler()) + class AsyncClient(client.Client): """An Engine.IO client for asyncio. @@ -22,13 +41,18 @@ class AsyncClient(client.Client): :param logger: To enable logging set to ``True`` or pass a logger object to use. To disable logging set to ``False``. The default is - ``False``. + ``False``. Note that fatal errors are logged even when + ``logger`` is ``False``. :param json: An alternative json module to use for encoding and decoding packets. Custom json modules must have ``dumps`` and ``loads`` functions that are compatible with the standard library versions. :param request_timeout: A timeout in seconds for requests. The default is 5 seconds. + :param http_session: an initialized ``aiohttp.ClientSession`` object to be + used when sending requests to the server. Use it if + you need to add special client options such as proxy + servers, SSL certificates, etc. :param ssl_verify: ``True`` to verify SSL certificates, or ``False`` to skip SSL certificate verification, allowing connections to servers with self signed certificates. @@ -37,7 +61,7 @@ class AsyncClient(client.Client): def is_asyncio_based(self): return True - async def connect(self, url, headers={}, transports=None, + async def connect(self, url, headers=None, transports=None, engineio_path='engine.io'): """Connect to an Engine.IO server. @@ -60,11 +84,22 @@ class AsyncClient(client.Client): eio = engineio.Client() await eio.connect('http://localhost:5000') """ + global async_signal_handler_set + if not async_signal_handler_set and \ + threading.current_thread() == threading.main_thread(): + + try: + asyncio.get_event_loop().add_signal_handler( + signal.SIGINT, async_signal_handler) + async_signal_handler_set = True + except NotImplementedError: # pragma: no cover + self.logger.warning('Signal handler is unsupported') + if self.state != 'disconnected': raise ValueError('Client is not in a disconnected state') valid_transports = ['polling', 'websocket'] if transports is not None: - if isinstance(transports, six.text_type): + if isinstance(transports, str): transports = [transports] transports = [transport for transport in transports if transport in valid_transports] @@ -73,7 +108,7 @@ class AsyncClient(client.Client): self.transports = transports or valid_transports self.queue = self.create_queue() return await getattr(self, '_connect_' + self.transports[0])( - url, headers, engineio_path) + url, headers or {}, engineio_path) async def wait(self): """Wait until the connection with the server ends. @@ -86,21 +121,16 @@ class AsyncClient(client.Client): if self.read_loop_task: await self.read_loop_task - async def send(self, data, binary=None): + async def send(self, data): """Send a message to a client. :param data: The data to send to the client. Data can be of type ``str``, ``bytes``, ``list`` or ``dict``. If a ``list`` or ``dict``, the data will be serialized as JSON. - :param binary: ``True`` to send packet as binary, ``False`` to send - as text. If not given, unicode (Python 2) and str - (Python 3) are sent as text, and str (Python 2) and - bytes (Python 3) are sent as binary. Note: this method is a coroutine. """ - await self._send_packet(packet.Packet(packet.MESSAGE, data=data, - binary=binary)) + await self._send_packet(packet.Packet(packet.MESSAGE, data=data)) async def disconnect(self, abort=False): """Disconnect from the server. @@ -182,14 +212,20 @@ class AsyncClient(client.Client): raise exceptions.ConnectionError( 'Connection refused by the server') if r.status < 200 or r.status >= 300: + self._reset() + try: + arg = await r.json() + except aiohttp.ClientError: + arg = None raise exceptions.ConnectionError( 'Unexpected status code {} in server response'.format( - r.status)) + r.status), arg) try: - p = payload.Payload(encoded_payload=await r.read()) + p = payload.Payload(encoded_payload=(await r.read()).decode( + 'utf-8')) except ValueError: - six.raise_from(exceptions.ConnectionError( - 'Unexpected response from server'), None) + raise exceptions.ConnectionError( + 'Unexpected response from server') from None open_packet = p.packets[0] if open_packet.packet_type != packet.OPEN: raise exceptions.ConnectionError( @@ -198,8 +234,8 @@ class AsyncClient(client.Client): 'Polling connection accepted with ' + str(open_packet.data)) self.sid = open_packet.data['sid'] self.upgrades = open_packet.data['upgrades'] - self.ping_interval = open_packet.data['pingInterval'] / 1000.0 - self.ping_timeout = open_packet.data['pingTimeout'] / 1000.0 + self.ping_interval = int(open_packet.data['pingInterval']) / 1000.0 + self.ping_timeout = int(open_packet.data['pingTimeout']) / 1000.0 self.current_transport = 'polling' self.base_url += '&sid=' + self.sid @@ -216,7 +252,6 @@ class AsyncClient(client.Client): # upgrade to websocket succeeded, we're done here return - self.ping_loop_task = self.start_background_task(self._ping_loop) self.write_loop_task = self.start_background_task(self._write_loop) self.read_loop_task = self.start_background_task( self._read_loop_polling) @@ -242,6 +277,17 @@ class AsyncClient(client.Client): if self.http is None or self.http.closed: # pragma: no cover self.http = aiohttp.ClientSession() + # extract any new cookies passed in a header so that they can also be + # sent the the WebSocket route + cookies = {} + for header, value in headers.items(): + if header.lower() == 'cookie': + cookies = dict( + [cookie.split('=', 1) for cookie in value.split('; ')]) + del headers[header] + break + self.http.cookie_jar.update_cookies(cookies) + try: if not self.ssl_verify: ssl_context = ssl.create_default_context() @@ -255,7 +301,8 @@ class AsyncClient(client.Client): websocket_url + self._get_url_timestamp(), headers=headers) except (aiohttp.client_exceptions.WSServerHandshakeError, - aiohttp.client_exceptions.ServerConnectionError): + aiohttp.client_exceptions.ServerConnectionError, + aiohttp.client_exceptions.ClientConnectionError): if upgrade: self.logger.warning( 'WebSocket upgrade failed: connection error') @@ -263,8 +310,7 @@ class AsyncClient(client.Client): else: raise exceptions.ConnectionError('Connection error') if upgrade: - p = packet.Packet(packet.PING, data='probe').encode( - always_bytes=False) + p = packet.Packet(packet.PING, data='probe').encode() try: await ws.send_str(p) except Exception as e: # pragma: no cover @@ -284,7 +330,7 @@ class AsyncClient(client.Client): self.logger.warning( 'WebSocket upgrade failed: no PONG packet') return False - p = packet.Packet(packet.UPGRADE).encode(always_bytes=False) + p = packet.Packet(packet.UPGRADE).encode() try: await ws.send_str(p) except Exception as e: # pragma: no cover @@ -307,8 +353,8 @@ class AsyncClient(client.Client): 'WebSocket connection accepted with ' + str(open_packet.data)) self.sid = open_packet.data['sid'] self.upgrades = open_packet.data['upgrades'] - self.ping_interval = open_packet.data['pingInterval'] / 1000.0 - self.ping_timeout = open_packet.data['pingTimeout'] / 1000.0 + self.ping_interval = int(open_packet.data['pingInterval']) / 1000.0 + self.ping_timeout = int(open_packet.data['pingTimeout']) / 1000.0 self.current_transport = 'websocket' self.state = 'connected' @@ -316,7 +362,6 @@ class AsyncClient(client.Client): await self._trigger_event('connect', run_async=False) self.ws = ws - self.ping_loop_task = self.start_background_task(self._ping_loop) self.write_loop_task = self.start_background_task(self._write_loop) self.read_loop_task = self.start_background_task( self._read_loop_websocket) @@ -331,8 +376,8 @@ class AsyncClient(client.Client): pkt.data if not isinstance(pkt.data, bytes) else '<binary>') if pkt.packet_type == packet.MESSAGE: await self._trigger_event('message', pkt.data, run_async=True) - elif pkt.packet_type == packet.PONG: - self.pong_received = True + elif pkt.packet_type == packet.PING: + await self._send_packet(packet.Packet(packet.PONG, pkt.data)) elif pkt.packet_type == packet.CLOSE: await self.disconnect(abort=True) elif pkt.packet_type == packet.NOOP: @@ -409,33 +454,6 @@ class AsyncClient(client.Client): return False return ret - async def _ping_loop(self): - """This background task sends a PING to the server at the requested - interval. - """ - self.pong_received = True - if self.ping_loop_event is None: - self.ping_loop_event = self.create_event() - else: - self.ping_loop_event.clear() - while self.state == 'connected': - if not self.pong_received: - self.logger.info( - 'PONG response has not been received, aborting') - if self.ws: - await self.ws.close() - await self.queue.put(None) - break - self.pong_received = False - await self._send_packet(packet.Packet(packet.PING)) - try: - await asyncio.wait_for(self.ping_loop_event.wait(), - self.ping_interval) - except (asyncio.TimeoutError, - asyncio.CancelledError): # pragma: no cover - pass - self.logger.info('Exiting ping task') - async def _read_loop_polling(self): """Read packets by polling the Engine.IO server.""" while self.state == 'connected': @@ -455,7 +473,8 @@ class AsyncClient(client.Client): await self.queue.put(None) break try: - p = payload.Payload(encoded_payload=await r.read()) + p = payload.Payload(encoded_payload=(await r.read()).decode( + 'utf-8')) except ValueError: self.logger.warning( 'Unexpected packet from server, aborting') @@ -466,10 +485,6 @@ class AsyncClient(client.Client): self.logger.info('Waiting for write loop task to end') await self.write_loop_task - self.logger.info('Waiting for ping loop task to end') - if self.ping_loop_event: # pragma: no cover - self.ping_loop_event.set() - await self.ping_loop_task if self.state == 'connected': await self._trigger_event('disconnect', run_async=False) try: @@ -484,9 +499,18 @@ class AsyncClient(client.Client): while self.state == 'connected': p = None try: - p = (await self.ws.receive()).data + p = await asyncio.wait_for( + self.ws.receive(), + timeout=self.ping_interval + self.ping_timeout) + p = p.data if p is None: # pragma: no cover - raise RuntimeError('WebSocket read returned None') + await self.queue.put(None) + break # the connection is broken + except asyncio.TimeoutError: + self.logger.warning( + 'Server has stopped communicating, aborting') + await self.queue.put(None) + break except aiohttp.client_exceptions.ServerDisconnectedError: self.logger.info( 'Read loop: WebSocket connection was closed, aborting') @@ -494,20 +518,21 @@ class AsyncClient(client.Client): break except Exception as e: self.logger.info( - 'Unexpected error "%s", aborting', str(e)) + 'Unexpected error receiving packet: "%s", aborting', + str(e)) + await self.queue.put(None) + break + try: + pkt = packet.Packet(encoded_packet=p) + except Exception as e: # pragma: no cover + self.logger.info( + 'Unexpected error decoding packet: "%s", aborting', str(e)) await self.queue.put(None) break - if isinstance(p, six.text_type): # pragma: no cover - p = p.encode('utf-8') - pkt = packet.Packet(encoded_packet=p) await self._receive_packet(pkt) self.logger.info('Waiting for write loop task to end') await self.write_loop_task - self.logger.info('Waiting for ping loop task to end') - if self.ping_loop_event: # pragma: no cover - self.ping_loop_event.set() - await self.ping_loop_task if self.state == 'connected': await self._trigger_event('disconnect', run_async=False) try: @@ -571,13 +596,12 @@ class AsyncClient(client.Client): try: for pkt in packets: if pkt.binary: - await self.ws.send_bytes(pkt.encode( - always_bytes=False)) + await self.ws.send_bytes(pkt.encode()) else: - await self.ws.send_str(pkt.encode( - always_bytes=False)) + await self.ws.send_str(pkt.encode()) self.queue.task_done() - except aiohttp.client_exceptions.ServerDisconnectedError: + except (aiohttp.client_exceptions.ServerDisconnectedError, + BrokenPipeError, OSError): self.logger.info( 'Write loop: WebSocket connection was closed, ' 'aborting') diff --git a/libs/engineio/asyncio_server.py b/libs/engineio/asyncio_server.py index d52b556db..6639f26bf 100644 --- a/libs/engineio/asyncio_server.py +++ b/libs/engineio/asyncio_server.py @@ -1,7 +1,5 @@ import asyncio - -import six -from six.moves import urllib +import urllib from . import exceptions from . import packet @@ -24,23 +22,30 @@ class AsyncServer(server.Server): "tornado", and finally "asgi". The first async mode that has all its dependencies installed is the one that is chosen. - :param ping_timeout: The time in seconds that the client waits for the - server to respond before disconnecting. - :param ping_interval: The interval in seconds at which the client pings - the server. The default is 25 seconds. For advanced + :param ping_interval: The interval in seconds at which the server pings + the client. The default is 25 seconds. For advanced control, a two element tuple can be given, where the first number is the ping interval and the second - is a grace period added by the server. The default - grace period is 5 seconds. + is a grace period added by the server. + :param ping_timeout: The time in seconds that the client waits for the + server to respond before disconnecting. The default + is 5 seconds. :param max_http_buffer_size: The maximum size of a message when using the - polling transport. + polling transport. The default is 1,000,000 + bytes. :param allow_upgrades: Whether to allow transport upgrades or not. :param http_compression: Whether to compress packages when using the polling transport. :param compression_threshold: Only compress messages when their byte size is greater than this value. - :param cookie: Name of the HTTP cookie that contains the client session - id. If set to ``None``, a cookie is not sent to the client. + :param cookie: If set to a string, it is the name of the HTTP cookie the + server sends back tot he client containing the client + session id. If set to a dictionary, the ``'name'`` key + contains the cookie name and other keys define cookie + attributes, where the value of each attribute can be a + string, a callable with no arguments, or a boolean. If set + to ``None`` (the default), a cookie is not sent to the + client. :param cors_allowed_origins: Origin or list of origins that are allowed to connect to this server. Only the same origin is allowed by default. Set this argument to @@ -49,7 +54,8 @@ class AsyncServer(server.Server): :param cors_credentials: Whether credentials (cookies, authentication) are allowed in requests to this server. :param logger: To enable logging set to ``True`` or pass a logger object to - use. To disable logging set to ``False``. + use. To disable logging set to ``False``. Note that fatal + errors are logged even when ``logger`` is ``False``. :param json: An alternative json module to use for encoding and decoding packets. Custom json modules must have ``dumps`` and ``loads`` functions that are compatible with the standard library @@ -71,17 +77,13 @@ class AsyncServer(server.Server): engineio_path = engineio_path.strip('/') self._async['create_route'](app, self, '/{}/'.format(engineio_path)) - async def send(self, sid, data, binary=None): + async def send(self, sid, data): """Send a message to a client. :param sid: The session id of the recipient client. :param data: The data to send to the client. Data can be of type ``str``, ``bytes``, ``list`` or ``dict``. If a ``list`` or ``dict``, the data will be serialized as JSON. - :param binary: ``True`` to send packet as binary, ``False`` to send - as text. If not given, unicode (Python 2) and str - (Python 3) are sent as text, and str (Python 2) and - bytes (Python 3) are sent as binary. Note: this method is a coroutine. """ @@ -91,8 +93,7 @@ class AsyncServer(server.Server): # the socket is not available self.logger.warning('Cannot send to sid %s', sid) return - await socket.send(packet.Packet(packet.MESSAGE, data=data, - binary=binary)) + await socket.send(packet.Packet(packet.MESSAGE, data=data)) async def get_session(self, sid): """Return the user session for a client. @@ -172,7 +173,7 @@ class AsyncServer(server.Server): del self.sockets[sid] else: await asyncio.wait([client.close() - for client in six.itervalues(self.sockets)]) + for client in self.sockets.values()]) self.sockets = {} async def handle_request(self, *args, **kwargs): @@ -198,28 +199,32 @@ class AsyncServer(server.Server): allowed_origins = self._cors_allowed_origins(environ) if allowed_origins is not None and origin not in \ allowed_origins: - self.logger.info(origin + ' is not an accepted origin.') - r = self._bad_request() - make_response = self._async['make_response'] - if asyncio.iscoroutinefunction(make_response): - response = await make_response( - r['status'], r['headers'], r['response'], environ) - else: - response = make_response(r['status'], r['headers'], - r['response'], environ) - return response + self._log_error_once( + origin + ' is not an accepted origin.', 'bad-origin') + return await self._make_response( + self._bad_request( + origin + ' is not an accepted origin.'), + environ) method = environ['REQUEST_METHOD'] query = urllib.parse.parse_qs(environ.get('QUERY_STRING', '')) sid = query['sid'][0] if 'sid' in query else None - b64 = False jsonp = False jsonp_index = None - if 'b64' in query: - if query['b64'][0] == "1" or query['b64'][0].lower() == "true": - b64 = True + # make sure the client speaks a compatible Engine.IO version + sid = query['sid'][0] if 'sid' in query else None + if sid is None and query.get('EIO') != ['4']: + self._log_error_once( + 'The client is using an unsupported version of the Socket.IO ' + 'or Engine.IO protocols', 'bad-version' + ) + return await self._make_response(self._bad_request( + 'The client is using an unsupported version of the Socket.IO ' + 'or Engine.IO protocols' + ), environ) + if 'j' in query: jsonp = True try: @@ -229,28 +234,34 @@ class AsyncServer(server.Server): pass if jsonp and jsonp_index is None: - self.logger.warning('Invalid JSONP index number') - r = self._bad_request() + self._log_error_once('Invalid JSONP index number', + 'bad-jsonp-index') + r = self._bad_request('Invalid JSONP index number') elif method == 'GET': if sid is None: transport = query.get('transport', ['polling'])[0] - if transport != 'polling' and transport != 'websocket': - self.logger.warning('Invalid transport %s', transport) - r = self._bad_request() - else: + # transport must be one of 'polling' or 'websocket'. + # if 'websocket', the HTTP_UPGRADE header must match. + upgrade_header = environ.get('HTTP_UPGRADE').lower() \ + if 'HTTP_UPGRADE' in environ else None + if transport == 'polling' \ + or transport == upgrade_header == 'websocket': r = await self._handle_connect(environ, transport, - b64, jsonp_index) + jsonp_index) + else: + self._log_error_once('Invalid transport ' + transport, + 'bad-transport') + r = self._bad_request('Invalid transport ' + transport) else: if sid not in self.sockets: - self.logger.warning('Invalid session %s', sid) - r = self._bad_request() + self._log_error_once('Invalid session ' + sid, 'bad-sid') + r = self._bad_request('Invalid session ' + sid) else: socket = self._get_socket(sid) try: packets = await socket.handle_get_request(environ) if isinstance(packets, list): - r = self._ok(packets, b64=b64, - jsonp_index=jsonp_index) + r = self._ok(packets, jsonp_index=jsonp_index) else: r = packets except exceptions.EngineIOError: @@ -261,8 +272,8 @@ class AsyncServer(server.Server): del self.sockets[sid] elif method == 'POST': if sid is None or sid not in self.sockets: - self.logger.warning('Invalid session %s', sid) - r = self._bad_request() + self._log_error_once('Invalid session ' + sid, 'bad-sid') + r = self._bad_request('Invalid session ' + sid) else: socket = self._get_socket(sid) try: @@ -294,16 +305,7 @@ class AsyncServer(server.Server): getattr(self, '_' + encoding)(r['response']) r['headers'] += [('Content-Encoding', encoding)] break - cors_headers = self._cors_headers(environ) - make_response = self._async['make_response'] - if asyncio.iscoroutinefunction(make_response): - response = await make_response(r['status'], - r['headers'] + cors_headers, - r['response'], environ) - else: - response = make_response(r['status'], r['headers'] + cors_headers, - r['response'], environ) - return response + return await self._make_response(r, environ) def start_background_task(self, target, *args, **kwargs): """Start a background task using the appropriate async model. @@ -362,15 +364,29 @@ class AsyncServer(server.Server): """ return asyncio.Event(*args, **kwargs) - async def _handle_connect(self, environ, transport, b64=False, - jsonp_index=None): + async def _make_response(self, response_dict, environ): + cors_headers = self._cors_headers(environ) + make_response = self._async['make_response'] + if asyncio.iscoroutinefunction(make_response): + response = await make_response( + response_dict['status'], + response_dict['headers'] + cors_headers, + response_dict['response'], environ) + else: + response = make_response( + response_dict['status'], + response_dict['headers'] + cors_headers, + response_dict['response'], environ) + return response + + async def _handle_connect(self, environ, transport, jsonp_index=None): """Handle a client connection request.""" if self.start_service_task: # start the service task to monitor connected clients self.start_service_task = False self.start_background_task(self._service_task) - sid = self._generate_id() + sid = self.generate_id() s = asyncio_socket.AsyncSocket(self, sid) self.sockets[sid] = s @@ -380,17 +396,18 @@ class AsyncServer(server.Server): 'pingTimeout': int(self.ping_timeout * 1000), 'pingInterval': int(self.ping_interval * 1000)}) await s.send(pkt) + s.schedule_ping() ret = await self._trigger_event('connect', sid, environ, run_async=False) - if ret is False: + if ret is not None and ret is not True: del self.sockets[sid] self.logger.warning('Application rejected connection') - return self._unauthorized() + return self._unauthorized(ret or None) if transport == 'websocket': ret = await s.handle_get_request(environ) - if s.closed: + if s.closed and sid in self.sockets: # websocket connection ended, so we are done del self.sockets[sid] return ret @@ -398,9 +415,20 @@ class AsyncServer(server.Server): s.connected = True headers = None if self.cookie: - headers = [('Set-Cookie', self.cookie + '=' + sid)] + if isinstance(self.cookie, dict): + headers = [( + 'Set-Cookie', + self._generate_sid_cookie(sid, self.cookie) + )] + else: + headers = [( + 'Set-Cookie', + self._generate_sid_cookie(sid, { + 'name': self.cookie, 'path': '/', 'SameSite': 'Lax' + }) + )] try: - return self._ok(await s.poll(), headers=headers, b64=b64, + return self._ok(await s.poll(), headers=headers, jsonp_index=jsonp_index) except exceptions.QueueEmpty: return self._bad_request() @@ -459,7 +487,12 @@ class AsyncServer(server.Server): if not socket.closing and not socket.closed: await socket.check_ping_timeout() await self.sleep(sleep_interval) - except (SystemExit, KeyboardInterrupt, asyncio.CancelledError): + except ( + SystemExit, + KeyboardInterrupt, + asyncio.CancelledError, + GeneratorExit, + ): self.logger.info('service task canceled') break except: diff --git a/libs/engineio/asyncio_socket.py b/libs/engineio/asyncio_socket.py index 7057a6cc3..508ee3ca2 100644 --- a/libs/engineio/asyncio_socket.py +++ b/libs/engineio/asyncio_socket.py @@ -1,5 +1,4 @@ import asyncio -import six import sys import time @@ -13,18 +12,24 @@ class AsyncSocket(socket.Socket): async def poll(self): """Wait for packets to send to the client.""" try: - packets = [await asyncio.wait_for(self.queue.get(), - self.server.ping_timeout)] + packets = [await asyncio.wait_for( + self.queue.get(), + self.server.ping_interval + self.server.ping_timeout)] self.queue.task_done() except (asyncio.TimeoutError, asyncio.CancelledError): raise exceptions.QueueEmpty() if packets == [None]: return [] - try: - packets.append(self.queue.get_nowait()) - self.queue.task_done() - except asyncio.QueueEmpty: - pass + while True: + try: + pkt = self.queue.get_nowait() + self.queue.task_done() + if pkt is None: + self.queue.put_nowait(None) + break + packets.append(pkt) + except asyncio.QueueEmpty: + break return packets async def receive(self, pkt): @@ -33,9 +38,8 @@ class AsyncSocket(socket.Socket): self.sid, packet.packet_names[pkt.packet_type], pkt.data if not isinstance(pkt.data, bytes) else '<binary>') - if pkt.packet_type == packet.PING: - self.last_ping = time.time() - await self.send(packet.Packet(packet.PONG, pkt.data)) + if pkt.packet_type == packet.PONG: + self.schedule_ping() elif pkt.packet_type == packet.MESSAGE: await self.server._trigger_event( 'message', self.sid, pkt.data, @@ -48,14 +52,11 @@ class AsyncSocket(socket.Socket): raise exceptions.UnknownPacketError() async def check_ping_timeout(self): - """Make sure the client is still sending pings. - - This helps detect disconnections for long-polling clients. - """ + """Make sure the client is still sending pings.""" if self.closed: raise exceptions.SocketIsClosedError() - if time.time() - self.last_ping > self.server.ping_interval + \ - self.server.ping_interval_grace_period: + if self.last_ping and \ + time.time() - self.last_ping > self.server.ping_timeout: self.server.logger.info('%s: Client is gone, closing socket', self.sid) # Passing abort=False here will cause close() to write a @@ -69,8 +70,6 @@ class AsyncSocket(socket.Socket): """Send a packet to the client.""" if not await self.check_ping_timeout(): return - if self.upgrading: - self.packet_backlog.append(pkt) else: await self.queue.put(pkt) self.server.logger.info('%s: Sending packet %s data %s', @@ -88,12 +87,16 @@ class AsyncSocket(socket.Socket): self.server.logger.info('%s: Received request to upgrade to %s', self.sid, transport) return await getattr(self, '_upgrade_' + transport)(environ) + if self.upgrading or self.upgraded: + # we are upgrading to WebSocket, do not return any more packets + # through the polling endpoint + return [packet.Packet(packet.NOOP)] try: packets = await self.poll() except exceptions.QueueEmpty: exc = sys.exc_info() await self.close(wait=False) - six.reraise(*exc) + raise exc[1].with_traceback(exc[2]) return packets async def handle_post_request(self, environ): @@ -102,7 +105,7 @@ class AsyncSocket(socket.Socket): if length > self.server.max_http_buffer_size: raise exceptions.ContentTooLongError() else: - body = await environ['wsgi.input'].read(length) + body = (await environ['wsgi.input'].read(length)).decode('utf-8') p = payload.Payload(encoded_payload=body) for pkt in p.packets: await self.receive(pkt) @@ -118,6 +121,16 @@ class AsyncSocket(socket.Socket): if wait: await self.queue.join() + def schedule_ping(self): + async def send_ping(): + self.last_ping = None + await asyncio.sleep(self.server.ping_interval) + if not self.closing and not self.closed: + self.last_ping = time.time() + await self.send(packet.Packet(packet.PING)) + + self.server.start_background_task(send_ping) + async def _upgrade_websocket(self, environ): """Upgrade the connection from polling to websocket.""" if self.upgraded: @@ -143,15 +156,15 @@ class AsyncSocket(socket.Socket): decoded_pkt.data != 'probe': self.server.logger.info( '%s: Failed websocket upgrade, no PING packet', self.sid) + self.upgrading = False return - await ws.send(packet.Packet( - packet.PONG, - data=six.text_type('probe')).encode(always_bytes=False)) + await ws.send(packet.Packet(packet.PONG, data='probe').encode()) await self.queue.put(packet.Packet(packet.NOOP)) # end poll try: pkt = await ws.wait() except IOError: # pragma: no cover + self.upgrading = False return decoded_pkt = packet.Packet(encoded_packet=pkt) if decoded_pkt.packet_type != packet.UPGRADE: @@ -160,13 +173,9 @@ class AsyncSocket(socket.Socket): ('%s: Failed websocket upgrade, expected UPGRADE packet, ' 'received %s instead.'), self.sid, pkt) + self.upgrading = False return self.upgraded = True - - # flush any packets that were sent during the upgrade - for pkt in self.packet_backlog: - await self.queue.put(pkt) - self.packet_backlog = [] self.upgrading = False else: self.connected = True @@ -185,7 +194,7 @@ class AsyncSocket(socket.Socket): break try: for pkt in packets: - await ws.send(pkt.encode(always_bytes=False)) + await ws.send(pkt.encode()) except: break writer_task = asyncio.ensure_future(writer()) @@ -197,7 +206,9 @@ class AsyncSocket(socket.Socket): p = None wait_task = asyncio.ensure_future(ws.wait()) try: - p = await asyncio.wait_for(wait_task, self.server.ping_timeout) + p = await asyncio.wait_for( + wait_task, + self.server.ping_interval + self.server.ping_timeout) except asyncio.CancelledError: # pragma: no cover # there is a bug (https://bugs.python.org/issue30508) in # asyncio that causes a "Task exception never retrieved" error @@ -216,8 +227,6 @@ class AsyncSocket(socket.Socket): if p is None: # connection closed by client break - if isinstance(p, six.text_type): # pragma: no cover - p = p.encode('utf-8') pkt = packet.Packet(encoded_packet=p) try: await self.receive(pkt) diff --git a/libs/engineio/client.py b/libs/engineio/client.py index b5ab50377..d307a5d62 100644 --- a/libs/engineio/client.py +++ b/libs/engineio/client.py @@ -1,3 +1,5 @@ +from base64 import b64encode +from json import JSONDecodeError import logging try: import queue @@ -7,9 +9,8 @@ import signal import ssl import threading import time +import urllib -import six -from six.moves import urllib try: import requests except ImportError: # pragma: no cover @@ -25,9 +26,6 @@ from . import payload default_logger = logging.getLogger('engineio.client') connected_clients = [] -if six.PY2: # pragma: no cover - ConnectionError = OSError - def signal_handler(sig, frame): """SIGINT handler. @@ -35,10 +33,8 @@ def signal_handler(sig, frame): Disconnect all active clients and then invoke the original signal handler. """ for client in connected_clients[:]: - if client.is_asyncio_based(): - client.start_background_task(client.disconnect, abort=True) - else: - client.disconnect(abort=True) + if not client.is_asyncio_based(): + client.disconnect() if callable(original_signal_handler): return original_signal_handler(sig, frame) else: # pragma: no cover @@ -57,13 +53,18 @@ class Client(object): :param logger: To enable logging set to ``True`` or pass a logger object to use. To disable logging set to ``False``. The default is - ``False``. + ``False``. Note that fatal errors are logged even when + ``logger`` is ``False``. :param json: An alternative json module to use for encoding and decoding packets. Custom json modules must have ``dumps`` and ``loads`` functions that are compatible with the standard library versions. :param request_timeout: A timeout in seconds for requests. The default is 5 seconds. + :param http_session: an initialized ``requests.Session`` object to be used + when sending requests to the server. Use it if you + need to add special client options such as proxy + servers, SSL certificates, custom CA bundle, etc. :param ssl_verify: ``True`` to verify SSL certificates, or ``False`` to skip SSL certificate verification, allowing connections to servers with self signed certificates. @@ -75,9 +76,11 @@ class Client(object): logger=False, json=None, request_timeout=5, + http_session=None, ssl_verify=True): global original_signal_handler - if original_signal_handler is None: + if original_signal_handler is None and \ + threading.current_thread() == threading.main_thread(): original_signal_handler = signal.signal(signal.SIGINT, signal_handler) self.handlers = {} @@ -88,13 +91,10 @@ class Client(object): self.upgrades = None self.ping_interval = None self.ping_timeout = None - self.pong_received = True - self.http = None + self.http = http_session self.ws = None self.read_loop_task = None self.write_loop_task = None - self.ping_loop_task = None - self.ping_loop_event = None self.queue = None self.state = 'disconnected' self.ssl_verify = ssl_verify @@ -105,8 +105,7 @@ class Client(object): self.logger = logger else: self.logger = default_logger - if not logging.root.handlers and \ - self.logger.level == logging.NOTSET: + if self.logger.level == logging.NOTSET: if logger: self.logger.setLevel(logging.INFO) else: @@ -151,7 +150,7 @@ class Client(object): return set_handler set_handler(handler) - def connect(self, url, headers={}, transports=None, + def connect(self, url, headers=None, transports=None, engineio_path='engine.io'): """Connect to an Engine.IO server. @@ -176,7 +175,7 @@ class Client(object): raise ValueError('Client is not in a disconnected state') valid_transports = ['polling', 'websocket'] if transports is not None: - if isinstance(transports, six.string_types): + if isinstance(transports, str): transports = [transports] transports = [transport for transport in transports if transport in valid_transports] @@ -185,7 +184,7 @@ class Client(object): self.transports = transports or valid_transports self.queue = self.create_queue() return getattr(self, '_connect_' + self.transports[0])( - url, headers, engineio_path) + url, headers or {}, engineio_path) def wait(self): """Wait until the connection with the server ends. @@ -196,19 +195,14 @@ class Client(object): if self.read_loop_task: self.read_loop_task.join() - def send(self, data, binary=None): + def send(self, data): """Send a message to a client. :param data: The data to send to the client. Data can be of type ``str``, ``bytes``, ``list`` or ``dict``. If a ``list`` or ``dict``, the data will be serialized as JSON. - :param binary: ``True`` to send packet as binary, ``False`` to send - as text. If not given, unicode (Python 2) and str - (Python 3) are sent as text, and str (Python 2) and - bytes (Python 3) are sent as binary. """ - self._send_packet(packet.Packet(packet.MESSAGE, data=data, - binary=binary)) + self._send_packet(packet.Packet(packet.MESSAGE, data=data)) def disconnect(self, abort=False): """Disconnect from the server. @@ -293,14 +287,19 @@ class Client(object): raise exceptions.ConnectionError( 'Connection refused by the server') if r.status_code < 200 or r.status_code >= 300: + self._reset() + try: + arg = r.json() + except JSONDecodeError: + arg = None raise exceptions.ConnectionError( 'Unexpected status code {} in server response'.format( - r.status_code)) + r.status_code), arg) try: - p = payload.Payload(encoded_payload=r.content) + p = payload.Payload(encoded_payload=r.content.decode('utf-8')) except ValueError: - six.raise_from(exceptions.ConnectionError( - 'Unexpected response from server'), None) + raise exceptions.ConnectionError( + 'Unexpected response from server') from None open_packet = p.packets[0] if open_packet.packet_type != packet.OPEN: raise exceptions.ConnectionError( @@ -309,8 +308,8 @@ class Client(object): 'Polling connection accepted with ' + str(open_packet.data)) self.sid = open_packet.data['sid'] self.upgrades = open_packet.data['upgrades'] - self.ping_interval = open_packet.data['pingInterval'] / 1000.0 - self.ping_timeout = open_packet.data['pingTimeout'] / 1000.0 + self.ping_interval = int(open_packet.data['pingInterval']) / 1000.0 + self.ping_timeout = int(open_packet.data['pingTimeout']) / 1000.0 self.current_transport = 'polling' self.base_url += '&sid=' + self.sid @@ -328,7 +327,6 @@ class Client(object): return # start background tasks associated with this client - self.ping_loop_task = self.start_background_task(self._ping_loop) self.write_loop_task = self.start_background_task(self._write_loop) self.read_loop_task = self.start_background_task( self._read_loop_polling) @@ -337,8 +335,8 @@ class Client(object): """Establish or upgrade to a WebSocket connection with the server.""" if websocket is None: # pragma: no cover # not installed - self.logger.warning('websocket-client package not installed, only ' - 'polling transport is available') + self.logger.error('websocket-client package not installed, only ' + 'polling transport is available') return False websocket_url = self._get_engineio_url(url, engineio_path, 'websocket') if self.sid: @@ -352,22 +350,75 @@ class Client(object): self.logger.info( 'Attempting WebSocket connection to ' + websocket_url) - # get the cookies from the long-polling connection so that they can - # also be sent the the WebSocket route + # get cookies and other settings from the long-polling connection + # so that they are preserved when connecting to the WebSocket route cookies = None + extra_options = {} if self.http: + # cookies cookies = '; '.join(["{}={}".format(cookie.name, cookie.value) for cookie in self.http.cookies]) + for header, value in headers.items(): + if header.lower() == 'cookie': + if cookies: + cookies += '; ' + cookies += value + del headers[header] + break + # auth + if 'Authorization' not in headers and self.http.auth is not None: + if not isinstance(self.http.auth, tuple): # pragma: no cover + raise ValueError('Only basic authentication is supported') + basic_auth = '{}:{}'.format( + self.http.auth[0], self.http.auth[1]).encode('utf-8') + basic_auth = b64encode(basic_auth).decode('utf-8') + headers['Authorization'] = 'Basic ' + basic_auth + + # cert + # this can be given as ('certfile', 'keyfile') or just 'certfile' + if isinstance(self.http.cert, tuple): + extra_options['sslopt'] = { + 'certfile': self.http.cert[0], + 'keyfile': self.http.cert[1]} + elif self.http.cert: + extra_options['sslopt'] = {'certfile': self.http.cert} + + # proxies + if self.http.proxies: + proxy_url = None + if websocket_url.startswith('ws://'): + proxy_url = self.http.proxies.get( + 'ws', self.http.proxies.get('http')) + else: # wss:// + proxy_url = self.http.proxies.get( + 'wss', self.http.proxies.get('https')) + if proxy_url: + parsed_url = urllib.parse.urlparse( + proxy_url if '://' in proxy_url + else 'scheme://' + proxy_url) + extra_options['http_proxy_host'] = parsed_url.hostname + extra_options['http_proxy_port'] = parsed_url.port + extra_options['http_proxy_auth'] = ( + (parsed_url.username, parsed_url.password) + if parsed_url.username or parsed_url.password + else None) + + # verify + if isinstance(self.http.verify, str): + if 'sslopt' in extra_options: + extra_options['sslopt']['ca_certs'] = self.http.verify + else: + extra_options['sslopt'] = {'ca_certs': self.http.verify} + elif not self.http.verify: + self.ssl_verify = False + + if not self.ssl_verify: + extra_options['sslopt'] = {"cert_reqs": ssl.CERT_NONE} try: - if not self.ssl_verify: - ws = websocket.create_connection( - websocket_url + self._get_url_timestamp(), header=headers, - cookie=cookies, sslopt={"cert_reqs": ssl.CERT_NONE}) - else: - ws = websocket.create_connection( - websocket_url + self._get_url_timestamp(), header=headers, - cookie=cookies) + ws = websocket.create_connection( + websocket_url + self._get_url_timestamp(), header=headers, + cookie=cookies, enable_multithread=True, **extra_options) except (ConnectionError, IOError, websocket.WebSocketException): if upgrade: self.logger.warning( @@ -376,8 +427,7 @@ class Client(object): else: raise exceptions.ConnectionError('Connection error') if upgrade: - p = packet.Packet(packet.PING, - data=six.text_type('probe')).encode() + p = packet.Packet(packet.PING, data='probe').encode() try: ws.send(p) except Exception as e: # pragma: no cover @@ -420,17 +470,17 @@ class Client(object): 'WebSocket connection accepted with ' + str(open_packet.data)) self.sid = open_packet.data['sid'] self.upgrades = open_packet.data['upgrades'] - self.ping_interval = open_packet.data['pingInterval'] / 1000.0 - self.ping_timeout = open_packet.data['pingTimeout'] / 1000.0 + self.ping_interval = int(open_packet.data['pingInterval']) / 1000.0 + self.ping_timeout = int(open_packet.data['pingTimeout']) / 1000.0 self.current_transport = 'websocket' self.state = 'connected' connected_clients.append(self) self._trigger_event('connect', run_async=False) self.ws = ws + self.ws.settimeout(self.ping_interval + self.ping_timeout) # start background tasks associated with this client - self.ping_loop_task = self.start_background_task(self._ping_loop) self.write_loop_task = self.start_background_task(self._write_loop) self.read_loop_task = self.start_background_task( self._read_loop_websocket) @@ -445,8 +495,8 @@ class Client(object): pkt.data if not isinstance(pkt.data, bytes) else '<binary>') if pkt.packet_type == packet.MESSAGE: self._trigger_event('message', pkt.data, run_async=True) - elif pkt.packet_type == packet.PONG: - self.pong_received = True + elif pkt.packet_type == packet.PING: + self._send_packet(packet.Packet(packet.PONG, pkt.data)) elif pkt.packet_type == packet.CLOSE: self.disconnect(abort=True) elif pkt.packet_type == packet.NOOP: @@ -470,9 +520,11 @@ class Client(object): timeout=None): # pragma: no cover if self.http is None: self.http = requests.Session() + if not self.ssl_verify: + self.http.verify = False try: return self.http.request(method, url, headers=headers, data=body, - timeout=timeout, verify=self.ssl_verify) + timeout=timeout) except requests.exceptions.RequestException as exc: self.logger.info('HTTP %s request to %s failed with error %s.', method, url, exc) @@ -504,7 +556,7 @@ class Client(object): scheme += 's' return ('{scheme}://{netloc}/{path}/?{query}' - '{sep}transport={transport}&EIO=3').format( + '{sep}transport={transport}&EIO=4').format( scheme=scheme, netloc=parsed_url.netloc, path=engineio_path, query=parsed_url.query, sep='&' if parsed_url.query else '', @@ -514,28 +566,6 @@ class Client(object): """Generate the Engine.IO query string timestamp.""" return '&t=' + str(time.time()) - def _ping_loop(self): - """This background task sends a PING to the server at the requested - interval. - """ - self.pong_received = True - if self.ping_loop_event is None: - self.ping_loop_event = self.create_event() - else: - self.ping_loop_event.clear() - while self.state == 'connected': - if not self.pong_received: - self.logger.info( - 'PONG response has not been received, aborting') - if self.ws: - self.ws.close(timeout=0) - self.queue.put(None) - break - self.pong_received = False - self._send_packet(packet.Packet(packet.PING)) - self.ping_loop_event.wait(timeout=self.ping_interval) - self.logger.info('Exiting ping task') - def _read_loop_polling(self): """Read packets by polling the Engine.IO server.""" while self.state == 'connected': @@ -555,7 +585,7 @@ class Client(object): self.queue.put(None) break try: - p = payload.Payload(encoded_payload=r.content) + p = payload.Payload(encoded_payload=r.content.decode('utf-8')) except ValueError: self.logger.warning( 'Unexpected packet from server, aborting') @@ -566,10 +596,6 @@ class Client(object): self.logger.info('Waiting for write loop task to end') self.write_loop_task.join() - self.logger.info('Waiting for ping loop task to end') - if self.ping_loop_event: # pragma: no cover - self.ping_loop_event.set() - self.ping_loop_task.join() if self.state == 'connected': self._trigger_event('disconnect', run_async=False) try: @@ -585,6 +611,11 @@ class Client(object): p = None try: p = self.ws.recv() + except websocket.WebSocketTimeoutException: + self.logger.warning( + 'Server has stopped communicating, aborting') + self.queue.put(None) + break except websocket.WebSocketConnectionClosedException: self.logger.warning( 'WebSocket connection was closed, aborting') @@ -592,20 +623,21 @@ class Client(object): break except Exception as e: self.logger.info( - 'Unexpected error "%s", aborting', str(e)) + 'Unexpected error receiving packet: "%s", aborting', + str(e)) + self.queue.put(None) + break + try: + pkt = packet.Packet(encoded_packet=p) + except Exception as e: # pragma: no cover + self.logger.info( + 'Unexpected error decoding packet: "%s", aborting', str(e)) self.queue.put(None) break - if isinstance(p, six.text_type): # pragma: no cover - p = p.encode('utf-8') - pkt = packet.Packet(encoded_packet=p) self._receive_packet(pkt) self.logger.info('Waiting for write loop task to end') self.write_loop_task.join() - self.logger.info('Waiting for ping loop task to end') - if self.ping_loop_event: # pragma: no cover - self.ping_loop_event.set() - self.ping_loop_task.join() if self.state == 'connected': self._trigger_event('disconnect', run_async=False) try: @@ -667,13 +699,14 @@ class Client(object): # websocket try: for pkt in packets: - encoded_packet = pkt.encode(always_bytes=False) + encoded_packet = pkt.encode() if pkt.binary: self.ws.send_binary(encoded_packet) else: self.ws.send(encoded_packet) self.queue.task_done() - except websocket.WebSocketConnectionClosedException: + except (websocket.WebSocketConnectionClosedException, + BrokenPipeError, OSError): self.logger.warning( 'WebSocket connection was closed, aborting') break diff --git a/libs/engineio/packet.py b/libs/engineio/packet.py index a3aa6d476..9dbd6c684 100644 --- a/libs/engineio/packet.py +++ b/libs/engineio/packet.py @@ -1,12 +1,10 @@ import base64 import json as _json -import six - (OPEN, CLOSE, PING, PONG, MESSAGE, UPGRADE, NOOP) = (0, 1, 2, 3, 4, 5, 6) packet_names = ['OPEN', 'CLOSE', 'PING', 'PONG', 'MESSAGE', 'UPGRADE', 'NOOP'] -binary_types = (six.binary_type, bytearray) +binary_types = (bytes, bytearray) class Packet(object): @@ -14,79 +12,61 @@ class Packet(object): json = _json - def __init__(self, packet_type=NOOP, data=None, binary=None, - encoded_packet=None): + def __init__(self, packet_type=NOOP, data=None, encoded_packet=None): self.packet_type = packet_type self.data = data - if binary is not None: - self.binary = binary - elif isinstance(data, six.text_type): + if isinstance(data, str): self.binary = False elif isinstance(data, binary_types): self.binary = True else: self.binary = False + if self.binary and self.packet_type != MESSAGE: + raise ValueError('Binary packets can only be of type MESSAGE') if encoded_packet: self.decode(encoded_packet) - def encode(self, b64=False, always_bytes=True): + def encode(self, b64=False): """Encode the packet for transmission.""" - if self.binary and not b64: - encoded_packet = six.int2byte(self.packet_type) - else: - encoded_packet = six.text_type(self.packet_type) - if self.binary and b64: - encoded_packet = 'b' + encoded_packet if self.binary: if b64: - encoded_packet += base64.b64encode(self.data).decode('utf-8') + encoded_packet = 'b' + base64.b64encode(self.data).decode( + 'utf-8') else: + encoded_packet = self.data + else: + encoded_packet = str(self.packet_type) + if isinstance(self.data, str): encoded_packet += self.data - elif isinstance(self.data, six.string_types): - encoded_packet += self.data - elif isinstance(self.data, dict) or isinstance(self.data, list): - encoded_packet += self.json.dumps(self.data, - separators=(',', ':')) - elif self.data is not None: - encoded_packet += str(self.data) - if always_bytes and not isinstance(encoded_packet, binary_types): - encoded_packet = encoded_packet.encode('utf-8') + elif isinstance(self.data, dict) or isinstance(self.data, list): + encoded_packet += self.json.dumps(self.data, + separators=(',', ':')) + elif self.data is not None: + encoded_packet += str(self.data) return encoded_packet def decode(self, encoded_packet): """Decode a transmitted package.""" - b64 = False - if not isinstance(encoded_packet, binary_types): - encoded_packet = encoded_packet.encode('utf-8') - elif not isinstance(encoded_packet, bytes): - encoded_packet = bytes(encoded_packet) - self.packet_type = six.byte2int(encoded_packet[0:1]) - if self.packet_type == 98: # 'b' --> binary base64 encoded packet + self.binary = isinstance(encoded_packet, binary_types) + b64 = not self.binary and encoded_packet[0] == 'b' + if b64: self.binary = True - encoded_packet = encoded_packet[1:] - self.packet_type = six.byte2int(encoded_packet[0:1]) - self.packet_type -= 48 - b64 = True - elif self.packet_type >= 48: - self.packet_type -= 48 - self.binary = False + self.packet_type = MESSAGE + self.data = base64.b64decode(encoded_packet[1:]) else: - self.binary = True - self.data = None - if len(encoded_packet) > 1: + if self.binary and not isinstance(encoded_packet, bytes): + encoded_packet = bytes(encoded_packet) if self.binary: - if b64: - self.data = base64.b64decode(encoded_packet[1:]) - else: - self.data = encoded_packet[1:] + self.packet_type = MESSAGE + self.data = encoded_packet else: + self.packet_type = int(encoded_packet[0]) try: - self.data = self.json.loads( - encoded_packet[1:].decode('utf-8')) + self.data = self.json.loads(encoded_packet[1:]) if isinstance(self.data, int): # do not allow integer payloads, see # github.com/miguelgrinberg/python-engineio/issues/75 # for background on this decision raise ValueError except ValueError: - self.data = encoded_packet[1:].decode('utf-8') + self.data = encoded_packet[1:] diff --git a/libs/engineio/payload.py b/libs/engineio/payload.py index fbf9cbd27..f0e9e343d 100644 --- a/libs/engineio/payload.py +++ b/libs/engineio/payload.py @@ -1,9 +1,7 @@ -import six +import urllib from . import packet -from six.moves import urllib - class Payload(object): """Engine.IO payload.""" @@ -14,31 +12,19 @@ class Payload(object): if encoded_payload is not None: self.decode(encoded_payload) - def encode(self, b64=False, jsonp_index=None): + def encode(self, jsonp_index=None): """Encode the payload for transmission.""" - encoded_payload = b'' + encoded_payload = '' for pkt in self.packets: - encoded_packet = pkt.encode(b64=b64) - packet_len = len(encoded_packet) - if b64: - encoded_payload += str(packet_len).encode('utf-8') + b':' + \ - encoded_packet - else: - binary_len = b'' - while packet_len != 0: - binary_len = six.int2byte(packet_len % 10) + binary_len - packet_len = int(packet_len / 10) - if not pkt.binary: - encoded_payload += b'\0' - else: - encoded_payload += b'\1' - encoded_payload += binary_len + b'\xff' + encoded_packet + if encoded_payload: + encoded_payload += '\x1e' + encoded_payload += pkt.encode(b64=True) if jsonp_index is not None: - encoded_payload = b'___eio[' + \ - str(jsonp_index).encode() + \ - b']("' + \ - encoded_payload.replace(b'"', b'\\"') + \ - b'");' + encoded_payload = '___eio[' + \ + str(jsonp_index) + \ + ']("' + \ + encoded_payload.replace('"', '\\"') + \ + '");' return encoded_payload def decode(self, encoded_payload): @@ -49,33 +35,12 @@ class Payload(object): return # JSONP POST payload starts with 'd=' - if encoded_payload.startswith(b'd='): + if encoded_payload.startswith('d='): encoded_payload = urllib.parse.parse_qs( - encoded_payload)[b'd'][0] + encoded_payload)['d'][0] - i = 0 - if six.byte2int(encoded_payload[0:1]) <= 1: - # binary encoding - while i < len(encoded_payload): - if len(self.packets) >= self.max_decode_packets: - raise ValueError('Too many packets in payload') - packet_len = 0 - i += 1 - while six.byte2int(encoded_payload[i:i + 1]) != 255: - packet_len = packet_len * 10 + six.byte2int( - encoded_payload[i:i + 1]) - i += 1 - self.packets.append(packet.Packet( - encoded_packet=encoded_payload[i + 1:i + 1 + packet_len])) - i += packet_len + 1 - else: - # assume text encoding - encoded_payload = encoded_payload.decode('utf-8') - while i < len(encoded_payload): - if len(self.packets) >= self.max_decode_packets: - raise ValueError('Too many packets in payload') - j = encoded_payload.find(':', i) - packet_len = int(encoded_payload[i:j]) - pkt = encoded_payload[j + 1:j + 1 + packet_len] - self.packets.append(packet.Packet(encoded_packet=pkt)) - i = j + 1 + packet_len + encoded_packets = encoded_payload.split('\x1e') + if len(encoded_packets) > self.max_decode_packets: + raise ValueError('Too many packets in payload') + self.packets = [packet.Packet(encoded_packet=encoded_packet) + for encoded_packet in encoded_packets] diff --git a/libs/engineio/server.py b/libs/engineio/server.py index e1543c2dc..7498f3f6b 100644 --- a/libs/engineio/server.py +++ b/libs/engineio/server.py @@ -1,12 +1,12 @@ +import base64 import gzip import importlib +import io import logging -import uuid +import secrets +import urllib import zlib -import six -from six.moves import urllib - from . import exceptions from . import packet from . import payload @@ -29,17 +29,16 @@ class Server(object): "gevent_uwsgi", then "gevent", and finally "threading". The first async mode that has all its dependencies installed is the one that is chosen. - :param ping_timeout: The time in seconds that the client waits for the - server to respond before disconnecting. The default - is 60 seconds. - :param ping_interval: The interval in seconds at which the client pings - the server. The default is 25 seconds. For advanced + :param ping_interval: The interval in seconds at which the server pings + the client. The default is 25 seconds. For advanced control, a two element tuple can be given, where the first number is the ping interval and the second - is a grace period added by the server. The default - grace period is 5 seconds. + is a grace period added by the server. + :param ping_timeout: The time in seconds that the client waits for the + server to respond before disconnecting. The default + is 5 seconds. :param max_http_buffer_size: The maximum size of a message when using the - polling transport. The default is 100,000,000 + polling transport. The default is 1,000,000 bytes. :param allow_upgrades: Whether to allow transport upgrades or not. The default is ``True``. @@ -48,9 +47,14 @@ class Server(object): :param compression_threshold: Only compress messages when their byte size is greater than this value. The default is 1024 bytes. - :param cookie: Name of the HTTP cookie that contains the client session - id. If set to ``None``, a cookie is not sent to the client. - The default is ``'io'``. + :param cookie: If set to a string, it is the name of the HTTP cookie the + server sends back tot he client containing the client + session id. If set to a dictionary, the ``'name'`` key + contains the cookie name and other keys define cookie + attributes, where the value of each attribute can be a + string, a callable with no arguments, or a boolean. If set + to ``None`` (the default), a cookie is not sent to the + client. :param cors_allowed_origins: Origin or list of origins that are allowed to connect to this server. Only the same origin is allowed by default. Set this argument to @@ -61,7 +65,8 @@ class Server(object): is ``True``. :param logger: To enable logging set to ``True`` or pass a logger object to use. To disable logging set to ``False``. The default is - ``False``. + ``False``. Note that fatal errors are logged even when + ``logger`` is ``False``. :param json: An alternative json module to use for encoding and decoding packets. Custom json modules must have ``dumps`` and ``loads`` functions that are compatible with the standard library @@ -79,11 +84,12 @@ class Server(object): compression_methods = ['gzip', 'deflate'] event_names = ['connect', 'disconnect', 'message'] _default_monitor_clients = True + sequence_number = 0 - def __init__(self, async_mode=None, ping_timeout=60, ping_interval=25, - max_http_buffer_size=100000000, allow_upgrades=True, + def __init__(self, async_mode=None, ping_interval=25, ping_timeout=5, + max_http_buffer_size=1000000, allow_upgrades=True, http_compression=True, compression_threshold=1024, - cookie='io', cors_allowed_origins=None, + cookie=None, cors_allowed_origins=None, cors_credentials=True, logger=False, json=None, async_handlers=True, monitor_clients=None, **kwargs): self.ping_timeout = ping_timeout @@ -92,7 +98,7 @@ class Server(object): self.ping_interval_grace_period = ping_interval[1] else: self.ping_interval = ping_interval - self.ping_interval_grace_period = 5 + self.ping_interval_grace_period = 0 self.max_http_buffer_size = max_http_buffer_size self.allow_upgrades = allow_upgrades self.http_compression = http_compression @@ -103,6 +109,7 @@ class Server(object): self.async_handlers = async_handlers self.sockets = {} self.handlers = {} + self.log_message_keys = set() self.start_service_task = monitor_clients \ if monitor_clients is not None else self._default_monitor_clients if json is not None: @@ -111,8 +118,7 @@ class Server(object): self.logger = logger else: self.logger = default_logger - if not logging.root.handlers and \ - self.logger.level == logging.NOTSET: + if self.logger.level == logging.NOTSET: if logger: self.logger.setLevel(logging.INFO) else: @@ -196,17 +202,13 @@ class Server(object): return set_handler set_handler(handler) - def send(self, sid, data, binary=None): + def send(self, sid, data): """Send a message to a client. :param sid: The session id of the recipient client. :param data: The data to send to the client. Data can be of type ``str``, ``bytes``, ``list`` or ``dict``. If a ``list`` or ``dict``, the data will be serialized as JSON. - :param binary: ``True`` to send packet as binary, ``False`` to send - as text. If not given, unicode (Python 2) and str - (Python 3) are sent as text, and str (Python 2) and - bytes (Python 3) are sent as binary. """ try: socket = self._get_socket(sid) @@ -214,7 +216,7 @@ class Server(object): # the socket is not available self.logger.warning('Cannot send to sid %s', sid) return - socket.send(packet.Packet(packet.MESSAGE, data=data, binary=binary)) + socket.send(packet.Packet(packet.MESSAGE, data=data)) def get_session(self, sid): """Return the user session for a client. @@ -292,7 +294,7 @@ class Server(object): if sid in self.sockets: # pragma: no cover del self.sockets[sid] else: - for client in six.itervalues(self.sockets): + for client in self.sockets.values(): client.close() self.sockets = {} @@ -329,22 +331,30 @@ class Server(object): allowed_origins = self._cors_allowed_origins(environ) if allowed_origins is not None and origin not in \ allowed_origins: - self.logger.info(origin + ' is not an accepted origin.') - r = self._bad_request() + self._log_error_once( + origin + ' is not an accepted origin.', 'bad-origin') + r = self._bad_request( + origin + ' is not an accepted origin.') start_response(r['status'], r['headers']) return [r['response']] method = environ['REQUEST_METHOD'] query = urllib.parse.parse_qs(environ.get('QUERY_STRING', '')) - - sid = query['sid'][0] if 'sid' in query else None - b64 = False jsonp = False jsonp_index = None - if 'b64' in query: - if query['b64'][0] == "1" or query['b64'][0].lower() == "true": - b64 = True + # make sure the client speaks a compatible Engine.IO version + sid = query['sid'][0] if 'sid' in query else None + if sid is None and query.get('EIO') != ['4']: + self._log_error_once( + 'The client is using an unsupported version of the Socket.IO ' + 'or Engine.IO protocols', 'bad-version') + r = self._bad_request( + 'The client is using an unsupported version of the Socket.IO ' + 'or Engine.IO protocols') + start_response(r['status'], r['headers']) + return [r['response']] + if 'j' in query: jsonp = True try: @@ -354,29 +364,35 @@ class Server(object): pass if jsonp and jsonp_index is None: - self.logger.warning('Invalid JSONP index number') - r = self._bad_request() + self._log_error_once('Invalid JSONP index number', + 'bad-jsonp-index') + r = self._bad_request('Invalid JSONP index number') elif method == 'GET': if sid is None: transport = query.get('transport', ['polling'])[0] - if transport != 'polling' and transport != 'websocket': - self.logger.warning('Invalid transport %s', transport) - r = self._bad_request() - else: + # transport must be one of 'polling' or 'websocket'. + # if 'websocket', the HTTP_UPGRADE header must match. + upgrade_header = environ.get('HTTP_UPGRADE').lower() \ + if 'HTTP_UPGRADE' in environ else None + if transport == 'polling' \ + or transport == upgrade_header == 'websocket': r = self._handle_connect(environ, start_response, - transport, b64, jsonp_index) + transport, jsonp_index) + else: + self._log_error_once('Invalid transport ' + transport, + 'bad-transport') + r = self._bad_request('Invalid transport ' + transport) else: if sid not in self.sockets: - self.logger.warning('Invalid session %s', sid) - r = self._bad_request() + self._log_error_once('Invalid session ' + sid, 'bad-sid') + r = self._bad_request('Invalid session ' + sid) else: socket = self._get_socket(sid) try: packets = socket.handle_get_request( environ, start_response) if isinstance(packets, list): - r = self._ok(packets, b64=b64, - jsonp_index=jsonp_index) + r = self._ok(packets, jsonp_index=jsonp_index) else: r = packets except exceptions.EngineIOError: @@ -387,8 +403,9 @@ class Server(object): del self.sockets[sid] elif method == 'POST': if sid is None or sid not in self.sockets: - self.logger.warning('Invalid session %s', sid) - r = self._bad_request() + self._log_error_once( + 'Invalid session ' + (sid or 'None'), 'bad-sid') + r = self._bad_request('Invalid session ' + (sid or 'None')) else: socket = self._get_socket(sid) try: @@ -481,11 +498,28 @@ class Server(object): """ return self._async['event'](*args, **kwargs) - def _generate_id(self): + def generate_id(self): """Generate a unique session id.""" - return uuid.uuid4().hex + id = base64.b64encode( + secrets.token_bytes(12) + self.sequence_number.to_bytes(3, 'big')) + self.sequence_number = (self.sequence_number + 1) & 0xffffff + return id.decode('utf-8').replace('/', '_').replace('+', '-') + + def _generate_sid_cookie(self, sid, attributes): + """Generate the sid cookie.""" + cookie = attributes.get('name', 'io') + '=' + sid + for attribute, value in attributes.items(): + if attribute == 'name': + continue + if callable(value): + value = value() + if value is True: + cookie += '; ' + attribute + else: + cookie += '; ' + attribute + '=' + value + return cookie - def _handle_connect(self, environ, start_response, transport, b64=False, + def _handle_connect(self, environ, start_response, transport, jsonp_index=None): """Handle a client connection request.""" if self.start_service_task: @@ -493,36 +527,53 @@ class Server(object): self.start_service_task = False self.start_background_task(self._service_task) - sid = self._generate_id() + sid = self.generate_id() s = socket.Socket(self, sid) self.sockets[sid] = s - pkt = packet.Packet( - packet.OPEN, {'sid': sid, - 'upgrades': self._upgrades(sid, transport), - 'pingTimeout': int(self.ping_timeout * 1000), - 'pingInterval': int(self.ping_interval * 1000)}) + pkt = packet.Packet(packet.OPEN, { + 'sid': sid, + 'upgrades': self._upgrades(sid, transport), + 'pingTimeout': int(self.ping_timeout * 1000), + 'pingInterval': int( + self.ping_interval + self.ping_interval_grace_period) * 1000}) s.send(pkt) + s.schedule_ping() + # NOTE: some sections below are marked as "no cover" to workaround + # what seems to be a bug in the coverage package. All the lines below + # are covered by tests, but some are not reported as such for some + # reason ret = self._trigger_event('connect', sid, environ, run_async=False) - if ret is False: + if ret is not None and ret is not True: # pragma: no cover del self.sockets[sid] self.logger.warning('Application rejected connection') - return self._unauthorized() + return self._unauthorized(ret or None) - if transport == 'websocket': + if transport == 'websocket': # pragma: no cover ret = s.handle_get_request(environ, start_response) - if s.closed: + if s.closed and sid in self.sockets: # websocket connection ended, so we are done del self.sockets[sid] return ret - else: + else: # pragma: no cover s.connected = True headers = None if self.cookie: - headers = [('Set-Cookie', self.cookie + '=' + sid)] + if isinstance(self.cookie, dict): + headers = [( + 'Set-Cookie', + self._generate_sid_cookie(sid, self.cookie) + )] + else: + headers = [( + 'Set-Cookie', + self._generate_sid_cookie(sid, { + 'name': self.cookie, 'path': '/', 'SameSite': 'Lax' + }) + )] try: - return self._ok(s.poll(), headers=headers, b64=b64, + return self._ok(s.poll(), headers=headers, jsonp_index=jsonp_index) except exceptions.QueueEmpty: return self._bad_request() @@ -561,29 +612,29 @@ class Server(object): raise KeyError('Session is disconnected') return s - def _ok(self, packets=None, headers=None, b64=False, jsonp_index=None): + def _ok(self, packets=None, headers=None, jsonp_index=None): """Generate a successful HTTP response.""" if packets is not None: if headers is None: headers = [] - if b64: - headers += [('Content-Type', 'text/plain; charset=UTF-8')] - else: - headers += [('Content-Type', 'application/octet-stream')] + headers += [('Content-Type', 'text/plain; charset=UTF-8')] return {'status': '200 OK', 'headers': headers, 'response': payload.Payload(packets=packets).encode( - b64=b64, jsonp_index=jsonp_index)} + jsonp_index=jsonp_index).encode('utf-8')} else: return {'status': '200 OK', 'headers': [('Content-Type', 'text/plain')], 'response': b'OK'} - def _bad_request(self): + def _bad_request(self, message=None): """Generate a bad request HTTP error response.""" + if message is None: + message = 'Bad Request' + message = packet.Packet.json.dumps(message) return {'status': '400 BAD REQUEST', 'headers': [('Content-Type', 'text/plain')], - 'response': b'Bad Request'} + 'response': message.encode('utf-8')} def _method_not_found(self): """Generate a method not found HTTP error response.""" @@ -591,11 +642,14 @@ class Server(object): 'headers': [('Content-Type', 'text/plain')], 'response': b'Method Not Found'} - def _unauthorized(self): + def _unauthorized(self, message=None): """Generate a unauthorized HTTP error response.""" + if message is None: + message = 'Unauthorized' + message = packet.Packet.json.dumps(message) return {'status': '401 UNAUTHORIZED', - 'headers': [('Content-Type', 'text/plain')], - 'response': b'Unauthorized'} + 'headers': [('Content-Type', 'application/json')], + 'response': message.encode('utf-8')} def _cors_allowed_origins(self, environ): default_origins = [] @@ -613,7 +667,7 @@ class Server(object): allowed_origins = default_origins elif self.cors_allowed_origins == '*': allowed_origins = None - elif isinstance(self.cors_allowed_origins, six.string_types): + elif isinstance(self.cors_allowed_origins, str): allowed_origins = [self.cors_allowed_origins] else: allowed_origins = self.cors_allowed_origins @@ -641,7 +695,7 @@ class Server(object): def _gzip(self, response): """Apply gzip compression to a response.""" - bytesio = six.BytesIO() + bytesio = io.BytesIO() with gzip.GzipFile(fileobj=bytesio, mode='w') as gz: gz.write(response) return bytesio.getvalue() @@ -650,6 +704,16 @@ class Server(object): """Apply deflate compression to a response.""" return zlib.compress(response) + def _log_error_once(self, message, message_key): + """Log message with logging.ERROR level the first time, then log + with given level.""" + if message_key not in self.log_message_keys: + self.logger.error(message + ' (further occurrences of this error ' + 'will be logged with level INFO)') + self.log_message_keys.add(message_key) + else: + self.logger.info(message) + def _service_task(self): # pragma: no cover """Monitor connected clients and clean up those that time out.""" while True: diff --git a/libs/engineio/socket.py b/libs/engineio/socket.py index 38593e7c7..1434b191d 100644 --- a/libs/engineio/socket.py +++ b/libs/engineio/socket.py @@ -1,4 +1,3 @@ -import six import sys import time @@ -15,11 +14,10 @@ class Socket(object): self.server = server self.sid = sid self.queue = self.server.create_queue() - self.last_ping = time.time() + self.last_ping = None self.connected = False self.upgrading = False self.upgraded = False - self.packet_backlog = [] self.closing = False self.closed = False self.session = {} @@ -28,7 +26,8 @@ class Socket(object): """Wait for packets to send to the client.""" queue_empty = self.server.get_queue_empty_exception() try: - packets = [self.queue.get(timeout=self.server.ping_timeout)] + packets = [self.queue.get( + timeout=self.server.ping_interval + self.server.ping_timeout)] self.queue.task_done() except queue_empty: raise exceptions.QueueEmpty() @@ -36,8 +35,12 @@ class Socket(object): return [] while True: try: - packets.append(self.queue.get(block=False)) + pkt = self.queue.get(block=False) self.queue.task_done() + if pkt is None: + self.queue.put(None) + break + packets.append(pkt) except queue_empty: break return packets @@ -50,9 +53,8 @@ class Socket(object): self.sid, packet_name, pkt.data if not isinstance(pkt.data, bytes) else '<binary>') - if pkt.packet_type == packet.PING: - self.last_ping = time.time() - self.send(packet.Packet(packet.PONG, pkt.data)) + if pkt.packet_type == packet.PONG: + self.schedule_ping() elif pkt.packet_type == packet.MESSAGE: self.server._trigger_event('message', self.sid, pkt.data, run_async=self.server.async_handlers) @@ -64,14 +66,11 @@ class Socket(object): raise exceptions.UnknownPacketError() def check_ping_timeout(self): - """Make sure the client is still sending pings. - - This helps detect disconnections for long-polling clients. - """ + """Make sure the client is still responding to pings.""" if self.closed: raise exceptions.SocketIsClosedError() - if time.time() - self.last_ping > self.server.ping_interval + \ - self.server.ping_interval_grace_period: + if self.last_ping and \ + time.time() - self.last_ping > self.server.ping_timeout: self.server.logger.info('%s: Client is gone, closing socket', self.sid) # Passing abort=False here will cause close() to write a @@ -85,8 +84,6 @@ class Socket(object): """Send a packet to the client.""" if not self.check_ping_timeout(): return - if self.upgrading: - self.packet_backlog.append(pkt) else: self.queue.put(pkt) self.server.logger.info('%s: Sending packet %s data %s', @@ -105,12 +102,16 @@ class Socket(object): self.sid, transport) return getattr(self, '_upgrade_' + transport)(environ, start_response) + if self.upgrading or self.upgraded: + # we are upgrading to WebSocket, do not return any more packets + # through the polling endpoint + return [packet.Packet(packet.NOOP)] try: packets = self.poll() except exceptions.QueueEmpty: exc = sys.exc_info() self.close(wait=False) - six.reraise(*exc) + raise exc[1].with_traceback(exc[2]) return packets def handle_post_request(self, environ): @@ -119,7 +120,7 @@ class Socket(object): if length > self.server.max_http_buffer_size: raise exceptions.ContentTooLongError() else: - body = environ['wsgi.input'].read(length) + body = environ['wsgi.input'].read(length).decode('utf-8') p = payload.Payload(encoded_payload=body) for pkt in p.packets: self.receive(pkt) @@ -136,6 +137,16 @@ class Socket(object): if wait: self.queue.join() + def schedule_ping(self): + def send_ping(): + self.last_ping = None + self.server.sleep(self.server.ping_interval) + if not self.closing and not self.closed: + self.last_ping = time.time() + self.send(packet.Packet(packet.PING)) + + self.server.start_background_task(send_ping) + def _upgrade_websocket(self, environ, start_response): """Upgrade the connection from polling to websocket.""" if self.upgraded: @@ -149,9 +160,11 @@ class Socket(object): def _websocket_handler(self, ws): """Engine.IO handler for websocket transport.""" # try to set a socket timeout matching the configured ping interval + # and timeout for attr in ['_sock', 'socket']: # pragma: no cover if hasattr(ws, attr) and hasattr(getattr(ws, attr), 'settimeout'): - getattr(ws, attr).settimeout(self.server.ping_timeout) + getattr(ws, attr).settimeout( + self.server.ping_interval + self.server.ping_timeout) if self.connected: # the socket was already connected, so this is an upgrade @@ -163,10 +176,9 @@ class Socket(object): decoded_pkt.data != 'probe': self.server.logger.info( '%s: Failed websocket upgrade, no PING packet', self.sid) + self.upgrading = False return [] - ws.send(packet.Packet( - packet.PONG, - data=six.text_type('probe')).encode(always_bytes=False)) + ws.send(packet.Packet(packet.PONG, data='probe').encode()) self.queue.put(packet.Packet(packet.NOOP)) # end poll pkt = ws.wait() @@ -177,13 +189,9 @@ class Socket(object): ('%s: Failed websocket upgrade, expected UPGRADE packet, ' 'received %s instead.'), self.sid, pkt) + self.upgrading = False return [] self.upgraded = True - - # flush any packets that were sent during the upgrade - for pkt in self.packet_backlog: - self.queue.put(pkt) - self.packet_backlog = [] self.upgrading = False else: self.connected = True @@ -202,7 +210,7 @@ class Socket(object): break try: for pkt in packets: - ws.send(pkt.encode(always_bytes=False)) + ws.send(pkt.encode()) except: break writer_task = self.server.start_background_task(writer) @@ -225,8 +233,6 @@ class Socket(object): if p is None: # connection closed by client break - if isinstance(p, six.text_type): # pragma: no cover - p = p.encode('utf-8') pkt = packet.Packet(encoded_packet=p) try: self.receive(pkt) |