diff options
Diffstat (limited to 'Stomps/SimpleWebSocketServer/SimpleWebSocketServer.py')
-rw-r--r-- | Stomps/SimpleWebSocketServer/SimpleWebSocketServer.py | 735 |
1 files changed, 735 insertions, 0 deletions
diff --git a/Stomps/SimpleWebSocketServer/SimpleWebSocketServer.py b/Stomps/SimpleWebSocketServer/SimpleWebSocketServer.py new file mode 100644 index 0000000..ea528ee --- /dev/null +++ b/Stomps/SimpleWebSocketServer/SimpleWebSocketServer.py @@ -0,0 +1,735 @@ +''' +The MIT License (MIT) +Copyright (c) 2013 Dave P. +''' +import sys +VER = sys.version_info[0] +if VER >= 3: + import socketserver + from http.server import BaseHTTPRequestHandler + from io import StringIO, BytesIO +else: + import SocketServer + from BaseHTTPServer import BaseHTTPRequestHandler + from StringIO import StringIO + +import hashlib +import base64 +import socket +import struct +import ssl +import errno +import codecs +from collections import deque +from select import select + +__all__ = ['WebSocket', + 'SimpleWebSocketServer', + 'SimpleSSLWebSocketServer'] + +def _check_unicode(val): + if VER >= 3: + return isinstance(val, str) + else: + return isinstance(val, basestring) + +class HTTPRequest(BaseHTTPRequestHandler): + def __init__(self, request_text): + if VER >= 3: + self.rfile = BytesIO(request_text) + else: + self.rfile = StringIO(request_text) + self.raw_requestline = self.rfile.readline() + self.error_code = self.error_message = None + self.parse_request() + +_VALID_STATUS_CODES = [1000, 1001, 1002, 1003, 1007, 1008, + 1009, 1010, 1011, 3000, 3999, 4000, 4999] + +HANDSHAKE_STR = ( + "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: %(acceptstr)s\r\n\r\n" +) + +FAILED_HANDSHAKE_STR = ( + "HTTP/1.1 426 Upgrade Required\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Version: 13\r\n" + "Content-Type: text/plain\r\n\r\n" + "This service requires use of the WebSocket protocol\r\n" +) + +GUID_STR = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' + +STREAM = 0x0 +TEXT = 0x1 +BINARY = 0x2 +CLOSE = 0x8 +PING = 0x9 +PONG = 0xA + +HEADERB1 = 1 +HEADERB2 = 3 +LENGTHSHORT = 4 +LENGTHLONG = 5 +MASK = 6 +PAYLOAD = 7 + +MAXHEADER = 65536 +MAXPAYLOAD = 33554432 + +class WebSocket(object): + + def __init__(self, server, sock, address): + self.server = server + self.client = sock + self.address = address + + self.handshaked = False + self.headerbuffer = bytearray() + self.headertoread = 2048 + + self.fin = 0 + self.data = bytearray() + self.opcode = 0 + self.hasmask = 0 + self.maskarray = None + self.length = 0 + self.lengtharray = None + self.index = 0 + self.request = None + self.usingssl = False + + self.frag_start = False + self.frag_type = BINARY + self.frag_buffer = None + self.frag_decoder = codecs.getincrementaldecoder('utf-8')(errors='strict') + self.closed = False + self.sendq = deque() + + self.state = HEADERB1 + + # restrict the size of header and payload for security reasons + self.maxheader = MAXHEADER + self.maxpayload = MAXPAYLOAD + + def handleMessage(self): + """ + Called when websocket frame is received. + To access the frame data call self.data. + + If the frame is Text then self.data is a unicode object. + If the frame is Binary then self.data is a bytearray object. + """ + pass + + def handleConnected(self): + """ + Called when a websocket client connects to the server. + """ + pass + + def handleClose(self): + """ + Called when a websocket server gets a Close frame from a client. + """ + pass + + def _handlePacket(self): + if self.opcode == CLOSE: + pass + elif self.opcode == STREAM: + pass + elif self.opcode == TEXT: + pass + elif self.opcode == BINARY: + pass + elif self.opcode == PONG or self.opcode == PING: + if len(self.data) > 125: + raise Exception('control frame length can not be > 125') + else: + # unknown or reserved opcode so just close + raise Exception('unknown opcode') + + if self.opcode == CLOSE: + status = 1000 + reason = u'' + length = len(self.data) + + if length == 0: + pass + elif length >= 2: + status = struct.unpack_from('!H', self.data[:2])[0] + reason = self.data[2:] + + if status not in _VALID_STATUS_CODES: + status = 1002 + + if len(reason) > 0: + try: + reason = reason.decode('utf8', errors='strict') + except: + status = 1002 + else: + status = 1002 + + self.close(status, reason) + return + + elif self.fin == 0: + if self.opcode != STREAM: + if self.opcode == PING or self.opcode == PONG: + raise Exception('control messages can not be fragmented') + + self.frag_type = self.opcode + self.frag_start = True + self.frag_decoder.reset() + + if self.frag_type == TEXT: + self.frag_buffer = [] + utf_str = self.frag_decoder.decode(self.data, final = False) + if utf_str: + self.frag_buffer.append(utf_str) + else: + self.frag_buffer = bytearray() + self.frag_buffer.extend(self.data) + + else: + if self.frag_start is False: + raise Exception('fragmentation protocol error') + + if self.frag_type == TEXT: + utf_str = self.frag_decoder.decode(self.data, final = False) + if utf_str: + self.frag_buffer.append(utf_str) + else: + self.frag_buffer.extend(self.data) + + else: + if self.opcode == STREAM: + if self.frag_start is False: + raise Exception('fragmentation protocol error') + + if self.frag_type == TEXT: + utf_str = self.frag_decoder.decode(self.data, final = True) + self.frag_buffer.append(utf_str) + self.data = u''.join(self.frag_buffer) + else: + self.frag_buffer.extend(self.data) + self.data = self.frag_buffer + + self.handleMessage() + + self.frag_decoder.reset() + self.frag_type = BINARY + self.frag_start = False + self.frag_buffer = None + + elif self.opcode == PING: + self._sendMessage(False, PONG, self.data) + + elif self.opcode == PONG: + pass + + else: + if self.frag_start is True: + raise Exception('fragmentation protocol error') + + if self.opcode == TEXT: + try: + self.data = self.data.decode('utf8', errors='strict') + except Exception as exp: + raise Exception('invalid utf-8 payload') + + self.handleMessage() + + + def _handleData(self): + # do the HTTP header and handshake + if self.handshaked is False: + + try: + data = self.client.recv(self.headertoread) + except (ssl.SSLWantReadError, ssl.SSLWantWriteError): + # SSL socket not ready to read yet, wait and try again + return + if not data: + raise Exception('remote socket closed') + + else: + # accumulate + self.headerbuffer.extend(data) + + if len(self.headerbuffer) >= self.maxheader: + raise Exception('header exceeded allowable size') + + # indicates end of HTTP header + if b'\r\n\r\n' in self.headerbuffer: + self.request = HTTPRequest(self.headerbuffer) + + # handshake rfc 6455 + try: + key = self.request.headers['Sec-WebSocket-Key'] + k = key.encode('ascii') + GUID_STR.encode('ascii') + k_s = base64.b64encode(hashlib.sha1(k).digest()).decode('ascii') + hStr = HANDSHAKE_STR % {'acceptstr': k_s} + self.sendq.append((BINARY, hStr.encode('ascii'))) + self.handshaked = True + self.handleConnected() + except Exception as e: + hStr = FAILED_HANDSHAKE_STR + self._sendBuffer(hStr.encode('ascii'), True) + self.client.close() + raise Exception('handshake failed: %s', str(e)) + + # else do normal data + else: + try: + data = self.client.recv(16384) + except (ssl.SSLWantReadError, ssl.SSLWantWriteError): + # SSL socket not ready to read yet, wait and try again + return + if not data: + raise Exception("remote socket closed") + + if VER >= 3: + for d in data: + self._parseMessage(d) + else: + for d in data: + self._parseMessage(ord(d)) + + def close(self, status = 1000, reason = u''): + """ + Send Close frame to the client. The underlying socket is only closed + when the client acknowledges the Close frame. + + status is the closing identifier. + reason is the reason for the close. + """ + try: + if self.closed is False: + close_msg = bytearray() + close_msg.extend(struct.pack("!H", status)) + if _check_unicode(reason): + close_msg.extend(reason.encode('utf-8')) + else: + close_msg.extend(reason) + + self._sendMessage(False, CLOSE, close_msg) + + finally: + self.closed = True + + + def _sendBuffer(self, buff, send_all = False): + size = len(buff) + tosend = size + already_sent = 0 + + while tosend > 0: + try: + # i should be able to send a bytearray + sent = self.client.send(buff[already_sent:]) + if sent == 0: + raise RuntimeError('socket connection broken') + + already_sent += sent + tosend -= sent + + except (ssl.SSLWantReadError, ssl.SSLWantWriteError): + # SSL socket not ready to send yet, wait and try again + if send_all: + continue + return buff[already_sent:] + + except socket.error as e: + # if we have full buffers then wait for them to drain and try again + if e.errno in [errno.EAGAIN, errno.EWOULDBLOCK]: + if send_all: + continue + return buff[already_sent:] + else: + raise e + + return None + + def sendFragmentStart(self, data): + """ + Send the start of a data fragment stream to a websocket client. + Subsequent data should be sent using sendFragment(). + A fragment stream is completed when sendFragmentEnd() is called. + + If data is a unicode object then the frame is sent as Text. + If the data is a bytearray object then the frame is sent as Binary. + """ + opcode = BINARY + if _check_unicode(data): + opcode = TEXT + self._sendMessage(True, opcode, data) + + def sendFragment(self, data): + """ + see sendFragmentStart() + + If data is a unicode object then the frame is sent as Text. + If the data is a bytearray object then the frame is sent as Binary. + """ + self._sendMessage(True, STREAM, data) + + def sendFragmentEnd(self, data): + """ + see sendFragmentEnd() + + If data is a unicode object then the frame is sent as Text. + If the data is a bytearray object then the frame is sent as Binary. + """ + self._sendMessage(False, STREAM, data) + + def sendMessage(self, data): + """ + Send websocket data frame to the client. + + If data is a unicode object then the frame is sent as Text. + If the data is a bytearray object then the frame is sent as Binary. + """ + opcode = BINARY + if _check_unicode(data): + opcode = TEXT + self._sendMessage(False, opcode, data) + + + def _sendMessage(self, fin, opcode, data): + + payload = bytearray() + + b1 = 0 + b2 = 0 + if fin is False: + b1 |= 0x80 + b1 |= opcode + + if _check_unicode(data): + data = data.encode('utf-8') + + length = len(data) + payload.append(b1) + + if length <= 125: + b2 |= length + payload.append(b2) + + elif length >= 126 and length <= 65535: + b2 |= 126 + payload.append(b2) + payload.extend(struct.pack("!H", length)) + + else: + b2 |= 127 + payload.append(b2) + payload.extend(struct.pack("!Q", length)) + + if length > 0: + payload.extend(data) + + self.sendq.append((opcode, payload)) + + + def _parseMessage(self, byte): + # read in the header + if self.state == HEADERB1: + + self.fin = byte & 0x80 + self.opcode = byte & 0x0F + self.state = HEADERB2 + + self.index = 0 + self.length = 0 + self.lengtharray = bytearray() + self.data = bytearray() + + rsv = byte & 0x70 + if rsv != 0: + raise Exception('RSV bit must be 0') + + elif self.state == HEADERB2: + mask = byte & 0x80 + length = byte & 0x7F + + if self.opcode == PING and length > 125: + raise Exception('ping packet is too large') + + if mask == 128: + self.hasmask = True + else: + self.hasmask = False + + if length <= 125: + self.length = length + + # if we have a mask we must read it + if self.hasmask is True: + self.maskarray = bytearray() + self.state = MASK + else: + # if there is no mask and no payload we are done + if self.length <= 0: + try: + self._handlePacket() + finally: + self.state = HEADERB1 + self.data = bytearray() + + # we have no mask and some payload + else: + #self.index = 0 + self.data = bytearray() + self.state = PAYLOAD + + elif length == 126: + self.lengtharray = bytearray() + self.state = LENGTHSHORT + + elif length == 127: + self.lengtharray = bytearray() + self.state = LENGTHLONG + + + elif self.state == LENGTHSHORT: + self.lengtharray.append(byte) + + if len(self.lengtharray) > 2: + raise Exception('short length exceeded allowable size') + + if len(self.lengtharray) == 2: + self.length = struct.unpack_from('!H', self.lengtharray)[0] + + if self.hasmask is True: + self.maskarray = bytearray() + self.state = MASK + else: + # if there is no mask and no payload we are done + if self.length <= 0: + try: + self._handlePacket() + finally: + self.state = HEADERB1 + self.data = bytearray() + + # we have no mask and some payload + else: + #self.index = 0 + self.data = bytearray() + self.state = PAYLOAD + + elif self.state == LENGTHLONG: + + self.lengtharray.append(byte) + + if len(self.lengtharray) > 8: + raise Exception('long length exceeded allowable size') + + if len(self.lengtharray) == 8: + self.length = struct.unpack_from('!Q', self.lengtharray)[0] + + if self.hasmask is True: + self.maskarray = bytearray() + self.state = MASK + else: + # if there is no mask and no payload we are done + if self.length <= 0: + try: + self._handlePacket() + finally: + self.state = HEADERB1 + self.data = bytearray() + + # we have no mask and some payload + else: + #self.index = 0 + self.data = bytearray() + self.state = PAYLOAD + + # MASK STATE + elif self.state == MASK: + self.maskarray.append(byte) + + if len(self.maskarray) > 4: + raise Exception('mask exceeded allowable size') + + if len(self.maskarray) == 4: + # if there is no mask and no payload we are done + if self.length <= 0: + try: + self._handlePacket() + finally: + self.state = HEADERB1 + self.data = bytearray() + + # we have no mask and some payload + else: + #self.index = 0 + self.data = bytearray() + self.state = PAYLOAD + + # PAYLOAD STATE + elif self.state == PAYLOAD: + if self.hasmask is True: + self.data.append( byte ^ self.maskarray[self.index % 4] ) + else: + self.data.append( byte ) + + # if length exceeds allowable size then we except and remove the connection + if len(self.data) >= self.maxpayload: + raise Exception('payload exceeded allowable size') + + # check if we have processed length bytes; if so we are done + if (self.index+1) == self.length: + try: + self._handlePacket() + finally: + #self.index = 0 + self.state = HEADERB1 + self.data = bytearray() + else: + self.index += 1 + + +class SimpleWebSocketServer(object): + def __init__(self, host, port, websocketclass, selectInterval = 0.1): + self.websocketclass = websocketclass + + if (host == ''): + host = None + + hostInfo = socket.getaddrinfo(host, port, socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP, socket.AI_PASSIVE) + self.serversocket = socket.socket(socket.AF_INET, hostInfo[0][1], hostInfo[0][2]) + self.serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.serversocket.bind(hostInfo[0][4]) + self.serversocket.listen(5) + self.selectInterval = selectInterval + self.connections = {} + self.listeners = [self.serversocket] + + def _decorateSocket(self, sock): + return sock + + def _constructWebSocket(self, sock, address): + return self.websocketclass(self, sock, address) + + def close(self): + self.serversocket.close() + + for desc, conn in self.connections.items(): + conn.close() + self._handleClose(conn) + + def _handleClose(self, client): + client.client.close() + # only call handleClose when we have a successful websocket connection + if client.handshaked: + try: + client.handleClose() + except: + pass + + def serveonce(self): + writers = [] + for fileno in self.listeners: + if fileno == self.serversocket: + continue + client = self.connections[fileno] + if client.sendq: + writers.append(fileno) + + rList, wList, xList = select(self.listeners, writers, self.listeners, self.selectInterval) + + for ready in wList: + client = self.connections[ready] + try: + while client.sendq: + opcode, payload = client.sendq.popleft() + remaining = client._sendBuffer(payload) + if remaining is not None: + client.sendq.appendleft((opcode, remaining)) + break + else: + if opcode == CLOSE: + raise Exception('received client close') + + except Exception as n: + self._handleClose(client) + del self.connections[ready] + self.listeners.remove(ready) + + for ready in rList: + if ready == self.serversocket: + sock = None + try: + sock, address = self.serversocket.accept() + newsock = self._decorateSocket(sock) + newsock.setblocking(0) + fileno = newsock.fileno() + self.connections[fileno] = self._constructWebSocket(newsock, address) + self.listeners.append(fileno) + except Exception as n: + if sock is not None: + sock.close() + else: + if ready not in self.connections: + continue + client = self.connections[ready] + try: + client._handleData() + except Exception as n: + self._handleClose(client) + del self.connections[ready] + self.listeners.remove(ready) + + for failed in xList: + if failed == self.serversocket: + self.close() + raise Exception('server socket failed') + else: + if failed not in self.connections: + continue + client = self.connections[failed] + self._handleClose(client) + del self.connections[failed] + self.listeners.remove(failed) + + def serveforever(self): + while True: + self.serveonce() + +class SimpleSSLWebSocketServer(SimpleWebSocketServer): + + def __init__(self, host, port, websocketclass, certfile = None, + keyfile = None, version = ssl.PROTOCOL_TLSv1_2, selectInterval = 0.1, ssl_context = None): + + SimpleWebSocketServer.__init__(self, host, port, + websocketclass, selectInterval) + + if ssl_context is None: + self.context = ssl.SSLContext(version) + self.context.load_cert_chain(certfile, keyfile) + else: + self.context = ssl_context + + def close(self): + super(SimpleSSLWebSocketServer, self).close() + + def _decorateSocket(self, sock): + sslsock = self.context.wrap_socket(sock, server_side=True) + return sslsock + + def _constructWebSocket(self, sock, address): + ws = self.websocketclass(self, sock, address) + ws.usingssl = True + return ws + + def serveforever(self): + super(SimpleSSLWebSocketServer, self).serveforever() |