diff options
Diffstat (limited to 'libs/websocket/_app.py')
-rw-r--r-- | libs/websocket/_app.py | 183 |
1 files changed, 130 insertions, 53 deletions
diff --git a/libs/websocket/_app.py b/libs/websocket/_app.py index 178f516a5..81aa1fcd9 100644 --- a/libs/websocket/_app.py +++ b/libs/websocket/_app.py @@ -23,6 +23,7 @@ Copyright (C) 2010 Hiroki Ohtani(liris) """ WebSocketApp provides higher level APIs. """ +import inspect import select import sys import threading @@ -39,6 +40,40 @@ from . import _logging __all__ = ["WebSocketApp"] +class Dispatcher: + def __init__(self, app, ping_timeout): + self.app = app + self.ping_timeout = ping_timeout + + def read(self, sock, read_callback, check_callback): + while self.app.sock.connected: + r, w, e = select.select( + (self.app.sock.sock, ), (), (), self.ping_timeout) + if r: + if not read_callback(): + break + check_callback() + +class SSLDispacther: + def __init__(self, app, ping_timeout): + self.app = app + self.ping_timeout = ping_timeout + + def read(self, sock, read_callback, check_callback): + while self.app.sock.connected: + r = self.select() + if r: + if not read_callback(): + break + check_callback() + + def select(self): + sock = self.app.sock.sock + if sock.pending(): + return [sock,] + + r, w, e = select.select((sock, ), (), (), self.ping_timeout) + return r class WebSocketApp(object): """ @@ -83,8 +118,7 @@ class WebSocketApp(object): The 2nd argument is utf-8 string which we get from the server. The 3rd argument is data type. ABNF.OPCODE_TEXT or ABNF.OPCODE_BINARY will be came. The 4th argument is continue flag. if 0, the data continue - keep_running: a boolean flag indicating whether the app's main loop - should keep running, defaults to True + keep_running: this parameter is obsolete and ignored. get_mask_key: a callable to produce new mask keys, see the WebSocket.set_mask_key's docstring for more information subprotocols: array of available sub protocols. default is None. @@ -92,6 +126,7 @@ class WebSocketApp(object): self.url = url self.header = header if header is not None else [] self.cookie = cookie + self.on_open = on_open self.on_message = on_message self.on_data = on_data @@ -100,7 +135,7 @@ class WebSocketApp(object): self.on_ping = on_ping self.on_pong = on_pong self.on_cont_message = on_cont_message - self.keep_running = keep_running + self.keep_running = False self.get_mask_key = get_mask_key self.sock = None self.last_ping_tm = 0 @@ -126,6 +161,7 @@ class WebSocketApp(object): self.keep_running = False if self.sock: self.sock.close(**kwargs) + self.sock = None def _send_ping(self, interval, event): while not event.wait(interval): @@ -142,7 +178,8 @@ class WebSocketApp(object): http_proxy_host=None, http_proxy_port=None, http_no_proxy=None, http_proxy_auth=None, skip_utf8_validation=False, - host=None, origin=None): + host=None, origin=None, dispatcher=None, + suppress_origin = False, proxy_type=None): """ run event loop for WebSocket framework. This loop is infinite loop and is alive during websocket is available. @@ -160,33 +197,64 @@ class WebSocketApp(object): skip_utf8_validation: skip utf8 validation. host: update host header. origin: update origin header. + dispatcher: customize reading data from socket. + suppress_origin: suppress outputting origin header. + + Returns + ------- + False if caught KeyboardInterrupt + True if other exception was raised during a loop """ - if not ping_timeout or ping_timeout <= 0: + if ping_timeout is not None and ping_timeout <= 0: ping_timeout = None if ping_timeout and ping_interval and ping_interval <= ping_timeout: raise WebSocketException("Ensure ping_interval > ping_timeout") - if sockopt is None: + if not sockopt: sockopt = [] - if sslopt is None: + if not sslopt: sslopt = {} if self.sock: raise WebSocketException("socket is already opened") thread = None - close_frame = None + self.keep_running = True + self.last_ping_tm = 0 + self.last_pong_tm = 0 + + def teardown(close_frame=None): + """ + Tears down the connection. + If close_frame is set, we will invoke the on_close handler with the + statusCode and reason from there. + """ + if thread and thread.isAlive(): + event.set() + thread.join() + self.keep_running = False + if self.sock: + self.sock.close() + close_args = self._get_close_args( + close_frame.data if close_frame else None) + self._callback(self.on_close, *close_args) + self.sock = None try: self.sock = WebSocket( self.get_mask_key, sockopt=sockopt, sslopt=sslopt, - fire_cont_frame=self.on_cont_message and True or False, - skip_utf8_validation=skip_utf8_validation) + fire_cont_frame=self.on_cont_message is not None, + skip_utf8_validation=skip_utf8_validation, + enable_multithread=True if ping_interval else False) self.sock.settimeout(getdefaulttimeout()) self.sock.connect( self.url, header=self.header, cookie=self.cookie, http_proxy_host=http_proxy_host, http_proxy_port=http_proxy_port, http_no_proxy=http_no_proxy, http_proxy_auth=http_proxy_auth, subprotocols=self.subprotocols, - host=host, origin=origin) + host=host, origin=origin, suppress_origin=suppress_origin, + proxy_type=proxy_type) + if not dispatcher: + dispatcher = self.create_dispatcher(ping_timeout) + self._callback(self.on_open) if ping_interval: @@ -196,58 +264,63 @@ class WebSocketApp(object): thread.setDaemon(True) thread.start() - while self.sock.connected: - r, w, e = select.select( - (self.sock.sock, ), (), (), ping_timeout or 10) # Use a 10 second timeout to avoid to wait forever on close + def read(): if not self.keep_running: - break + return teardown() - if r: - op_code, frame = self.sock.recv_data_frame(True) - if op_code == ABNF.OPCODE_CLOSE: - close_frame = frame - break - elif op_code == ABNF.OPCODE_PING: - self._callback(self.on_ping, frame.data) - elif op_code == ABNF.OPCODE_PONG: - self.last_pong_tm = time.time() - self._callback(self.on_pong, frame.data) - elif op_code == ABNF.OPCODE_CONT and self.on_cont_message: - self._callback(self.on_data, data, - frame.opcode, frame.fin) - self._callback(self.on_cont_message, - frame.data, frame.fin) - else: - data = frame.data - if six.PY3 and op_code == ABNF.OPCODE_TEXT: - data = data.decode("utf-8") - self._callback(self.on_data, data, frame.opcode, True) - self._callback(self.on_message, data) - - if ping_timeout and self.last_ping_tm \ - and time.time() - self.last_ping_tm > ping_timeout \ - and self.last_ping_tm - self.last_pong_tm > ping_timeout: - raise WebSocketTimeoutException("ping/pong timed out") + op_code, frame = self.sock.recv_data_frame(True) + if op_code == ABNF.OPCODE_CLOSE: + return teardown(frame) + elif op_code == ABNF.OPCODE_PING: + self._callback(self.on_ping, frame.data) + elif op_code == ABNF.OPCODE_PONG: + self.last_pong_tm = time.time() + self._callback(self.on_pong, frame.data) + elif op_code == ABNF.OPCODE_CONT and self.on_cont_message: + self._callback(self.on_data, frame.data, + frame.opcode, frame.fin) + self._callback(self.on_cont_message, + frame.data, frame.fin) + else: + data = frame.data + if six.PY3 and op_code == ABNF.OPCODE_TEXT: + data = data.decode("utf-8") + self._callback(self.on_data, data, frame.opcode, True) + self._callback(self.on_message, data) + + return True + + def check(): + if (ping_timeout): + has_timeout_expired = time.time() - self.last_ping_tm > ping_timeout + has_pong_not_arrived_after_last_ping = self.last_pong_tm - self.last_ping_tm < 0 + has_pong_arrived_too_late = self.last_pong_tm - self.last_ping_tm > ping_timeout + + if (self.last_ping_tm + and has_timeout_expired + and (has_pong_not_arrived_after_last_ping or has_pong_arrived_too_late)): + raise WebSocketTimeoutException("ping/pong timed out") + return True + + dispatcher.read(self.sock.sock, read, check) except (Exception, KeyboardInterrupt, SystemExit) as e: self._callback(self.on_error, e) if isinstance(e, SystemExit): # propagate SystemExit further raise - finally: - if thread and thread.isAlive(): - event.set() - thread.join() - self.keep_running = False - self.sock.close() - close_args = self._get_close_args( - close_frame.data if close_frame else None) - self._callback(self.on_close, *close_args) - self.sock = None + teardown() + return not isinstance(e, KeyboardInterrupt) + + def create_dispatcher(self, ping_timeout): + timeout = ping_timeout or 10 + if self.sock.is_ssl(): + return SSLDispacther(self, timeout) + + return Dispatcher(self, timeout) def _get_close_args(self, data): """ this functions extracts the code, reason from the close body if they exists, and if the self.on_close except three arguments """ - import inspect # if the on_close callback is "old", just return empty list if sys.version_info < (3, 0): if not self.on_close or len(inspect.getargspec(self.on_close).args) != 3: @@ -266,7 +339,11 @@ class WebSocketApp(object): def _callback(self, callback, *args): if callback: try: - callback(self, *args) + if inspect.ismethod(callback): + callback(*args) + else: + callback(self, *args) + except Exception as e: _logging.error("error from callback {}: {}".format(callback, e)) if _logging.isEnabledForDebug(): |