diff options
author | Louis Vézina <[email protected]> | 2020-04-15 00:02:44 -0400 |
---|---|---|
committer | Louis Vézina <[email protected]> | 2020-04-15 00:02:44 -0400 |
commit | 1b0e721a9d4b88bfbfea823798de92713d50826b (patch) | |
tree | 9139c5ca46fe1391f540a0190170edf08d30a588 /libs/waitress | |
parent | 02551f2486531cfdb83576ced380b72507fb2da0 (diff) | |
download | bazarr-1b0e721a9d4b88bfbfea823798de92713d50826b.tar.gz bazarr-1b0e721a9d4b88bfbfea823798de92713d50826b.zip |
WIP
Diffstat (limited to 'libs/waitress')
45 files changed, 14427 insertions, 0 deletions
diff --git a/libs/waitress/__init__.py b/libs/waitress/__init__.py new file mode 100644 index 000000000..e6e5911a5 --- /dev/null +++ b/libs/waitress/__init__.py @@ -0,0 +1,45 @@ +from waitress.server import create_server +import logging + + +def serve(app, **kw): + _server = kw.pop("_server", create_server) # test shim + _quiet = kw.pop("_quiet", False) # test shim + _profile = kw.pop("_profile", False) # test shim + if not _quiet: # pragma: no cover + # idempotent if logging has already been set up + logging.basicConfig() + server = _server(app, **kw) + if not _quiet: # pragma: no cover + server.print_listen("Serving on http://{}:{}") + if _profile: # pragma: no cover + profile("server.run()", globals(), locals(), (), False) + else: + server.run() + + +def serve_paste(app, global_conf, **kw): + serve(app, **kw) + return 0 + + +def profile(cmd, globals, locals, sort_order, callers): # pragma: no cover + # runs a command under the profiler and print profiling output at shutdown + import os + import profile + import pstats + import tempfile + + fd, fn = tempfile.mkstemp() + try: + profile.runctx(cmd, globals, locals, fn) + stats = pstats.Stats(fn) + stats.strip_dirs() + # calls,time,cumulative and cumulative,calls,time are useful + stats.sort_stats(*(sort_order or ("cumulative", "calls", "time"))) + if callers: + stats.print_callers(0.3) + else: + stats.print_stats(0.3) + finally: + os.remove(fn) diff --git a/libs/waitress/__main__.py b/libs/waitress/__main__.py new file mode 100644 index 000000000..9bcd07e59 --- /dev/null +++ b/libs/waitress/__main__.py @@ -0,0 +1,3 @@ +from waitress.runner import run # pragma nocover + +run() # pragma nocover diff --git a/libs/waitress/adjustments.py b/libs/waitress/adjustments.py new file mode 100644 index 000000000..93439eab8 --- /dev/null +++ b/libs/waitress/adjustments.py @@ -0,0 +1,515 @@ +############################################################################## +# +# Copyright (c) 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +"""Adjustments are tunable parameters. +""" +import getopt +import socket +import warnings + +from .proxy_headers import PROXY_HEADERS +from .compat import ( + PY2, + WIN, + string_types, + HAS_IPV6, +) + +truthy = frozenset(("t", "true", "y", "yes", "on", "1")) + +KNOWN_PROXY_HEADERS = frozenset( + header.lower().replace("_", "-") for header in PROXY_HEADERS +) + + +def asbool(s): + """ Return the boolean value ``True`` if the case-lowered value of string + input ``s`` is any of ``t``, ``true``, ``y``, ``on``, or ``1``, otherwise + return the boolean value ``False``. If ``s`` is the value ``None``, + return ``False``. If ``s`` is already one of the boolean values ``True`` + or ``False``, return it.""" + if s is None: + return False + if isinstance(s, bool): + return s + s = str(s).strip() + return s.lower() in truthy + + +def asoctal(s): + """Convert the given octal string to an actual number.""" + return int(s, 8) + + +def aslist_cronly(value): + if isinstance(value, string_types): + value = filter(None, [x.strip() for x in value.splitlines()]) + return list(value) + + +def aslist(value): + """ Return a list of strings, separating the input based on newlines + and, if flatten=True (the default), also split on spaces within + each line.""" + values = aslist_cronly(value) + result = [] + for value in values: + subvalues = value.split() + result.extend(subvalues) + return result + + +def asset(value): + return set(aslist(value)) + + +def slash_fixed_str(s): + s = s.strip() + if s: + # always have a leading slash, replace any number of leading slashes + # with a single slash, and strip any trailing slashes + s = "/" + s.lstrip("/").rstrip("/") + return s + + +def str_iftruthy(s): + return str(s) if s else None + + +def as_socket_list(sockets): + """Checks if the elements in the list are of type socket and + removes them if not.""" + return [sock for sock in sockets if isinstance(sock, socket.socket)] + + +class _str_marker(str): + pass + + +class _int_marker(int): + pass + + +class _bool_marker(object): + pass + + +class Adjustments(object): + """This class contains tunable parameters. + """ + + _params = ( + ("host", str), + ("port", int), + ("ipv4", asbool), + ("ipv6", asbool), + ("listen", aslist), + ("threads", int), + ("trusted_proxy", str_iftruthy), + ("trusted_proxy_count", int), + ("trusted_proxy_headers", asset), + ("log_untrusted_proxy_headers", asbool), + ("clear_untrusted_proxy_headers", asbool), + ("url_scheme", str), + ("url_prefix", slash_fixed_str), + ("backlog", int), + ("recv_bytes", int), + ("send_bytes", int), + ("outbuf_overflow", int), + ("outbuf_high_watermark", int), + ("inbuf_overflow", int), + ("connection_limit", int), + ("cleanup_interval", int), + ("channel_timeout", int), + ("log_socket_errors", asbool), + ("max_request_header_size", int), + ("max_request_body_size", int), + ("expose_tracebacks", asbool), + ("ident", str_iftruthy), + ("asyncore_loop_timeout", int), + ("asyncore_use_poll", asbool), + ("unix_socket", str), + ("unix_socket_perms", asoctal), + ("sockets", as_socket_list), + ) + + _param_map = dict(_params) + + # hostname or IP address to listen on + host = _str_marker("0.0.0.0") + + # TCP port to listen on + port = _int_marker(8080) + + listen = ["{}:{}".format(host, port)] + + # number of threads available for tasks + threads = 4 + + # Host allowed to overrid ``wsgi.url_scheme`` via header + trusted_proxy = None + + # How many proxies we trust when chained + # + # X-Forwarded-For: 192.0.2.1, "[2001:db8::1]" + # + # or + # + # Forwarded: for=192.0.2.1, For="[2001:db8::1]" + # + # means there were (potentially), two proxies involved. If we know there is + # only 1 valid proxy, then that initial IP address "192.0.2.1" is not + # trusted and we completely ignore it. If there are two trusted proxies in + # the path, this value should be set to a higher number. + trusted_proxy_count = None + + # Which of the proxy headers should we trust, this is a set where you + # either specify forwarded or one or more of forwarded-host, forwarded-for, + # forwarded-proto, forwarded-port. + trusted_proxy_headers = set() + + # Would you like waitress to log warnings about untrusted proxy headers + # that were encountered while processing the proxy headers? This only makes + # sense to set when you have a trusted_proxy, and you expect the upstream + # proxy server to filter invalid headers + log_untrusted_proxy_headers = False + + # Should waitress clear any proxy headers that are not deemed trusted from + # the environ? Change to True by default in 2.x + clear_untrusted_proxy_headers = _bool_marker + + # default ``wsgi.url_scheme`` value + url_scheme = "http" + + # default ``SCRIPT_NAME`` value, also helps reset ``PATH_INFO`` + # when nonempty + url_prefix = "" + + # server identity (sent in Server: header) + ident = "waitress" + + # backlog is the value waitress passes to pass to socket.listen() This is + # the maximum number of incoming TCP connections that will wait in an OS + # queue for an available channel. From listen(1): "If a connection + # request arrives when the queue is full, the client may receive an error + # with an indication of ECONNREFUSED or, if the underlying protocol + # supports retransmission, the request may be ignored so that a later + # reattempt at connection succeeds." + backlog = 1024 + + # recv_bytes is the argument to pass to socket.recv(). + recv_bytes = 8192 + + # deprecated setting controls how many bytes will be buffered before + # being flushed to the socket + send_bytes = 1 + + # A tempfile should be created if the pending output is larger than + # outbuf_overflow, which is measured in bytes. The default is 1MB. This + # is conservative. + outbuf_overflow = 1048576 + + # The app_iter will pause when pending output is larger than this value + # in bytes. + outbuf_high_watermark = 16777216 + + # A tempfile should be created if the pending input is larger than + # inbuf_overflow, which is measured in bytes. The default is 512K. This + # is conservative. + inbuf_overflow = 524288 + + # Stop creating new channels if too many are already active (integer). + # Each channel consumes at least one file descriptor, and, depending on + # the input and output body sizes, potentially up to three. The default + # is conservative, but you may need to increase the number of file + # descriptors available to the Waitress process on most platforms in + # order to safely change it (see ``ulimit -a`` "open files" setting). + # Note that this doesn't control the maximum number of TCP connections + # that can be waiting for processing; the ``backlog`` argument controls + # that. + connection_limit = 100 + + # Minimum seconds between cleaning up inactive channels. + cleanup_interval = 30 + + # Maximum seconds to leave an inactive connection open. + channel_timeout = 120 + + # Boolean: turn off to not log premature client disconnects. + log_socket_errors = True + + # maximum number of bytes of all request headers combined (256K default) + max_request_header_size = 262144 + + # maximum number of bytes in request body (1GB default) + max_request_body_size = 1073741824 + + # expose tracebacks of uncaught exceptions + expose_tracebacks = False + + # Path to a Unix domain socket to use. + unix_socket = None + + # Path to a Unix domain socket to use. + unix_socket_perms = 0o600 + + # The socket options to set on receiving a connection. It is a list of + # (level, optname, value) tuples. TCP_NODELAY disables the Nagle + # algorithm for writes (Waitress already buffers its writes). + socket_options = [ + (socket.SOL_TCP, socket.TCP_NODELAY, 1), + ] + + # The asyncore.loop timeout value + asyncore_loop_timeout = 1 + + # The asyncore.loop flag to use poll() instead of the default select(). + asyncore_use_poll = False + + # Enable IPv4 by default + ipv4 = True + + # Enable IPv6 by default + ipv6 = True + + # A list of sockets that waitress will use to accept connections. They can + # be used for e.g. socket activation + sockets = [] + + def __init__(self, **kw): + + if "listen" in kw and ("host" in kw or "port" in kw): + raise ValueError("host or port may not be set if listen is set.") + + if "listen" in kw and "sockets" in kw: + raise ValueError("socket may not be set if listen is set.") + + if "sockets" in kw and ("host" in kw or "port" in kw): + raise ValueError("host or port may not be set if sockets is set.") + + if "sockets" in kw and "unix_socket" in kw: + raise ValueError("unix_socket may not be set if sockets is set") + + if "unix_socket" in kw and ("host" in kw or "port" in kw): + raise ValueError("unix_socket may not be set if host or port is set") + + if "unix_socket" in kw and "listen" in kw: + raise ValueError("unix_socket may not be set if listen is set") + + if "send_bytes" in kw: + warnings.warn( + "send_bytes will be removed in a future release", DeprecationWarning, + ) + + for k, v in kw.items(): + if k not in self._param_map: + raise ValueError("Unknown adjustment %r" % k) + setattr(self, k, self._param_map[k](v)) + + if not isinstance(self.host, _str_marker) or not isinstance( + self.port, _int_marker + ): + self.listen = ["{}:{}".format(self.host, self.port)] + + enabled_families = socket.AF_UNSPEC + + if not self.ipv4 and not HAS_IPV6: # pragma: no cover + raise ValueError( + "IPv4 is disabled but IPv6 is not available. Cowardly refusing to start." + ) + + if self.ipv4 and not self.ipv6: + enabled_families = socket.AF_INET + + if not self.ipv4 and self.ipv6 and HAS_IPV6: + enabled_families = socket.AF_INET6 + + wanted_sockets = [] + hp_pairs = [] + for i in self.listen: + if ":" in i: + (host, port) = i.rsplit(":", 1) + + # IPv6 we need to make sure that we didn't split on the address + if "]" in port: # pragma: nocover + (host, port) = (i, str(self.port)) + else: + (host, port) = (i, str(self.port)) + + if WIN and PY2: # pragma: no cover + try: + # Try turning the port into an integer + port = int(port) + + except Exception: + raise ValueError( + "Windows does not support service names instead of port numbers" + ) + + try: + if "[" in host and "]" in host: # pragma: nocover + host = host.strip("[").rstrip("]") + + if host == "*": + host = None + + for s in socket.getaddrinfo( + host, + port, + enabled_families, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + socket.AI_PASSIVE, + ): + (family, socktype, proto, _, sockaddr) = s + + # It seems that getaddrinfo() may sometimes happily return + # the same result multiple times, this of course makes + # bind() very unhappy... + # + # Split on %, and drop the zone-index from the host in the + # sockaddr. Works around a bug in OS X whereby + # getaddrinfo() returns the same link-local interface with + # two different zone-indices (which makes no sense what so + # ever...) yet treats them equally when we attempt to bind(). + if ( + sockaddr[1] == 0 + or (sockaddr[0].split("%", 1)[0], sockaddr[1]) not in hp_pairs + ): + wanted_sockets.append((family, socktype, proto, sockaddr)) + hp_pairs.append((sockaddr[0].split("%", 1)[0], sockaddr[1])) + + except Exception: + raise ValueError("Invalid host/port specified.") + + if self.trusted_proxy_count is not None and self.trusted_proxy is None: + raise ValueError( + "trusted_proxy_count has no meaning without setting " "trusted_proxy" + ) + + elif self.trusted_proxy_count is None: + self.trusted_proxy_count = 1 + + if self.trusted_proxy_headers and self.trusted_proxy is None: + raise ValueError( + "trusted_proxy_headers has no meaning without setting " "trusted_proxy" + ) + + if self.trusted_proxy_headers: + self.trusted_proxy_headers = { + header.lower() for header in self.trusted_proxy_headers + } + + unknown_values = self.trusted_proxy_headers - KNOWN_PROXY_HEADERS + if unknown_values: + raise ValueError( + "Received unknown trusted_proxy_headers value (%s) expected one " + "of %s" + % (", ".join(unknown_values), ", ".join(KNOWN_PROXY_HEADERS)) + ) + + if ( + "forwarded" in self.trusted_proxy_headers + and self.trusted_proxy_headers - {"forwarded"} + ): + raise ValueError( + "The Forwarded proxy header and the " + "X-Forwarded-{By,Host,Proto,Port,For} headers are mutually " + "exclusive. Can't trust both!" + ) + + elif self.trusted_proxy is not None: + warnings.warn( + "No proxy headers were marked as trusted, but trusted_proxy was set. " + "Implicitly trusting X-Forwarded-Proto for backwards compatibility. " + "This will be removed in future versions of waitress.", + DeprecationWarning, + ) + self.trusted_proxy_headers = {"x-forwarded-proto"} + + if self.clear_untrusted_proxy_headers is _bool_marker: + warnings.warn( + "In future versions of Waitress clear_untrusted_proxy_headers will be " + "set to True by default. You may opt-out by setting this value to " + "False, or opt-in explicitly by setting this to True.", + DeprecationWarning, + ) + self.clear_untrusted_proxy_headers = False + + self.listen = wanted_sockets + + self.check_sockets(self.sockets) + + @classmethod + def parse_args(cls, argv): + """Pre-parse command line arguments for input into __init__. Note that + this does not cast values into adjustment types, it just creates a + dictionary suitable for passing into __init__, where __init__ does the + casting. + """ + long_opts = ["help", "call"] + for opt, cast in cls._params: + opt = opt.replace("_", "-") + if cast is asbool: + long_opts.append(opt) + long_opts.append("no-" + opt) + else: + long_opts.append(opt + "=") + + kw = { + "help": False, + "call": False, + } + + opts, args = getopt.getopt(argv, "", long_opts) + for opt, value in opts: + param = opt.lstrip("-").replace("-", "_") + + if param == "listen": + kw["listen"] = "{} {}".format(kw.get("listen", ""), value) + continue + + if param.startswith("no_"): + param = param[3:] + kw[param] = "false" + elif param in ("help", "call"): + kw[param] = True + elif cls._param_map[param] is asbool: + kw[param] = "true" + else: + kw[param] = value + + return kw, args + + @classmethod + def check_sockets(cls, sockets): + has_unix_socket = False + has_inet_socket = False + has_unsupported_socket = False + for sock in sockets: + if ( + sock.family == socket.AF_INET or sock.family == socket.AF_INET6 + ) and sock.type == socket.SOCK_STREAM: + has_inet_socket = True + elif ( + hasattr(socket, "AF_UNIX") + and sock.family == socket.AF_UNIX + and sock.type == socket.SOCK_STREAM + ): + has_unix_socket = True + else: + has_unsupported_socket = True + if has_unix_socket and has_inet_socket: + raise ValueError("Internet and UNIX sockets may not be mixed.") + if has_unsupported_socket: + raise ValueError("Only Internet or UNIX stream sockets may be used.") diff --git a/libs/waitress/buffers.py b/libs/waitress/buffers.py new file mode 100644 index 000000000..04f6b4274 --- /dev/null +++ b/libs/waitress/buffers.py @@ -0,0 +1,308 @@ +############################################################################## +# +# Copyright (c) 2001-2004 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +"""Buffers +""" +from io import BytesIO + +# copy_bytes controls the size of temp. strings for shuffling data around. +COPY_BYTES = 1 << 18 # 256K + +# The maximum number of bytes to buffer in a simple string. +STRBUF_LIMIT = 8192 + + +class FileBasedBuffer(object): + + remain = 0 + + def __init__(self, file, from_buffer=None): + self.file = file + if from_buffer is not None: + from_file = from_buffer.getfile() + read_pos = from_file.tell() + from_file.seek(0) + while True: + data = from_file.read(COPY_BYTES) + if not data: + break + file.write(data) + self.remain = int(file.tell() - read_pos) + from_file.seek(read_pos) + file.seek(read_pos) + + def __len__(self): + return self.remain + + def __nonzero__(self): + return True + + __bool__ = __nonzero__ # py3 + + def append(self, s): + file = self.file + read_pos = file.tell() + file.seek(0, 2) + file.write(s) + file.seek(read_pos) + self.remain = self.remain + len(s) + + def get(self, numbytes=-1, skip=False): + file = self.file + if not skip: + read_pos = file.tell() + if numbytes < 0: + # Read all + res = file.read() + else: + res = file.read(numbytes) + if skip: + self.remain -= len(res) + else: + file.seek(read_pos) + return res + + def skip(self, numbytes, allow_prune=0): + if self.remain < numbytes: + raise ValueError( + "Can't skip %d bytes in buffer of %d bytes" % (numbytes, self.remain) + ) + self.file.seek(numbytes, 1) + self.remain = self.remain - numbytes + + def newfile(self): + raise NotImplementedError() + + def prune(self): + file = self.file + if self.remain == 0: + read_pos = file.tell() + file.seek(0, 2) + sz = file.tell() + file.seek(read_pos) + if sz == 0: + # Nothing to prune. + return + nf = self.newfile() + while True: + data = file.read(COPY_BYTES) + if not data: + break + nf.write(data) + self.file = nf + + def getfile(self): + return self.file + + def close(self): + if hasattr(self.file, "close"): + self.file.close() + self.remain = 0 + + +class TempfileBasedBuffer(FileBasedBuffer): + def __init__(self, from_buffer=None): + FileBasedBuffer.__init__(self, self.newfile(), from_buffer) + + def newfile(self): + from tempfile import TemporaryFile + + return TemporaryFile("w+b") + + +class BytesIOBasedBuffer(FileBasedBuffer): + def __init__(self, from_buffer=None): + if from_buffer is not None: + FileBasedBuffer.__init__(self, BytesIO(), from_buffer) + else: + # Shortcut. :-) + self.file = BytesIO() + + def newfile(self): + return BytesIO() + + +def _is_seekable(fp): + if hasattr(fp, "seekable"): + return fp.seekable() + return hasattr(fp, "seek") and hasattr(fp, "tell") + + +class ReadOnlyFileBasedBuffer(FileBasedBuffer): + # used as wsgi.file_wrapper + + def __init__(self, file, block_size=32768): + self.file = file + self.block_size = block_size # for __iter__ + + def prepare(self, size=None): + if _is_seekable(self.file): + start_pos = self.file.tell() + self.file.seek(0, 2) + end_pos = self.file.tell() + self.file.seek(start_pos) + fsize = end_pos - start_pos + if size is None: + self.remain = fsize + else: + self.remain = min(fsize, size) + return self.remain + + def get(self, numbytes=-1, skip=False): + # never read more than self.remain (it can be user-specified) + if numbytes == -1 or numbytes > self.remain: + numbytes = self.remain + file = self.file + if not skip: + read_pos = file.tell() + res = file.read(numbytes) + if skip: + self.remain -= len(res) + else: + file.seek(read_pos) + return res + + def __iter__(self): # called by task if self.filelike has no seek/tell + return self + + def next(self): + val = self.file.read(self.block_size) + if not val: + raise StopIteration + return val + + __next__ = next # py3 + + def append(self, s): + raise NotImplementedError + + +class OverflowableBuffer(object): + """ + This buffer implementation has four stages: + - No data + - Bytes-based buffer + - BytesIO-based buffer + - Temporary file storage + The first two stages are fastest for simple transfers. + """ + + overflowed = False + buf = None + strbuf = b"" # Bytes-based buffer. + + def __init__(self, overflow): + # overflow is the maximum to be stored in a StringIO buffer. + self.overflow = overflow + + def __len__(self): + buf = self.buf + if buf is not None: + # use buf.__len__ rather than len(buf) FBO of not getting + # OverflowError on Python 2 + return buf.__len__() + else: + return self.strbuf.__len__() + + def __nonzero__(self): + # use self.__len__ rather than len(self) FBO of not getting + # OverflowError on Python 2 + return self.__len__() > 0 + + __bool__ = __nonzero__ # py3 + + def _create_buffer(self): + strbuf = self.strbuf + if len(strbuf) >= self.overflow: + self._set_large_buffer() + else: + self._set_small_buffer() + buf = self.buf + if strbuf: + buf.append(self.strbuf) + self.strbuf = b"" + return buf + + def _set_small_buffer(self): + self.buf = BytesIOBasedBuffer(self.buf) + self.overflowed = False + + def _set_large_buffer(self): + self.buf = TempfileBasedBuffer(self.buf) + self.overflowed = True + + def append(self, s): + buf = self.buf + if buf is None: + strbuf = self.strbuf + if len(strbuf) + len(s) < STRBUF_LIMIT: + self.strbuf = strbuf + s + return + buf = self._create_buffer() + buf.append(s) + # use buf.__len__ rather than len(buf) FBO of not getting + # OverflowError on Python 2 + sz = buf.__len__() + if not self.overflowed: + if sz >= self.overflow: + self._set_large_buffer() + + def get(self, numbytes=-1, skip=False): + buf = self.buf + if buf is None: + strbuf = self.strbuf + if not skip: + return strbuf + buf = self._create_buffer() + return buf.get(numbytes, skip) + + def skip(self, numbytes, allow_prune=False): + buf = self.buf + if buf is None: + if allow_prune and numbytes == len(self.strbuf): + # We could slice instead of converting to + # a buffer, but that would eat up memory in + # large transfers. + self.strbuf = b"" + return + buf = self._create_buffer() + buf.skip(numbytes, allow_prune) + + def prune(self): + """ + A potentially expensive operation that removes all data + already retrieved from the buffer. + """ + buf = self.buf + if buf is None: + self.strbuf = b"" + return + buf.prune() + if self.overflowed: + # use buf.__len__ rather than len(buf) FBO of not getting + # OverflowError on Python 2 + sz = buf.__len__() + if sz < self.overflow: + # Revert to a faster buffer. + self._set_small_buffer() + + def getfile(self): + buf = self.buf + if buf is None: + buf = self._create_buffer() + return buf.getfile() + + def close(self): + buf = self.buf + if buf is not None: + buf.close() diff --git a/libs/waitress/channel.py b/libs/waitress/channel.py new file mode 100644 index 000000000..a8bc76f74 --- /dev/null +++ b/libs/waitress/channel.py @@ -0,0 +1,414 @@ +############################################################################## +# +# Copyright (c) 2001, 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +import socket +import threading +import time +import traceback + +from waitress.buffers import ( + OverflowableBuffer, + ReadOnlyFileBasedBuffer, +) + +from waitress.parser import HTTPRequestParser + +from waitress.task import ( + ErrorTask, + WSGITask, +) + +from waitress.utilities import InternalServerError + +from . import wasyncore + + +class ClientDisconnected(Exception): + """ Raised when attempting to write to a closed socket.""" + + +class HTTPChannel(wasyncore.dispatcher, object): + """ + Setting self.requests = [somerequest] prevents more requests from being + received until the out buffers have been flushed. + + Setting self.requests = [] allows more requests to be received. + """ + + task_class = WSGITask + error_task_class = ErrorTask + parser_class = HTTPRequestParser + + request = None # A request parser instance + last_activity = 0 # Time of last activity + will_close = False # set to True to close the socket. + close_when_flushed = False # set to True to close the socket when flushed + requests = () # currently pending requests + sent_continue = False # used as a latch after sending 100 continue + total_outbufs_len = 0 # total bytes ready to send + current_outbuf_count = 0 # total bytes written to current outbuf + + # + # ASYNCHRONOUS METHODS (including __init__) + # + + def __init__( + self, server, sock, addr, adj, map=None, + ): + self.server = server + self.adj = adj + self.outbufs = [OverflowableBuffer(adj.outbuf_overflow)] + self.creation_time = self.last_activity = time.time() + self.sendbuf_len = sock.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF) + + # task_lock used to push/pop requests + self.task_lock = threading.Lock() + # outbuf_lock used to access any outbuf (expected to use an RLock) + self.outbuf_lock = threading.Condition() + + wasyncore.dispatcher.__init__(self, sock, map=map) + + # Don't let wasyncore.dispatcher throttle self.addr on us. + self.addr = addr + + def writable(self): + # if there's data in the out buffer or we've been instructed to close + # the channel (possibly by our server maintenance logic), run + # handle_write + return self.total_outbufs_len or self.will_close or self.close_when_flushed + + def handle_write(self): + # Precondition: there's data in the out buffer to be sent, or + # there's a pending will_close request + if not self.connected: + # we dont want to close the channel twice + return + + # try to flush any pending output + if not self.requests: + # 1. There are no running tasks, so we don't need to try to lock + # the outbuf before sending + # 2. The data in the out buffer should be sent as soon as possible + # because it's either data left over from task output + # or a 100 Continue line sent within "received". + flush = self._flush_some + elif self.total_outbufs_len >= self.adj.send_bytes: + # 1. There's a running task, so we need to try to lock + # the outbuf before sending + # 2. Only try to send if the data in the out buffer is larger + # than self.adj_bytes to avoid TCP fragmentation + flush = self._flush_some_if_lockable + else: + # 1. There's not enough data in the out buffer to bother to send + # right now. + flush = None + + if flush: + try: + flush() + except socket.error: + if self.adj.log_socket_errors: + self.logger.exception("Socket error") + self.will_close = True + except Exception: + self.logger.exception("Unexpected exception when flushing") + self.will_close = True + + if self.close_when_flushed and not self.total_outbufs_len: + self.close_when_flushed = False + self.will_close = True + + if self.will_close: + self.handle_close() + + def readable(self): + # We might want to create a new task. We can only do this if: + # 1. We're not already about to close the connection. + # 2. There's no already currently running task(s). + # 3. There's no data in the output buffer that needs to be sent + # before we potentially create a new task. + return not (self.will_close or self.requests or self.total_outbufs_len) + + def handle_read(self): + try: + data = self.recv(self.adj.recv_bytes) + except socket.error: + if self.adj.log_socket_errors: + self.logger.exception("Socket error") + self.handle_close() + return + if data: + self.last_activity = time.time() + self.received(data) + + def received(self, data): + """ + Receives input asynchronously and assigns one or more requests to the + channel. + """ + # Preconditions: there's no task(s) already running + request = self.request + requests = [] + + if not data: + return False + + while data: + if request is None: + request = self.parser_class(self.adj) + n = request.received(data) + if request.expect_continue and request.headers_finished: + # guaranteed by parser to be a 1.1 request + request.expect_continue = False + if not self.sent_continue: + # there's no current task, so we don't need to try to + # lock the outbuf to append to it. + outbuf_payload = b"HTTP/1.1 100 Continue\r\n\r\n" + self.outbufs[-1].append(outbuf_payload) + self.current_outbuf_count += len(outbuf_payload) + self.total_outbufs_len += len(outbuf_payload) + self.sent_continue = True + self._flush_some() + request.completed = False + if request.completed: + # The request (with the body) is ready to use. + self.request = None + if not request.empty: + requests.append(request) + request = None + else: + self.request = request + if n >= len(data): + break + data = data[n:] + + if requests: + self.requests = requests + self.server.add_task(self) + + return True + + def _flush_some_if_lockable(self): + # Since our task may be appending to the outbuf, we try to acquire + # the lock, but we don't block if we can't. + if self.outbuf_lock.acquire(False): + try: + self._flush_some() + + if self.total_outbufs_len < self.adj.outbuf_high_watermark: + self.outbuf_lock.notify() + finally: + self.outbuf_lock.release() + + def _flush_some(self): + # Send as much data as possible to our client + + sent = 0 + dobreak = False + + while True: + outbuf = self.outbufs[0] + # use outbuf.__len__ rather than len(outbuf) FBO of not getting + # OverflowError on 32-bit Python + outbuflen = outbuf.__len__() + while outbuflen > 0: + chunk = outbuf.get(self.sendbuf_len) + num_sent = self.send(chunk) + if num_sent: + outbuf.skip(num_sent, True) + outbuflen -= num_sent + sent += num_sent + self.total_outbufs_len -= num_sent + else: + # failed to write anything, break out entirely + dobreak = True + break + else: + # self.outbufs[-1] must always be a writable outbuf + if len(self.outbufs) > 1: + toclose = self.outbufs.pop(0) + try: + toclose.close() + except Exception: + self.logger.exception("Unexpected error when closing an outbuf") + else: + # caught up, done flushing for now + dobreak = True + + if dobreak: + break + + if sent: + self.last_activity = time.time() + return True + + return False + + def handle_close(self): + with self.outbuf_lock: + for outbuf in self.outbufs: + try: + outbuf.close() + except Exception: + self.logger.exception( + "Unknown exception while trying to close outbuf" + ) + self.total_outbufs_len = 0 + self.connected = False + self.outbuf_lock.notify() + wasyncore.dispatcher.close(self) + + def add_channel(self, map=None): + """See wasyncore.dispatcher + + This hook keeps track of opened channels. + """ + wasyncore.dispatcher.add_channel(self, map) + self.server.active_channels[self._fileno] = self + + def del_channel(self, map=None): + """See wasyncore.dispatcher + + This hook keeps track of closed channels. + """ + fd = self._fileno # next line sets this to None + wasyncore.dispatcher.del_channel(self, map) + ac = self.server.active_channels + if fd in ac: + del ac[fd] + + # + # SYNCHRONOUS METHODS + # + + def write_soon(self, data): + if not self.connected: + # if the socket is closed then interrupt the task so that it + # can cleanup possibly before the app_iter is exhausted + raise ClientDisconnected + if data: + # the async mainloop might be popping data off outbuf; we can + # block here waiting for it because we're in a task thread + with self.outbuf_lock: + self._flush_outbufs_below_high_watermark() + if not self.connected: + raise ClientDisconnected + num_bytes = len(data) + if data.__class__ is ReadOnlyFileBasedBuffer: + # they used wsgi.file_wrapper + self.outbufs.append(data) + nextbuf = OverflowableBuffer(self.adj.outbuf_overflow) + self.outbufs.append(nextbuf) + self.current_outbuf_count = 0 + else: + if self.current_outbuf_count > self.adj.outbuf_high_watermark: + # rotate to a new buffer if the current buffer has hit + # the watermark to avoid it growing unbounded + nextbuf = OverflowableBuffer(self.adj.outbuf_overflow) + self.outbufs.append(nextbuf) + self.current_outbuf_count = 0 + self.outbufs[-1].append(data) + self.current_outbuf_count += num_bytes + self.total_outbufs_len += num_bytes + if self.total_outbufs_len >= self.adj.send_bytes: + self.server.pull_trigger() + return num_bytes + return 0 + + def _flush_outbufs_below_high_watermark(self): + # check first to avoid locking if possible + if self.total_outbufs_len > self.adj.outbuf_high_watermark: + with self.outbuf_lock: + while ( + self.connected + and self.total_outbufs_len > self.adj.outbuf_high_watermark + ): + self.server.pull_trigger() + self.outbuf_lock.wait() + + def service(self): + """Execute all pending requests """ + with self.task_lock: + while self.requests: + request = self.requests[0] + if request.error: + task = self.error_task_class(self, request) + else: + task = self.task_class(self, request) + try: + task.service() + except ClientDisconnected: + self.logger.info( + "Client disconnected while serving %s" % task.request.path + ) + task.close_on_finish = True + except Exception: + self.logger.exception( + "Exception while serving %s" % task.request.path + ) + if not task.wrote_header: + if self.adj.expose_tracebacks: + body = traceback.format_exc() + else: + body = ( + "The server encountered an unexpected " + "internal server error" + ) + req_version = request.version + req_headers = request.headers + request = self.parser_class(self.adj) + request.error = InternalServerError(body) + # copy some original request attributes to fulfill + # HTTP 1.1 requirements + request.version = req_version + try: + request.headers["CONNECTION"] = req_headers["CONNECTION"] + except KeyError: + pass + task = self.error_task_class(self, request) + try: + task.service() # must not fail + except ClientDisconnected: + task.close_on_finish = True + else: + task.close_on_finish = True + # we cannot allow self.requests to drop to empty til + # here; otherwise the mainloop gets confused + if task.close_on_finish: + self.close_when_flushed = True + for request in self.requests: + request.close() + self.requests = [] + else: + # before processing a new request, ensure there is not too + # much data in the outbufs waiting to be flushed + # NB: currently readable() returns False while we are + # flushing data so we know no new requests will come in + # that we need to account for, otherwise it'd be better + # to do this check at the start of the request instead of + # at the end to account for consecutive service() calls + if len(self.requests) > 1: + self._flush_outbufs_below_high_watermark() + request = self.requests.pop(0) + request.close() + + if self.connected: + self.server.pull_trigger() + self.last_activity = time.time() + + def cancel(self): + """ Cancels all pending / active requests """ + self.will_close = True + self.connected = False + self.last_activity = time.time() + self.requests = [] diff --git a/libs/waitress/compat.py b/libs/waitress/compat.py new file mode 100644 index 000000000..fe72a7610 --- /dev/null +++ b/libs/waitress/compat.py @@ -0,0 +1,179 @@ +import os +import sys +import types +import platform +import warnings + +try: + import urlparse +except ImportError: # pragma: no cover + from urllib import parse as urlparse + +try: + import fcntl +except ImportError: # pragma: no cover + fcntl = None # windows + +# True if we are running on Python 3. +PY2 = sys.version_info[0] == 2 +PY3 = sys.version_info[0] == 3 + +# True if we are running on Windows +WIN = platform.system() == "Windows" + +if PY3: # pragma: no cover + string_types = (str,) + integer_types = (int,) + class_types = (type,) + text_type = str + binary_type = bytes + long = int +else: + string_types = (basestring,) + integer_types = (int, long) + class_types = (type, types.ClassType) + text_type = unicode + binary_type = str + long = long + +if PY3: # pragma: no cover + from urllib.parse import unquote_to_bytes + + def unquote_bytes_to_wsgi(bytestring): + return unquote_to_bytes(bytestring).decode("latin-1") + + +else: + from urlparse import unquote as unquote_to_bytes + + def unquote_bytes_to_wsgi(bytestring): + return unquote_to_bytes(bytestring) + + +def text_(s, encoding="latin-1", errors="strict"): + """ If ``s`` is an instance of ``binary_type``, return + ``s.decode(encoding, errors)``, otherwise return ``s``""" + if isinstance(s, binary_type): + return s.decode(encoding, errors) + return s # pragma: no cover + + +if PY3: # pragma: no cover + + def tostr(s): + if isinstance(s, text_type): + s = s.encode("latin-1") + return str(s, "latin-1", "strict") + + def tobytes(s): + return bytes(s, "latin-1") + + +else: + tostr = str + + def tobytes(s): + return s + + +if PY3: # pragma: no cover + import builtins + + exec_ = getattr(builtins, "exec") + + def reraise(tp, value, tb=None): + if value is None: + value = tp + if value.__traceback__ is not tb: + raise value.with_traceback(tb) + raise value + + del builtins + +else: # pragma: no cover + + def exec_(code, globs=None, locs=None): + """Execute code in a namespace.""" + if globs is None: + frame = sys._getframe(1) + globs = frame.f_globals + if locs is None: + locs = frame.f_locals + del frame + elif locs is None: + locs = globs + exec("""exec code in globs, locs""") + + exec_( + """def reraise(tp, value, tb=None): + raise tp, value, tb +""" + ) + +try: + from StringIO import StringIO as NativeIO +except ImportError: # pragma: no cover + from io import StringIO as NativeIO + +try: + import httplib +except ImportError: # pragma: no cover + from http import client as httplib + +try: + MAXINT = sys.maxint +except AttributeError: # pragma: no cover + MAXINT = sys.maxsize + + +# Fix for issue reported in https://github.com/Pylons/waitress/issues/138, +# Python on Windows may not define IPPROTO_IPV6 in socket. +import socket + +HAS_IPV6 = socket.has_ipv6 + +if hasattr(socket, "IPPROTO_IPV6") and hasattr(socket, "IPV6_V6ONLY"): + IPPROTO_IPV6 = socket.IPPROTO_IPV6 + IPV6_V6ONLY = socket.IPV6_V6ONLY +else: # pragma: no cover + if WIN: + IPPROTO_IPV6 = 41 + IPV6_V6ONLY = 27 + else: + warnings.warn( + "OS does not support required IPv6 socket flags. This is requirement " + "for Waitress. Please open an issue at https://github.com/Pylons/waitress. " + "IPv6 support has been disabled.", + RuntimeWarning, + ) + HAS_IPV6 = False + + +def set_nonblocking(fd): # pragma: no cover + if PY3 and sys.version_info[1] >= 5: + os.set_blocking(fd, False) + elif fcntl is None: + raise RuntimeError("no fcntl module present") + else: + flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0) + flags = flags | os.O_NONBLOCK + fcntl.fcntl(fd, fcntl.F_SETFL, flags) + + +if PY3: + ResourceWarning = ResourceWarning +else: + ResourceWarning = UserWarning + + +def qualname(cls): + if PY3: + return cls.__qualname__ + return cls.__name__ + + +try: + import thread +except ImportError: + # py3 + import _thread as thread diff --git a/libs/waitress/parser.py b/libs/waitress/parser.py new file mode 100644 index 000000000..fef8a3da6 --- /dev/null +++ b/libs/waitress/parser.py @@ -0,0 +1,413 @@ +############################################################################## +# +# Copyright (c) 2001, 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +"""HTTP Request Parser + +This server uses asyncore to accept connections and do initial +processing but threads to do work. +""" +import re +from io import BytesIO + +from waitress.buffers import OverflowableBuffer +from waitress.compat import tostr, unquote_bytes_to_wsgi, urlparse +from waitress.receiver import ChunkedReceiver, FixedStreamReceiver +from waitress.utilities import ( + BadRequest, + RequestEntityTooLarge, + RequestHeaderFieldsTooLarge, + ServerNotImplemented, + find_double_newline, +) +from .rfc7230 import HEADER_FIELD + + +class ParsingError(Exception): + pass + + +class TransferEncodingNotImplemented(Exception): + pass + +class HTTPRequestParser(object): + """A structure that collects the HTTP request. + + Once the stream is completed, the instance is passed to + a server task constructor. + """ + + completed = False # Set once request is completed. + empty = False # Set if no request was made. + expect_continue = False # client sent "Expect: 100-continue" header + headers_finished = False # True when headers have been read + header_plus = b"" + chunked = False + content_length = 0 + header_bytes_received = 0 + body_bytes_received = 0 + body_rcv = None + version = "1.0" + error = None + connection_close = False + + # Other attributes: first_line, header, headers, command, uri, version, + # path, query, fragment + + def __init__(self, adj): + """ + adj is an Adjustments object. + """ + # headers is a mapping containing keys translated to uppercase + # with dashes turned into underscores. + self.headers = {} + self.adj = adj + + def received(self, data): + """ + Receives the HTTP stream for one request. Returns the number of + bytes consumed. Sets the completed flag once both the header and the + body have been received. + """ + if self.completed: + return 0 # Can't consume any more. + + datalen = len(data) + br = self.body_rcv + if br is None: + # In header. + max_header = self.adj.max_request_header_size + + s = self.header_plus + data + index = find_double_newline(s) + consumed = 0 + + if index >= 0: + # If the headers have ended, and we also have part of the body + # message in data we still want to validate we aren't going + # over our limit for received headers. + self.header_bytes_received += index + consumed = datalen - (len(s) - index) + else: + self.header_bytes_received += datalen + consumed = datalen + + # If the first line + headers is over the max length, we return a + # RequestHeaderFieldsTooLarge error rather than continuing to + # attempt to parse the headers. + if self.header_bytes_received >= max_header: + self.parse_header(b"GET / HTTP/1.0\r\n") + self.error = RequestHeaderFieldsTooLarge( + "exceeds max_header of %s" % max_header + ) + self.completed = True + return consumed + + if index >= 0: + # Header finished. + header_plus = s[:index] + + # Remove preceeding blank lines. This is suggested by + # https://tools.ietf.org/html/rfc7230#section-3.5 to support + # clients sending an extra CR LF after another request when + # using HTTP pipelining + header_plus = header_plus.lstrip() + + if not header_plus: + self.empty = True + self.completed = True + else: + try: + self.parse_header(header_plus) + except ParsingError as e: + self.error = BadRequest(e.args[0]) + self.completed = True + except TransferEncodingNotImplemented as e: + self.error = ServerNotImplemented(e.args[0]) + self.completed = True + else: + if self.body_rcv is None: + # no content-length header and not a t-e: chunked + # request + self.completed = True + + if self.content_length > 0: + max_body = self.adj.max_request_body_size + # we won't accept this request if the content-length + # is too large + + if self.content_length >= max_body: + self.error = RequestEntityTooLarge( + "exceeds max_body of %s" % max_body + ) + self.completed = True + self.headers_finished = True + + return consumed + + # Header not finished yet. + self.header_plus = s + + return datalen + else: + # In body. + consumed = br.received(data) + self.body_bytes_received += consumed + max_body = self.adj.max_request_body_size + + if self.body_bytes_received >= max_body: + # this will only be raised during t-e: chunked requests + self.error = RequestEntityTooLarge("exceeds max_body of %s" % max_body) + self.completed = True + elif br.error: + # garbage in chunked encoding input probably + self.error = br.error + self.completed = True + elif br.completed: + # The request (with the body) is ready to use. + self.completed = True + + if self.chunked: + # We've converted the chunked transfer encoding request + # body into a normal request body, so we know its content + # length; set the header here. We already popped the + # TRANSFER_ENCODING header in parse_header, so this will + # appear to the client to be an entirely non-chunked HTTP + # request with a valid content-length. + self.headers["CONTENT_LENGTH"] = str(br.__len__()) + + return consumed + + def parse_header(self, header_plus): + """ + Parses the header_plus block of text (the headers plus the + first line of the request). + """ + index = header_plus.find(b"\r\n") + if index >= 0: + first_line = header_plus[:index].rstrip() + header = header_plus[index + 2 :] + else: + raise ParsingError("HTTP message header invalid") + + if b"\r" in first_line or b"\n" in first_line: + raise ParsingError("Bare CR or LF found in HTTP message") + + self.first_line = first_line # for testing + + lines = get_header_lines(header) + + headers = self.headers + for line in lines: + header = HEADER_FIELD.match(line) + + if not header: + raise ParsingError("Invalid header") + + key, value = header.group("name", "value") + + if b"_" in key: + # TODO(xistence): Should we drop this request instead? + continue + + # Only strip off whitespace that is considered valid whitespace by + # RFC7230, don't strip the rest + value = value.strip(b" \t") + key1 = tostr(key.upper().replace(b"-", b"_")) + # If a header already exists, we append subsequent values + # seperated by a comma. Applications already need to handle + # the comma seperated values, as HTTP front ends might do + # the concatenation for you (behavior specified in RFC2616). + try: + headers[key1] += tostr(b", " + value) + except KeyError: + headers[key1] = tostr(value) + + # command, uri, version will be bytes + command, uri, version = crack_first_line(first_line) + version = tostr(version) + command = tostr(command) + self.command = command + self.version = version + ( + self.proxy_scheme, + self.proxy_netloc, + self.path, + self.query, + self.fragment, + ) = split_uri(uri) + self.url_scheme = self.adj.url_scheme + connection = headers.get("CONNECTION", "") + + if version == "1.0": + if connection.lower() != "keep-alive": + self.connection_close = True + + if version == "1.1": + # since the server buffers data from chunked transfers and clients + # never need to deal with chunked requests, downstream clients + # should not see the HTTP_TRANSFER_ENCODING header; we pop it + # here + te = headers.pop("TRANSFER_ENCODING", "") + + # NB: We can not just call bare strip() here because it will also + # remove other non-printable characters that we explicitly do not + # want removed so that if someone attempts to smuggle a request + # with these characters we don't fall prey to it. + # + # For example \x85 is stripped by default, but it is not considered + # valid whitespace to be stripped by RFC7230. + encodings = [ + encoding.strip(" \t").lower() for encoding in te.split(",") if encoding + ] + + for encoding in encodings: + # Out of the transfer-codings listed in + # https://tools.ietf.org/html/rfc7230#section-4 we only support + # chunked at this time. + + # Note: the identity transfer-coding was removed in RFC7230: + # https://tools.ietf.org/html/rfc7230#appendix-A.2 and is thus + # not supported + if encoding not in {"chunked"}: + raise TransferEncodingNotImplemented( + "Transfer-Encoding requested is not supported." + ) + + if encodings and encodings[-1] == "chunked": + self.chunked = True + buf = OverflowableBuffer(self.adj.inbuf_overflow) + self.body_rcv = ChunkedReceiver(buf) + elif encodings: # pragma: nocover + raise TransferEncodingNotImplemented( + "Transfer-Encoding requested is not supported." + ) + + expect = headers.get("EXPECT", "").lower() + self.expect_continue = expect == "100-continue" + if connection.lower() == "close": + self.connection_close = True + + if not self.chunked: + try: + cl = int(headers.get("CONTENT_LENGTH", 0)) + except ValueError: + raise ParsingError("Content-Length is invalid") + + self.content_length = cl + if cl > 0: + buf = OverflowableBuffer(self.adj.inbuf_overflow) + self.body_rcv = FixedStreamReceiver(cl, buf) + + def get_body_stream(self): + body_rcv = self.body_rcv + if body_rcv is not None: + return body_rcv.getfile() + else: + return BytesIO() + + def close(self): + body_rcv = self.body_rcv + if body_rcv is not None: + body_rcv.getbuf().close() + + +def split_uri(uri): + # urlsplit handles byte input by returning bytes on py3, so + # scheme, netloc, path, query, and fragment are bytes + + scheme = netloc = path = query = fragment = b"" + + # urlsplit below will treat this as a scheme-less netloc, thereby losing + # the original intent of the request. Here we shamelessly stole 4 lines of + # code from the CPython stdlib to parse out the fragment and query but + # leave the path alone. See + # https://github.com/python/cpython/blob/8c9e9b0cd5b24dfbf1424d1f253d02de80e8f5ef/Lib/urllib/parse.py#L465-L468 + # and https://github.com/Pylons/waitress/issues/260 + + if uri[:2] == b"//": + path = uri + + if b"#" in path: + path, fragment = path.split(b"#", 1) + + if b"?" in path: + path, query = path.split(b"?", 1) + else: + try: + scheme, netloc, path, query, fragment = urlparse.urlsplit(uri) + except UnicodeError: + raise ParsingError("Bad URI") + + return ( + tostr(scheme), + tostr(netloc), + unquote_bytes_to_wsgi(path), + tostr(query), + tostr(fragment), + ) + + +def get_header_lines(header): + """ + Splits the header into lines, putting multi-line headers together. + """ + r = [] + lines = header.split(b"\r\n") + for line in lines: + if not line: + continue + + if b"\r" in line or b"\n" in line: + raise ParsingError('Bare CR or LF found in header line "%s"' % tostr(line)) + + if line.startswith((b" ", b"\t")): + if not r: + # https://corte.si/posts/code/pathod/pythonservers/index.html + raise ParsingError('Malformed header line "%s"' % tostr(line)) + r[-1] += line + else: + r.append(line) + return r + + +first_line_re = re.compile( + b"([^ ]+) " + b"((?:[^ :?#]+://[^ ?#/]*(?:[0-9]{1,5})?)?[^ ]+)" + b"(( HTTP/([0-9.]+))$|$)" +) + + +def crack_first_line(line): + m = first_line_re.match(line) + if m is not None and m.end() == len(line): + if m.group(3): + version = m.group(5) + else: + version = b"" + method = m.group(1) + + # the request methods that are currently defined are all uppercase: + # https://www.iana.org/assignments/http-methods/http-methods.xhtml and + # the request method is case sensitive according to + # https://tools.ietf.org/html/rfc7231#section-4.1 + + # By disallowing anything but uppercase methods we save poor + # unsuspecting souls from sending lowercase HTTP methods to waitress + # and having the request complete, while servers like nginx drop the + # request onto the floor. + if method != method.upper(): + raise ParsingError('Malformed HTTP method "%s"' % tostr(method)) + uri = m.group(2) + return method, uri, version + else: + return b"", b"", b"" diff --git a/libs/waitress/proxy_headers.py b/libs/waitress/proxy_headers.py new file mode 100644 index 000000000..1df8b8eba --- /dev/null +++ b/libs/waitress/proxy_headers.py @@ -0,0 +1,333 @@ +from collections import namedtuple + +from .utilities import logger, undquote, BadRequest + + +PROXY_HEADERS = frozenset( + { + "X_FORWARDED_FOR", + "X_FORWARDED_HOST", + "X_FORWARDED_PROTO", + "X_FORWARDED_PORT", + "X_FORWARDED_BY", + "FORWARDED", + } +) + +Forwarded = namedtuple("Forwarded", ["by", "for_", "host", "proto"]) + + +class MalformedProxyHeader(Exception): + def __init__(self, header, reason, value): + self.header = header + self.reason = reason + self.value = value + super(MalformedProxyHeader, self).__init__(header, reason, value) + + +def proxy_headers_middleware( + app, + trusted_proxy=None, + trusted_proxy_count=1, + trusted_proxy_headers=None, + clear_untrusted=True, + log_untrusted=False, + logger=logger, +): + def translate_proxy_headers(environ, start_response): + untrusted_headers = PROXY_HEADERS + remote_peer = environ["REMOTE_ADDR"] + if trusted_proxy == "*" or remote_peer == trusted_proxy: + try: + untrusted_headers = parse_proxy_headers( + environ, + trusted_proxy_count=trusted_proxy_count, + trusted_proxy_headers=trusted_proxy_headers, + logger=logger, + ) + except MalformedProxyHeader as ex: + logger.warning( + 'Malformed proxy header "%s" from "%s": %s value: %s', + ex.header, + remote_peer, + ex.reason, + ex.value, + ) + error = BadRequest('Header "{0}" malformed.'.format(ex.header)) + return error.wsgi_response(environ, start_response) + + # Clear out the untrusted proxy headers + if clear_untrusted: + clear_untrusted_headers( + environ, untrusted_headers, log_warning=log_untrusted, logger=logger, + ) + + return app(environ, start_response) + + return translate_proxy_headers + + +def parse_proxy_headers( + environ, trusted_proxy_count, trusted_proxy_headers, logger=logger, +): + if trusted_proxy_headers is None: + trusted_proxy_headers = set() + + forwarded_for = [] + forwarded_host = forwarded_proto = forwarded_port = forwarded = "" + client_addr = None + untrusted_headers = set(PROXY_HEADERS) + + def raise_for_multiple_values(): + raise ValueError("Unspecified behavior for multiple values found in header",) + + if "x-forwarded-for" in trusted_proxy_headers and "HTTP_X_FORWARDED_FOR" in environ: + try: + forwarded_for = [] + + for forward_hop in environ["HTTP_X_FORWARDED_FOR"].split(","): + forward_hop = forward_hop.strip() + forward_hop = undquote(forward_hop) + + # Make sure that all IPv6 addresses are surrounded by brackets, + # this is assuming that the IPv6 representation here does not + # include a port number. + + if "." not in forward_hop and ( + ":" in forward_hop and forward_hop[-1] != "]" + ): + forwarded_for.append("[{}]".format(forward_hop)) + else: + forwarded_for.append(forward_hop) + + forwarded_for = forwarded_for[-trusted_proxy_count:] + client_addr = forwarded_for[0] + + untrusted_headers.remove("X_FORWARDED_FOR") + except Exception as ex: + raise MalformedProxyHeader( + "X-Forwarded-For", str(ex), environ["HTTP_X_FORWARDED_FOR"], + ) + + if ( + "x-forwarded-host" in trusted_proxy_headers + and "HTTP_X_FORWARDED_HOST" in environ + ): + try: + forwarded_host_multiple = [] + + for forward_host in environ["HTTP_X_FORWARDED_HOST"].split(","): + forward_host = forward_host.strip() + forward_host = undquote(forward_host) + forwarded_host_multiple.append(forward_host) + + forwarded_host_multiple = forwarded_host_multiple[-trusted_proxy_count:] + forwarded_host = forwarded_host_multiple[0] + + untrusted_headers.remove("X_FORWARDED_HOST") + except Exception as ex: + raise MalformedProxyHeader( + "X-Forwarded-Host", str(ex), environ["HTTP_X_FORWARDED_HOST"], + ) + + if "x-forwarded-proto" in trusted_proxy_headers: + try: + forwarded_proto = undquote(environ.get("HTTP_X_FORWARDED_PROTO", "")) + if "," in forwarded_proto: + raise_for_multiple_values() + untrusted_headers.remove("X_FORWARDED_PROTO") + except Exception as ex: + raise MalformedProxyHeader( + "X-Forwarded-Proto", str(ex), environ["HTTP_X_FORWARDED_PROTO"], + ) + + if "x-forwarded-port" in trusted_proxy_headers: + try: + forwarded_port = undquote(environ.get("HTTP_X_FORWARDED_PORT", "")) + if "," in forwarded_port: + raise_for_multiple_values() + untrusted_headers.remove("X_FORWARDED_PORT") + except Exception as ex: + raise MalformedProxyHeader( + "X-Forwarded-Port", str(ex), environ["HTTP_X_FORWARDED_PORT"], + ) + + if "x-forwarded-by" in trusted_proxy_headers: + # Waitress itself does not use X-Forwarded-By, but we can not + # remove it so it can get set in the environ + untrusted_headers.remove("X_FORWARDED_BY") + + if "forwarded" in trusted_proxy_headers: + forwarded = environ.get("HTTP_FORWARDED", None) + untrusted_headers = PROXY_HEADERS - {"FORWARDED"} + + # If the Forwarded header exists, it gets priority + if forwarded: + proxies = [] + try: + for forwarded_element in forwarded.split(","): + # Remove whitespace that may have been introduced when + # appending a new entry + forwarded_element = forwarded_element.strip() + + forwarded_for = forwarded_host = forwarded_proto = "" + forwarded_port = forwarded_by = "" + + for pair in forwarded_element.split(";"): + pair = pair.lower() + + if not pair: + continue + + token, equals, value = pair.partition("=") + + if equals != "=": + raise ValueError('Invalid forwarded-pair missing "="') + + if token.strip() != token: + raise ValueError("Token may not be surrounded by whitespace") + + if value.strip() != value: + raise ValueError("Value may not be surrounded by whitespace") + + if token == "by": + forwarded_by = undquote(value) + + elif token == "for": + forwarded_for = undquote(value) + + elif token == "host": + forwarded_host = undquote(value) + + elif token == "proto": + forwarded_proto = undquote(value) + + else: + logger.warning("Unknown Forwarded token: %s" % token) + + proxies.append( + Forwarded( + forwarded_by, forwarded_for, forwarded_host, forwarded_proto + ) + ) + except Exception as ex: + raise MalformedProxyHeader( + "Forwarded", str(ex), environ["HTTP_FORWARDED"], + ) + + proxies = proxies[-trusted_proxy_count:] + + # Iterate backwards and fill in some values, the oldest entry that + # contains the information we expect is the one we use. We expect + # that intermediate proxies may re-write the host header or proto, + # but the oldest entry is the one that contains the information the + # client expects when generating URL's + # + # Forwarded: for="[2001:db8::1]";host="example.com:8443";proto="https" + # Forwarded: for=192.0.2.1;host="example.internal:8080" + # + # (After HTTPS header folding) should mean that we use as values: + # + # Host: example.com + # Protocol: https + # Port: 8443 + + for proxy in proxies[::-1]: + client_addr = proxy.for_ or client_addr + forwarded_host = proxy.host or forwarded_host + forwarded_proto = proxy.proto or forwarded_proto + + if forwarded_proto: + forwarded_proto = forwarded_proto.lower() + + if forwarded_proto not in {"http", "https"}: + raise MalformedProxyHeader( + "Forwarded Proto=" if forwarded else "X-Forwarded-Proto", + "unsupported proto value", + forwarded_proto, + ) + + # Set the URL scheme to the proxy provided proto + environ["wsgi.url_scheme"] = forwarded_proto + + if not forwarded_port: + if forwarded_proto == "http": + forwarded_port = "80" + + if forwarded_proto == "https": + forwarded_port = "443" + + if forwarded_host: + if ":" in forwarded_host and forwarded_host[-1] != "]": + host, port = forwarded_host.rsplit(":", 1) + host, port = host.strip(), str(port) + + # We trust the port in the Forwarded Host/X-Forwarded-Host over + # X-Forwarded-Port, or whatever we got from Forwarded + # Proto/X-Forwarded-Proto. + + if forwarded_port != port: + forwarded_port = port + + # We trust the proxy server's forwarded Host + environ["SERVER_NAME"] = host + environ["HTTP_HOST"] = forwarded_host + else: + # We trust the proxy server's forwarded Host + environ["SERVER_NAME"] = forwarded_host + environ["HTTP_HOST"] = forwarded_host + + if forwarded_port: + if forwarded_port not in {"443", "80"}: + environ["HTTP_HOST"] = "{}:{}".format( + forwarded_host, forwarded_port + ) + elif forwarded_port == "80" and environ["wsgi.url_scheme"] != "http": + environ["HTTP_HOST"] = "{}:{}".format( + forwarded_host, forwarded_port + ) + elif forwarded_port == "443" and environ["wsgi.url_scheme"] != "https": + environ["HTTP_HOST"] = "{}:{}".format( + forwarded_host, forwarded_port + ) + + if forwarded_port: + environ["SERVER_PORT"] = str(forwarded_port) + + if client_addr: + if ":" in client_addr and client_addr[-1] != "]": + addr, port = client_addr.rsplit(":", 1) + environ["REMOTE_ADDR"] = strip_brackets(addr.strip()) + environ["REMOTE_PORT"] = port.strip() + else: + environ["REMOTE_ADDR"] = strip_brackets(client_addr.strip()) + environ["REMOTE_HOST"] = environ["REMOTE_ADDR"] + + return untrusted_headers + + +def strip_brackets(addr): + if addr[0] == "[" and addr[-1] == "]": + return addr[1:-1] + return addr + + +def clear_untrusted_headers( + environ, untrusted_headers, log_warning=False, logger=logger +): + untrusted_headers_removed = [ + header + for header in untrusted_headers + if environ.pop("HTTP_" + header, False) is not False + ] + + if log_warning and untrusted_headers_removed: + untrusted_headers_removed = [ + "-".join(x.capitalize() for x in header.split("_")) + for header in untrusted_headers_removed + ] + logger.warning( + "Removed untrusted headers (%s). Waitress recommends these be " + "removed upstream.", + ", ".join(untrusted_headers_removed), + ) diff --git a/libs/waitress/receiver.py b/libs/waitress/receiver.py new file mode 100644 index 000000000..5d1568d51 --- /dev/null +++ b/libs/waitress/receiver.py @@ -0,0 +1,186 @@ +############################################################################## +# +# Copyright (c) 2001, 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +"""Data Chunk Receiver +""" + +from waitress.utilities import BadRequest, find_double_newline + + +class FixedStreamReceiver(object): + + # See IStreamConsumer + completed = False + error = None + + def __init__(self, cl, buf): + self.remain = cl + self.buf = buf + + def __len__(self): + return self.buf.__len__() + + def received(self, data): + "See IStreamConsumer" + rm = self.remain + + if rm < 1: + self.completed = True # Avoid any chance of spinning + + return 0 + datalen = len(data) + + if rm <= datalen: + self.buf.append(data[:rm]) + self.remain = 0 + self.completed = True + + return rm + else: + self.buf.append(data) + self.remain -= datalen + + return datalen + + def getfile(self): + return self.buf.getfile() + + def getbuf(self): + return self.buf + + +class ChunkedReceiver(object): + + chunk_remainder = 0 + validate_chunk_end = False + control_line = b"" + chunk_end = b"" + all_chunks_received = False + trailer = b"" + completed = False + error = None + + # max_control_line = 1024 + # max_trailer = 65536 + + def __init__(self, buf): + self.buf = buf + + def __len__(self): + return self.buf.__len__() + + def received(self, s): + # Returns the number of bytes consumed. + + if self.completed: + return 0 + orig_size = len(s) + + while s: + rm = self.chunk_remainder + + if rm > 0: + # Receive the remainder of a chunk. + to_write = s[:rm] + self.buf.append(to_write) + written = len(to_write) + s = s[written:] + + self.chunk_remainder -= written + + if self.chunk_remainder == 0: + self.validate_chunk_end = True + elif self.validate_chunk_end: + s = self.chunk_end + s + + pos = s.find(b"\r\n") + + if pos < 0 and len(s) < 2: + self.chunk_end = s + s = b"" + else: + self.chunk_end = b"" + if pos == 0: + # Chop off the terminating CR LF from the chunk + s = s[2:] + else: + self.error = BadRequest("Chunk not properly terminated") + self.all_chunks_received = True + + # Always exit this loop + self.validate_chunk_end = False + elif not self.all_chunks_received: + # Receive a control line. + s = self.control_line + s + pos = s.find(b"\r\n") + + if pos < 0: + # Control line not finished. + self.control_line = s + s = b"" + else: + # Control line finished. + line = s[:pos] + s = s[pos + 2 :] + self.control_line = b"" + line = line.strip() + + if line: + # Begin a new chunk. + semi = line.find(b";") + + if semi >= 0: + # discard extension info. + line = line[:semi] + try: + sz = int(line.strip(), 16) # hexadecimal + except ValueError: # garbage in input + self.error = BadRequest("garbage in chunked encoding input") + sz = 0 + + if sz > 0: + # Start a new chunk. + self.chunk_remainder = sz + else: + # Finished chunks. + self.all_chunks_received = True + # else expect a control line. + else: + # Receive the trailer. + trailer = self.trailer + s + + if trailer.startswith(b"\r\n"): + # No trailer. + self.completed = True + + return orig_size - (len(trailer) - 2) + pos = find_double_newline(trailer) + + if pos < 0: + # Trailer not finished. + self.trailer = trailer + s = b"" + else: + # Finished the trailer. + self.completed = True + self.trailer = trailer[:pos] + + return orig_size - (len(trailer) - pos) + + return orig_size + + def getfile(self): + return self.buf.getfile() + + def getbuf(self): + return self.buf diff --git a/libs/waitress/rfc7230.py b/libs/waitress/rfc7230.py new file mode 100644 index 000000000..cd33c9064 --- /dev/null +++ b/libs/waitress/rfc7230.py @@ -0,0 +1,52 @@ +""" +This contains a bunch of RFC7230 definitions and regular expressions that are +needed to properly parse HTTP messages. +""" + +import re + +from .compat import tobytes + +WS = "[ \t]" +OWS = WS + "{0,}?" +RWS = WS + "{1,}?" +BWS = OWS + +# RFC 7230 Section 3.2.6 "Field Value Components": +# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" +# / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~" +# / DIGIT / ALPHA +# obs-text = %x80-FF +TCHAR = r"[!#$%&'*+\-.^_`|~0-9A-Za-z]" +OBS_TEXT = r"\x80-\xff" + +TOKEN = TCHAR + "{1,}" + +# RFC 5234 Appendix B.1 "Core Rules": +# VCHAR = %x21-7E +# ; visible (printing) characters +VCHAR = r"\x21-\x7e" + +# header-field = field-name ":" OWS field-value OWS +# field-name = token +# field-value = *( field-content / obs-fold ) +# field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ] +# field-vchar = VCHAR / obs-text + +# Errata from: https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189 +# changes field-content to: +# +# field-content = field-vchar [ 1*( SP / HTAB / field-vchar ) +# field-vchar ] + +FIELD_VCHAR = "[" + VCHAR + OBS_TEXT + "]" +# Field content is more greedy than the ABNF, in that it will match the whole value +FIELD_CONTENT = FIELD_VCHAR + "+(?:[ \t]+" + FIELD_VCHAR + "+)*" +# Which allows the field value here to just see if there is even a value in the first place +FIELD_VALUE = "(?:" + FIELD_CONTENT + ")?" + +HEADER_FIELD = re.compile( + tobytes( + "^(?P<name>" + TOKEN + "):" + OWS + "(?P<value>" + FIELD_VALUE + ")" + OWS + "$" + ) +) diff --git a/libs/waitress/runner.py b/libs/waitress/runner.py new file mode 100644 index 000000000..2495084f0 --- /dev/null +++ b/libs/waitress/runner.py @@ -0,0 +1,286 @@ +############################################################################## +# +# Copyright (c) 2013 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +"""Command line runner. +""" + +from __future__ import print_function, unicode_literals + +import getopt +import os +import os.path +import re +import sys + +from waitress import serve +from waitress.adjustments import Adjustments + +HELP = """\ +Usage: + + {0} [OPTS] MODULE:OBJECT + +Standard options: + + --help + Show this information. + + --call + Call the given object to get the WSGI application. + + --host=ADDR + Hostname or IP address on which to listen, default is '0.0.0.0', + which means "all IP addresses on this host". + + Note: May not be used together with --listen + + --port=PORT + TCP port on which to listen, default is '8080' + + Note: May not be used together with --listen + + --listen=ip:port + Tell waitress to listen on an ip port combination. + + Example: + + --listen=127.0.0.1:8080 + --listen=[::1]:8080 + --listen=*:8080 + + This option may be used multiple times to listen on multiple sockets. + A wildcard for the hostname is also supported and will bind to both + IPv4/IPv6 depending on whether they are enabled or disabled. + + --[no-]ipv4 + Toggle on/off IPv4 support. + + Example: + + --no-ipv4 + + This will disable IPv4 socket support. This affects wildcard matching + when generating the list of sockets. + + --[no-]ipv6 + Toggle on/off IPv6 support. + + Example: + + --no-ipv6 + + This will turn on IPv6 socket support. This affects wildcard matching + when generating a list of sockets. + + --unix-socket=PATH + Path of Unix socket. If a socket path is specified, a Unix domain + socket is made instead of the usual inet domain socket. + + Not available on Windows. + + --unix-socket-perms=PERMS + Octal permissions to use for the Unix domain socket, default is + '600'. + + --url-scheme=STR + Default wsgi.url_scheme value, default is 'http'. + + --url-prefix=STR + The ``SCRIPT_NAME`` WSGI environment value. Setting this to anything + except the empty string will cause the WSGI ``SCRIPT_NAME`` value to be + the value passed minus any trailing slashes you add, and it will cause + the ``PATH_INFO`` of any request which is prefixed with this value to + be stripped of the prefix. Default is the empty string. + + --ident=STR + Server identity used in the 'Server' header in responses. Default + is 'waitress'. + +Tuning options: + + --threads=INT + Number of threads used to process application logic, default is 4. + + --backlog=INT + Connection backlog for the server. Default is 1024. + + --recv-bytes=INT + Number of bytes to request when calling socket.recv(). Default is + 8192. + + --send-bytes=INT + Number of bytes to send to socket.send(). Default is 18000. + Multiples of 9000 should avoid partly-filled TCP packets. + + --outbuf-overflow=INT + A temporary file should be created if the pending output is larger + than this. Default is 1048576 (1MB). + + --outbuf-high-watermark=INT + The app_iter will pause when pending output is larger than this value + and will resume once enough data is written to the socket to fall below + this threshold. Default is 16777216 (16MB). + + --inbuf-overflow=INT + A temporary file should be created if the pending input is larger + than this. Default is 524288 (512KB). + + --connection-limit=INT + Stop creating new channels if too many are already active. + Default is 100. + + --cleanup-interval=INT + Minimum seconds between cleaning up inactive channels. Default + is 30. See '--channel-timeout'. + + --channel-timeout=INT + Maximum number of seconds to leave inactive connections open. + Default is 120. 'Inactive' is defined as 'has received no data + from the client and has sent no data to the client'. + + --[no-]log-socket-errors + Toggle whether premature client disconnect tracebacks ought to be + logged. On by default. + + --max-request-header-size=INT + Maximum size of all request headers combined. Default is 262144 + (256KB). + + --max-request-body-size=INT + Maximum size of request body. Default is 1073741824 (1GB). + + --[no-]expose-tracebacks + Toggle whether to expose tracebacks of unhandled exceptions to the + client. Off by default. + + --asyncore-loop-timeout=INT + The timeout value in seconds passed to asyncore.loop(). Default is 1. + + --asyncore-use-poll + The use_poll argument passed to ``asyncore.loop()``. Helps overcome + open file descriptors limit. Default is False. + +""" + +RUNNER_PATTERN = re.compile( + r""" + ^ + (?P<module> + [a-z_][a-z0-9_]*(?:\.[a-z_][a-z0-9_]*)* + ) + : + (?P<object> + [a-z_][a-z0-9_]*(?:\.[a-z_][a-z0-9_]*)* + ) + $ + """, + re.I | re.X, +) + + +def match(obj_name): + matches = RUNNER_PATTERN.match(obj_name) + if not matches: + raise ValueError("Malformed application '{0}'".format(obj_name)) + return matches.group("module"), matches.group("object") + + +def resolve(module_name, object_name): + """Resolve a named object in a module.""" + # We cast each segments due to an issue that has been found to manifest + # in Python 2.6.6, but not 2.6.8, and may affect other revisions of Python + # 2.6 and 2.7, whereby ``__import__`` chokes if the list passed in the + # ``fromlist`` argument are unicode strings rather than 8-bit strings. + # The error triggered is "TypeError: Item in ``fromlist '' not a string". + # My guess is that this was fixed by checking against ``basestring`` + # rather than ``str`` sometime between the release of 2.6.6 and 2.6.8, + # but I've yet to go over the commits. I know, however, that the NEWS + # file makes no mention of such a change to the behaviour of + # ``__import__``. + segments = [str(segment) for segment in object_name.split(".")] + obj = __import__(module_name, fromlist=segments[:1]) + for segment in segments: + obj = getattr(obj, segment) + return obj + + +def show_help(stream, name, error=None): # pragma: no cover + if error is not None: + print("Error: {0}\n".format(error), file=stream) + print(HELP.format(name), file=stream) + + +def show_exception(stream): + exc_type, exc_value = sys.exc_info()[:2] + args = getattr(exc_value, "args", None) + print( + ("There was an exception ({0}) importing your module.\n").format( + exc_type.__name__, + ), + file=stream, + ) + if args: + print("It had these arguments: ", file=stream) + for idx, arg in enumerate(args, start=1): + print("{0}. {1}\n".format(idx, arg), file=stream) + else: + print("It had no arguments.", file=stream) + + +def run(argv=sys.argv, _serve=serve): + """Command line runner.""" + name = os.path.basename(argv[0]) + + try: + kw, args = Adjustments.parse_args(argv[1:]) + except getopt.GetoptError as exc: + show_help(sys.stderr, name, str(exc)) + return 1 + + if kw["help"]: + show_help(sys.stdout, name) + return 0 + + if len(args) != 1: + show_help(sys.stderr, name, "Specify one application only") + return 1 + + try: + module, obj_name = match(args[0]) + except ValueError as exc: + show_help(sys.stderr, name, str(exc)) + show_exception(sys.stderr) + return 1 + + # Add the current directory onto sys.path + sys.path.append(os.getcwd()) + + # Get the WSGI function. + try: + app = resolve(module, obj_name) + except ImportError: + show_help(sys.stderr, name, "Bad module '{0}'".format(module)) + show_exception(sys.stderr) + return 1 + except AttributeError: + show_help(sys.stderr, name, "Bad object name '{0}'".format(obj_name)) + show_exception(sys.stderr) + return 1 + if kw["call"]: + app = app() + + # These arguments are specific to the runner, not waitress itself. + del kw["call"], kw["help"] + + _serve(app, **kw) + return 0 diff --git a/libs/waitress/server.py b/libs/waitress/server.py new file mode 100644 index 000000000..ae566994f --- /dev/null +++ b/libs/waitress/server.py @@ -0,0 +1,436 @@ +############################################################################## +# +# Copyright (c) 2001, 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## + +import os +import os.path +import socket +import time + +from waitress import trigger +from waitress.adjustments import Adjustments +from waitress.channel import HTTPChannel +from waitress.task import ThreadedTaskDispatcher +from waitress.utilities import cleanup_unix_socket + +from waitress.compat import ( + IPPROTO_IPV6, + IPV6_V6ONLY, +) +from . import wasyncore +from .proxy_headers import proxy_headers_middleware + + +def create_server( + application, + map=None, + _start=True, # test shim + _sock=None, # test shim + _dispatcher=None, # test shim + **kw # adjustments +): + """ + if __name__ == '__main__': + server = create_server(app) + server.run() + """ + if application is None: + raise ValueError( + 'The "app" passed to ``create_server`` was ``None``. You forgot ' + "to return a WSGI app within your application." + ) + adj = Adjustments(**kw) + + if map is None: # pragma: nocover + map = {} + + dispatcher = _dispatcher + if dispatcher is None: + dispatcher = ThreadedTaskDispatcher() + dispatcher.set_thread_count(adj.threads) + + if adj.unix_socket and hasattr(socket, "AF_UNIX"): + sockinfo = (socket.AF_UNIX, socket.SOCK_STREAM, None, None) + return UnixWSGIServer( + application, + map, + _start, + _sock, + dispatcher=dispatcher, + adj=adj, + sockinfo=sockinfo, + ) + + effective_listen = [] + last_serv = None + if not adj.sockets: + for sockinfo in adj.listen: + # When TcpWSGIServer is called, it registers itself in the map. This + # side-effect is all we need it for, so we don't store a reference to + # or return it to the user. + last_serv = TcpWSGIServer( + application, + map, + _start, + _sock, + dispatcher=dispatcher, + adj=adj, + sockinfo=sockinfo, + ) + effective_listen.append( + (last_serv.effective_host, last_serv.effective_port) + ) + + for sock in adj.sockets: + sockinfo = (sock.family, sock.type, sock.proto, sock.getsockname()) + if sock.family == socket.AF_INET or sock.family == socket.AF_INET6: + last_serv = TcpWSGIServer( + application, + map, + _start, + sock, + dispatcher=dispatcher, + adj=adj, + bind_socket=False, + sockinfo=sockinfo, + ) + effective_listen.append( + (last_serv.effective_host, last_serv.effective_port) + ) + elif hasattr(socket, "AF_UNIX") and sock.family == socket.AF_UNIX: + last_serv = UnixWSGIServer( + application, + map, + _start, + sock, + dispatcher=dispatcher, + adj=adj, + bind_socket=False, + sockinfo=sockinfo, + ) + effective_listen.append( + (last_serv.effective_host, last_serv.effective_port) + ) + + # We are running a single server, so we can just return the last server, + # saves us from having to create one more object + if len(effective_listen) == 1: + # In this case we have no need to use a MultiSocketServer + return last_serv + + # Return a class that has a utility function to print out the sockets it's + # listening on, and has a .run() function. All of the TcpWSGIServers + # registered themselves in the map above. + return MultiSocketServer(map, adj, effective_listen, dispatcher) + + +# This class is only ever used if we have multiple listen sockets. It allows +# the serve() API to call .run() which starts the wasyncore loop, and catches +# SystemExit/KeyboardInterrupt so that it can atempt to cleanly shut down. +class MultiSocketServer(object): + asyncore = wasyncore # test shim + + def __init__( + self, map=None, adj=None, effective_listen=None, dispatcher=None, + ): + self.adj = adj + self.map = map + self.effective_listen = effective_listen + self.task_dispatcher = dispatcher + + def print_listen(self, format_str): # pragma: nocover + for l in self.effective_listen: + l = list(l) + + if ":" in l[0]: + l[0] = "[{}]".format(l[0]) + + print(format_str.format(*l)) + + def run(self): + try: + self.asyncore.loop( + timeout=self.adj.asyncore_loop_timeout, + map=self.map, + use_poll=self.adj.asyncore_use_poll, + ) + except (SystemExit, KeyboardInterrupt): + self.close() + + def close(self): + self.task_dispatcher.shutdown() + wasyncore.close_all(self.map) + + +class BaseWSGIServer(wasyncore.dispatcher, object): + + channel_class = HTTPChannel + next_channel_cleanup = 0 + socketmod = socket # test shim + asyncore = wasyncore # test shim + + def __init__( + self, + application, + map=None, + _start=True, # test shim + _sock=None, # test shim + dispatcher=None, # dispatcher + adj=None, # adjustments + sockinfo=None, # opaque object + bind_socket=True, + **kw + ): + if adj is None: + adj = Adjustments(**kw) + + if adj.trusted_proxy or adj.clear_untrusted_proxy_headers: + # wrap the application to deal with proxy headers + # we wrap it here because webtest subclasses the TcpWSGIServer + # directly and thus doesn't run any code that's in create_server + application = proxy_headers_middleware( + application, + trusted_proxy=adj.trusted_proxy, + trusted_proxy_count=adj.trusted_proxy_count, + trusted_proxy_headers=adj.trusted_proxy_headers, + clear_untrusted=adj.clear_untrusted_proxy_headers, + log_untrusted=adj.log_untrusted_proxy_headers, + logger=self.logger, + ) + + if map is None: + # use a nonglobal socket map by default to hopefully prevent + # conflicts with apps and libs that use the wasyncore global socket + # map ala https://github.com/Pylons/waitress/issues/63 + map = {} + if sockinfo is None: + sockinfo = adj.listen[0] + + self.sockinfo = sockinfo + self.family = sockinfo[0] + self.socktype = sockinfo[1] + self.application = application + self.adj = adj + self.trigger = trigger.trigger(map) + if dispatcher is None: + dispatcher = ThreadedTaskDispatcher() + dispatcher.set_thread_count(self.adj.threads) + + self.task_dispatcher = dispatcher + self.asyncore.dispatcher.__init__(self, _sock, map=map) + if _sock is None: + self.create_socket(self.family, self.socktype) + if self.family == socket.AF_INET6: # pragma: nocover + self.socket.setsockopt(IPPROTO_IPV6, IPV6_V6ONLY, 1) + + self.set_reuse_addr() + + if bind_socket: + self.bind_server_socket() + + self.effective_host, self.effective_port = self.getsockname() + self.server_name = self.get_server_name(self.effective_host) + self.active_channels = {} + if _start: + self.accept_connections() + + def bind_server_socket(self): + raise NotImplementedError # pragma: no cover + + def get_server_name(self, ip): + """Given an IP or hostname, try to determine the server name.""" + + if not ip: + raise ValueError("Requires an IP to get the server name") + + server_name = str(ip) + + # If we are bound to all IP's, just return the current hostname, only + # fall-back to "localhost" if we fail to get the hostname + if server_name == "0.0.0.0" or server_name == "::": + try: + return str(self.socketmod.gethostname()) + except (socket.error, UnicodeDecodeError): # pragma: no cover + # We also deal with UnicodeDecodeError in case of Windows with + # non-ascii hostname + return "localhost" + + # Now let's try and convert the IP address to a proper hostname + try: + server_name = self.socketmod.gethostbyaddr(server_name)[0] + except (socket.error, UnicodeDecodeError): # pragma: no cover + # We also deal with UnicodeDecodeError in case of Windows with + # non-ascii hostname + pass + + # If it contains an IPv6 literal, make sure to surround it with + # brackets + if ":" in server_name and "[" not in server_name: + server_name = "[{}]".format(server_name) + + return server_name + + def getsockname(self): + raise NotImplementedError # pragma: no cover + + def accept_connections(self): + self.accepting = True + self.socket.listen(self.adj.backlog) # Get around asyncore NT limit + + def add_task(self, task): + self.task_dispatcher.add_task(task) + + def readable(self): + now = time.time() + if now >= self.next_channel_cleanup: + self.next_channel_cleanup = now + self.adj.cleanup_interval + self.maintenance(now) + return self.accepting and len(self._map) < self.adj.connection_limit + + def writable(self): + return False + + def handle_read(self): + pass + + def handle_connect(self): + pass + + def handle_accept(self): + try: + v = self.accept() + if v is None: + return + conn, addr = v + except socket.error: + # Linux: On rare occasions we get a bogus socket back from + # accept. socketmodule.c:makesockaddr complains that the + # address family is unknown. We don't want the whole server + # to shut down because of this. + if self.adj.log_socket_errors: + self.logger.warning("server accept() threw an exception", exc_info=True) + return + self.set_socket_options(conn) + addr = self.fix_addr(addr) + self.channel_class(self, conn, addr, self.adj, map=self._map) + + def run(self): + try: + self.asyncore.loop( + timeout=self.adj.asyncore_loop_timeout, + map=self._map, + use_poll=self.adj.asyncore_use_poll, + ) + except (SystemExit, KeyboardInterrupt): + self.task_dispatcher.shutdown() + + def pull_trigger(self): + self.trigger.pull_trigger() + + def set_socket_options(self, conn): + pass + + def fix_addr(self, addr): + return addr + + def maintenance(self, now): + """ + Closes channels that have not had any activity in a while. + + The timeout is configured through adj.channel_timeout (seconds). + """ + cutoff = now - self.adj.channel_timeout + for channel in self.active_channels.values(): + if (not channel.requests) and channel.last_activity < cutoff: + channel.will_close = True + + def print_listen(self, format_str): # pragma: nocover + print(format_str.format(self.effective_host, self.effective_port)) + + def close(self): + self.trigger.close() + return wasyncore.dispatcher.close(self) + + +class TcpWSGIServer(BaseWSGIServer): + def bind_server_socket(self): + (_, _, _, sockaddr) = self.sockinfo + self.bind(sockaddr) + + def getsockname(self): + try: + return self.socketmod.getnameinfo( + self.socket.getsockname(), self.socketmod.NI_NUMERICSERV + ) + except: # pragma: no cover + # This only happens on Linux because a DNS issue is considered a + # temporary failure that will raise (even when NI_NAMEREQD is not + # set). Instead we try again, but this time we just ask for the + # numerichost and the numericserv (port) and return those. It is + # better than nothing. + return self.socketmod.getnameinfo( + self.socket.getsockname(), + self.socketmod.NI_NUMERICHOST | self.socketmod.NI_NUMERICSERV, + ) + + def set_socket_options(self, conn): + for (level, optname, value) in self.adj.socket_options: + conn.setsockopt(level, optname, value) + + +if hasattr(socket, "AF_UNIX"): + + class UnixWSGIServer(BaseWSGIServer): + def __init__( + self, + application, + map=None, + _start=True, # test shim + _sock=None, # test shim + dispatcher=None, # dispatcher + adj=None, # adjustments + sockinfo=None, # opaque object + **kw + ): + if sockinfo is None: + sockinfo = (socket.AF_UNIX, socket.SOCK_STREAM, None, None) + + super(UnixWSGIServer, self).__init__( + application, + map=map, + _start=_start, + _sock=_sock, + dispatcher=dispatcher, + adj=adj, + sockinfo=sockinfo, + **kw + ) + + def bind_server_socket(self): + cleanup_unix_socket(self.adj.unix_socket) + self.bind(self.adj.unix_socket) + if os.path.exists(self.adj.unix_socket): + os.chmod(self.adj.unix_socket, self.adj.unix_socket_perms) + + def getsockname(self): + return ("unix", self.socket.getsockname()) + + def fix_addr(self, addr): + return ("localhost", None) + + def get_server_name(self, ip): + return "localhost" + + +# Compatibility alias. +WSGIServer = TcpWSGIServer diff --git a/libs/waitress/task.py b/libs/waitress/task.py new file mode 100644 index 000000000..8e7ab1888 --- /dev/null +++ b/libs/waitress/task.py @@ -0,0 +1,570 @@ +############################################################################## +# +# Copyright (c) 2001, 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## + +import socket +import sys +import threading +import time +from collections import deque + +from .buffers import ReadOnlyFileBasedBuffer +from .compat import reraise, tobytes +from .utilities import build_http_date, logger, queue_logger + +rename_headers = { # or keep them without the HTTP_ prefix added + "CONTENT_LENGTH": "CONTENT_LENGTH", + "CONTENT_TYPE": "CONTENT_TYPE", +} + +hop_by_hop = frozenset( + ( + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailers", + "transfer-encoding", + "upgrade", + ) +) + + +class ThreadedTaskDispatcher(object): + """A Task Dispatcher that creates a thread for each task. + """ + + stop_count = 0 # Number of threads that will stop soon. + active_count = 0 # Number of currently active threads + logger = logger + queue_logger = queue_logger + + def __init__(self): + self.threads = set() + self.queue = deque() + self.lock = threading.Lock() + self.queue_cv = threading.Condition(self.lock) + self.thread_exit_cv = threading.Condition(self.lock) + + def start_new_thread(self, target, args): + t = threading.Thread(target=target, name="waitress", args=args) + t.daemon = True + t.start() + + def handler_thread(self, thread_no): + while True: + with self.lock: + while not self.queue and self.stop_count == 0: + # Mark ourselves as idle before waiting to be + # woken up, then we will once again be active + self.active_count -= 1 + self.queue_cv.wait() + self.active_count += 1 + + if self.stop_count > 0: + self.active_count -= 1 + self.stop_count -= 1 + self.threads.discard(thread_no) + self.thread_exit_cv.notify() + break + + task = self.queue.popleft() + try: + task.service() + except BaseException: + self.logger.exception("Exception when servicing %r", task) + + def set_thread_count(self, count): + with self.lock: + threads = self.threads + thread_no = 0 + running = len(threads) - self.stop_count + while running < count: + # Start threads. + while thread_no in threads: + thread_no = thread_no + 1 + threads.add(thread_no) + running += 1 + self.start_new_thread(self.handler_thread, (thread_no,)) + self.active_count += 1 + thread_no = thread_no + 1 + if running > count: + # Stop threads. + self.stop_count += running - count + self.queue_cv.notify_all() + + def add_task(self, task): + with self.lock: + self.queue.append(task) + self.queue_cv.notify() + queue_size = len(self.queue) + idle_threads = len(self.threads) - self.stop_count - self.active_count + if queue_size > idle_threads: + self.queue_logger.warning( + "Task queue depth is %d", queue_size - idle_threads + ) + + def shutdown(self, cancel_pending=True, timeout=5): + self.set_thread_count(0) + # Ensure the threads shut down. + threads = self.threads + expiration = time.time() + timeout + with self.lock: + while threads: + if time.time() >= expiration: + self.logger.warning("%d thread(s) still running", len(threads)) + break + self.thread_exit_cv.wait(0.1) + if cancel_pending: + # Cancel remaining tasks. + queue = self.queue + if len(queue) > 0: + self.logger.warning("Canceling %d pending task(s)", len(queue)) + while queue: + task = queue.popleft() + task.cancel() + self.queue_cv.notify_all() + return True + return False + + +class Task(object): + close_on_finish = False + status = "200 OK" + wrote_header = False + start_time = 0 + content_length = None + content_bytes_written = 0 + logged_write_excess = False + logged_write_no_body = False + complete = False + chunked_response = False + logger = logger + + def __init__(self, channel, request): + self.channel = channel + self.request = request + self.response_headers = [] + version = request.version + if version not in ("1.0", "1.1"): + # fall back to a version we support. + version = "1.0" + self.version = version + + def service(self): + try: + try: + self.start() + self.execute() + self.finish() + except socket.error: + self.close_on_finish = True + if self.channel.adj.log_socket_errors: + raise + finally: + pass + + @property + def has_body(self): + return not ( + self.status.startswith("1") + or self.status.startswith("204") + or self.status.startswith("304") + ) + + def build_response_header(self): + version = self.version + # Figure out whether the connection should be closed. + connection = self.request.headers.get("CONNECTION", "").lower() + response_headers = [] + content_length_header = None + date_header = None + server_header = None + connection_close_header = None + + for (headername, headerval) in self.response_headers: + headername = "-".join([x.capitalize() for x in headername.split("-")]) + + if headername == "Content-Length": + if self.has_body: + content_length_header = headerval + else: + continue # pragma: no cover + + if headername == "Date": + date_header = headerval + + if headername == "Server": + server_header = headerval + + if headername == "Connection": + connection_close_header = headerval.lower() + # replace with properly capitalized version + response_headers.append((headername, headerval)) + + if ( + content_length_header is None + and self.content_length is not None + and self.has_body + ): + content_length_header = str(self.content_length) + response_headers.append(("Content-Length", content_length_header)) + + def close_on_finish(): + if connection_close_header is None: + response_headers.append(("Connection", "close")) + self.close_on_finish = True + + if version == "1.0": + if connection == "keep-alive": + if not content_length_header: + close_on_finish() + else: + response_headers.append(("Connection", "Keep-Alive")) + else: + close_on_finish() + + elif version == "1.1": + if connection == "close": + close_on_finish() + + if not content_length_header: + # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length + # for any response with a status code of 1xx, 204 or 304. + + if self.has_body: + response_headers.append(("Transfer-Encoding", "chunked")) + self.chunked_response = True + + if not self.close_on_finish: + close_on_finish() + + # under HTTP 1.1 keep-alive is default, no need to set the header + else: + raise AssertionError("neither HTTP/1.0 or HTTP/1.1") + + # Set the Server and Date field, if not yet specified. This is needed + # if the server is used as a proxy. + ident = self.channel.server.adj.ident + + if not server_header: + if ident: + response_headers.append(("Server", ident)) + else: + response_headers.append(("Via", ident or "waitress")) + + if not date_header: + response_headers.append(("Date", build_http_date(self.start_time))) + + self.response_headers = response_headers + + first_line = "HTTP/%s %s" % (self.version, self.status) + # NB: sorting headers needs to preserve same-named-header order + # as per RFC 2616 section 4.2; thus the key=lambda x: x[0] here; + # rely on stable sort to keep relative position of same-named headers + next_lines = [ + "%s: %s" % hv for hv in sorted(self.response_headers, key=lambda x: x[0]) + ] + lines = [first_line] + next_lines + res = "%s\r\n\r\n" % "\r\n".join(lines) + + return tobytes(res) + + def remove_content_length_header(self): + response_headers = [] + + for header_name, header_value in self.response_headers: + if header_name.lower() == "content-length": + continue # pragma: nocover + response_headers.append((header_name, header_value)) + + self.response_headers = response_headers + + def start(self): + self.start_time = time.time() + + def finish(self): + if not self.wrote_header: + self.write(b"") + if self.chunked_response: + # not self.write, it will chunk it! + self.channel.write_soon(b"0\r\n\r\n") + + def write(self, data): + if not self.complete: + raise RuntimeError("start_response was not called before body written") + channel = self.channel + if not self.wrote_header: + rh = self.build_response_header() + channel.write_soon(rh) + self.wrote_header = True + + if data and self.has_body: + towrite = data + cl = self.content_length + if self.chunked_response: + # use chunked encoding response + towrite = tobytes(hex(len(data))[2:].upper()) + b"\r\n" + towrite += data + b"\r\n" + elif cl is not None: + towrite = data[: cl - self.content_bytes_written] + self.content_bytes_written += len(towrite) + if towrite != data and not self.logged_write_excess: + self.logger.warning( + "application-written content exceeded the number of " + "bytes specified by Content-Length header (%s)" % cl + ) + self.logged_write_excess = True + if towrite: + channel.write_soon(towrite) + elif data: + # Cheat, and tell the application we have written all of the bytes, + # even though the response shouldn't have a body and we are + # ignoring it entirely. + self.content_bytes_written += len(data) + + if not self.logged_write_no_body: + self.logger.warning( + "application-written content was ignored due to HTTP " + "response that may not contain a message-body: (%s)" % self.status + ) + self.logged_write_no_body = True + + +class ErrorTask(Task): + """ An error task produces an error response + """ + + complete = True + + def execute(self): + e = self.request.error + status, headers, body = e.to_response() + self.status = status + self.response_headers.extend(headers) + # We need to explicitly tell the remote client we are closing the + # connection, because self.close_on_finish is set, and we are going to + # slam the door in the clients face. + self.response_headers.append(("Connection", "close")) + self.close_on_finish = True + self.content_length = len(body) + self.write(tobytes(body)) + + +class WSGITask(Task): + """A WSGI task produces a response from a WSGI application. + """ + + environ = None + + def execute(self): + environ = self.get_environment() + + def start_response(status, headers, exc_info=None): + if self.complete and not exc_info: + raise AssertionError( + "start_response called a second time without providing exc_info." + ) + if exc_info: + try: + if self.wrote_header: + # higher levels will catch and handle raised exception: + # 1. "service" method in task.py + # 2. "service" method in channel.py + # 3. "handler_thread" method in task.py + reraise(exc_info[0], exc_info[1], exc_info[2]) + else: + # As per WSGI spec existing headers must be cleared + self.response_headers = [] + finally: + exc_info = None + + self.complete = True + + if not status.__class__ is str: + raise AssertionError("status %s is not a string" % status) + if "\n" in status or "\r" in status: + raise ValueError( + "carriage return/line feed character present in status" + ) + + self.status = status + + # Prepare the headers for output + for k, v in headers: + if not k.__class__ is str: + raise AssertionError( + "Header name %r is not a string in %r" % (k, (k, v)) + ) + if not v.__class__ is str: + raise AssertionError( + "Header value %r is not a string in %r" % (v, (k, v)) + ) + + if "\n" in v or "\r" in v: + raise ValueError( + "carriage return/line feed character present in header value" + ) + if "\n" in k or "\r" in k: + raise ValueError( + "carriage return/line feed character present in header name" + ) + + kl = k.lower() + if kl == "content-length": + self.content_length = int(v) + elif kl in hop_by_hop: + raise AssertionError( + '%s is a "hop-by-hop" header; it cannot be used by ' + "a WSGI application (see PEP 3333)" % k + ) + + self.response_headers.extend(headers) + + # Return a method used to write the response data. + return self.write + + # Call the application to handle the request and write a response + app_iter = self.channel.server.application(environ, start_response) + + can_close_app_iter = True + try: + if app_iter.__class__ is ReadOnlyFileBasedBuffer: + cl = self.content_length + size = app_iter.prepare(cl) + if size: + if cl != size: + if cl is not None: + self.remove_content_length_header() + self.content_length = size + self.write(b"") # generate headers + # if the write_soon below succeeds then the channel will + # take over closing the underlying file via the channel's + # _flush_some or handle_close so we intentionally avoid + # calling close in the finally block + self.channel.write_soon(app_iter) + can_close_app_iter = False + return + + first_chunk_len = None + for chunk in app_iter: + if first_chunk_len is None: + first_chunk_len = len(chunk) + # Set a Content-Length header if one is not supplied. + # start_response may not have been called until first + # iteration as per PEP, so we must reinterrogate + # self.content_length here + if self.content_length is None: + app_iter_len = None + if hasattr(app_iter, "__len__"): + app_iter_len = len(app_iter) + if app_iter_len == 1: + self.content_length = first_chunk_len + # transmit headers only after first iteration of the iterable + # that returns a non-empty bytestring (PEP 3333) + if chunk: + self.write(chunk) + + cl = self.content_length + if cl is not None: + if self.content_bytes_written != cl: + # close the connection so the client isn't sitting around + # waiting for more data when there are too few bytes + # to service content-length + self.close_on_finish = True + if self.request.command != "HEAD": + self.logger.warning( + "application returned too few bytes (%s) " + "for specified Content-Length (%s) via app_iter" + % (self.content_bytes_written, cl), + ) + finally: + if can_close_app_iter and hasattr(app_iter, "close"): + app_iter.close() + + def get_environment(self): + """Returns a WSGI environment.""" + environ = self.environ + if environ is not None: + # Return the cached copy. + return environ + + request = self.request + path = request.path + channel = self.channel + server = channel.server + url_prefix = server.adj.url_prefix + + if path.startswith("/"): + # strip extra slashes at the beginning of a path that starts + # with any number of slashes + path = "/" + path.lstrip("/") + + if url_prefix: + # NB: url_prefix is guaranteed by the configuration machinery to + # be either the empty string or a string that starts with a single + # slash and ends without any slashes + if path == url_prefix: + # if the path is the same as the url prefix, the SCRIPT_NAME + # should be the url_prefix and PATH_INFO should be empty + path = "" + else: + # if the path starts with the url prefix plus a slash, + # the SCRIPT_NAME should be the url_prefix and PATH_INFO should + # the value of path from the slash until its end + url_prefix_with_trailing_slash = url_prefix + "/" + if path.startswith(url_prefix_with_trailing_slash): + path = path[len(url_prefix) :] + + environ = { + "REMOTE_ADDR": channel.addr[0], + # Nah, we aren't actually going to look up the reverse DNS for + # REMOTE_ADDR, but we will happily set this environment variable + # for the WSGI application. Spec says we can just set this to + # REMOTE_ADDR, so we do. + "REMOTE_HOST": channel.addr[0], + # try and set the REMOTE_PORT to something useful, but maybe None + "REMOTE_PORT": str(channel.addr[1]), + "REQUEST_METHOD": request.command.upper(), + "SERVER_PORT": str(server.effective_port), + "SERVER_NAME": server.server_name, + "SERVER_SOFTWARE": server.adj.ident, + "SERVER_PROTOCOL": "HTTP/%s" % self.version, + "SCRIPT_NAME": url_prefix, + "PATH_INFO": path, + "QUERY_STRING": request.query, + "wsgi.url_scheme": request.url_scheme, + # the following environment variables are required by the WSGI spec + "wsgi.version": (1, 0), + # apps should use the logging module + "wsgi.errors": sys.stderr, + "wsgi.multithread": True, + "wsgi.multiprocess": False, + "wsgi.run_once": False, + "wsgi.input": request.get_body_stream(), + "wsgi.file_wrapper": ReadOnlyFileBasedBuffer, + "wsgi.input_terminated": True, # wsgi.input is EOF terminated + } + + for key, value in dict(request.headers).items(): + value = value.strip() + mykey = rename_headers.get(key, None) + if mykey is None: + mykey = "HTTP_" + key + if mykey not in environ: + environ[mykey] = value + + # cache the environ for this request + self.environ = environ + return environ diff --git a/libs/waitress/tests/__init__.py b/libs/waitress/tests/__init__.py new file mode 100644 index 000000000..b711d3609 --- /dev/null +++ b/libs/waitress/tests/__init__.py @@ -0,0 +1,2 @@ +# +# This file is necessary to make this directory a package. diff --git a/libs/waitress/tests/fixtureapps/__init__.py b/libs/waitress/tests/fixtureapps/__init__.py new file mode 100644 index 000000000..f215a2b90 --- /dev/null +++ b/libs/waitress/tests/fixtureapps/__init__.py @@ -0,0 +1 @@ +# package (for -m) diff --git a/libs/waitress/tests/fixtureapps/badcl.py b/libs/waitress/tests/fixtureapps/badcl.py new file mode 100644 index 000000000..24067de41 --- /dev/null +++ b/libs/waitress/tests/fixtureapps/badcl.py @@ -0,0 +1,11 @@ +def app(environ, start_response): # pragma: no cover + body = b"abcdefghi" + cl = len(body) + if environ["PATH_INFO"] == "/short_body": + cl = len(body) + 1 + if environ["PATH_INFO"] == "/long_body": + cl = len(body) - 1 + start_response( + "200 OK", [("Content-Length", str(cl)), ("Content-Type", "text/plain")] + ) + return [body] diff --git a/libs/waitress/tests/fixtureapps/echo.py b/libs/waitress/tests/fixtureapps/echo.py new file mode 100644 index 000000000..813bdacea --- /dev/null +++ b/libs/waitress/tests/fixtureapps/echo.py @@ -0,0 +1,56 @@ +from collections import namedtuple +import json + + +def app_body_only(environ, start_response): # pragma: no cover + cl = environ.get("CONTENT_LENGTH", None) + if cl is not None: + cl = int(cl) + body = environ["wsgi.input"].read(cl) + cl = str(len(body)) + start_response("200 OK", [("Content-Length", cl), ("Content-Type", "text/plain"),]) + return [body] + + +def app(environ, start_response): # pragma: no cover + cl = environ.get("CONTENT_LENGTH", None) + if cl is not None: + cl = int(cl) + request_body = environ["wsgi.input"].read(cl) + cl = str(len(request_body)) + meta = { + "method": environ["REQUEST_METHOD"], + "path_info": environ["PATH_INFO"], + "script_name": environ["SCRIPT_NAME"], + "query_string": environ["QUERY_STRING"], + "content_length": cl, + "scheme": environ["wsgi.url_scheme"], + "remote_addr": environ["REMOTE_ADDR"], + "remote_host": environ["REMOTE_HOST"], + "server_port": environ["SERVER_PORT"], + "server_name": environ["SERVER_NAME"], + "headers": { + k[len("HTTP_") :]: v for k, v in environ.items() if k.startswith("HTTP_") + }, + } + response = json.dumps(meta).encode("utf8") + b"\r\n\r\n" + request_body + start_response( + "200 OK", + [("Content-Length", str(len(response))), ("Content-Type", "text/plain"),], + ) + return [response] + + +Echo = namedtuple( + "Echo", + ( + "method path_info script_name query_string content_length scheme " + "remote_addr remote_host server_port server_name headers body" + ), +) + + +def parse_response(response): + meta, body = response.split(b"\r\n\r\n", 1) + meta = json.loads(meta.decode("utf8")) + return Echo(body=body, **meta) diff --git a/libs/waitress/tests/fixtureapps/error.py b/libs/waitress/tests/fixtureapps/error.py new file mode 100644 index 000000000..5afb1c542 --- /dev/null +++ b/libs/waitress/tests/fixtureapps/error.py @@ -0,0 +1,21 @@ +def app(environ, start_response): # pragma: no cover + cl = environ.get("CONTENT_LENGTH", None) + if cl is not None: + cl = int(cl) + body = environ["wsgi.input"].read(cl) + cl = str(len(body)) + if environ["PATH_INFO"] == "/before_start_response": + raise ValueError("wrong") + write = start_response( + "200 OK", [("Content-Length", cl), ("Content-Type", "text/plain")] + ) + if environ["PATH_INFO"] == "/after_write_cb": + write("abc") + if environ["PATH_INFO"] == "/in_generator": + + def foo(): + yield "abc" + raise ValueError + + return foo() + raise ValueError("wrong") diff --git a/libs/waitress/tests/fixtureapps/filewrapper.py b/libs/waitress/tests/fixtureapps/filewrapper.py new file mode 100644 index 000000000..63df5a6dc --- /dev/null +++ b/libs/waitress/tests/fixtureapps/filewrapper.py @@ -0,0 +1,93 @@ +import io +import os + +here = os.path.dirname(os.path.abspath(__file__)) +fn = os.path.join(here, "groundhog1.jpg") + + +class KindaFilelike(object): # pragma: no cover + def __init__(self, bytes): + self.bytes = bytes + + def read(self, n): + bytes = self.bytes[:n] + self.bytes = self.bytes[n:] + return bytes + + +class UnseekableIOBase(io.RawIOBase): # pragma: no cover + def __init__(self, bytes): + self.buf = io.BytesIO(bytes) + + def writable(self): + return False + + def readable(self): + return True + + def seekable(self): + return False + + def read(self, n): + return self.buf.read(n) + + +def app(environ, start_response): # pragma: no cover + path_info = environ["PATH_INFO"] + if path_info.startswith("/filelike"): + f = open(fn, "rb") + f.seek(0, 2) + cl = f.tell() + f.seek(0) + if path_info == "/filelike": + headers = [ + ("Content-Length", str(cl)), + ("Content-Type", "image/jpeg"), + ] + elif path_info == "/filelike_nocl": + headers = [("Content-Type", "image/jpeg")] + elif path_info == "/filelike_shortcl": + # short content length + headers = [ + ("Content-Length", "1"), + ("Content-Type", "image/jpeg"), + ] + else: + # long content length (/filelike_longcl) + headers = [ + ("Content-Length", str(cl + 10)), + ("Content-Type", "image/jpeg"), + ] + else: + with open(fn, "rb") as fp: + data = fp.read() + cl = len(data) + f = KindaFilelike(data) + if path_info == "/notfilelike": + headers = [ + ("Content-Length", str(len(data))), + ("Content-Type", "image/jpeg"), + ] + elif path_info == "/notfilelike_iobase": + headers = [ + ("Content-Length", str(len(data))), + ("Content-Type", "image/jpeg"), + ] + f = UnseekableIOBase(data) + elif path_info == "/notfilelike_nocl": + headers = [("Content-Type", "image/jpeg")] + elif path_info == "/notfilelike_shortcl": + # short content length + headers = [ + ("Content-Length", "1"), + ("Content-Type", "image/jpeg"), + ] + else: + # long content length (/notfilelike_longcl) + headers = [ + ("Content-Length", str(cl + 10)), + ("Content-Type", "image/jpeg"), + ] + + start_response("200 OK", headers) + return environ["wsgi.file_wrapper"](f, 8192) diff --git a/libs/waitress/tests/fixtureapps/getline.py b/libs/waitress/tests/fixtureapps/getline.py new file mode 100644 index 000000000..5e0ad3ae5 --- /dev/null +++ b/libs/waitress/tests/fixtureapps/getline.py @@ -0,0 +1,17 @@ +import sys + +if __name__ == "__main__": + try: + from urllib.request import urlopen, URLError + except ImportError: + from urllib2 import urlopen, URLError + + url = sys.argv[1] + headers = {"Content-Type": "text/plain; charset=utf-8"} + try: + resp = urlopen(url) + line = resp.readline().decode("ascii") # py3 + except URLError: + line = "failed to read %s" % url + sys.stdout.write(line) + sys.stdout.flush() diff --git a/libs/waitress/tests/fixtureapps/groundhog1.jpg b/libs/waitress/tests/fixtureapps/groundhog1.jpg Binary files differnew file mode 100644 index 000000000..90f610ea0 --- /dev/null +++ b/libs/waitress/tests/fixtureapps/groundhog1.jpg diff --git a/libs/waitress/tests/fixtureapps/nocl.py b/libs/waitress/tests/fixtureapps/nocl.py new file mode 100644 index 000000000..f82bba0c8 --- /dev/null +++ b/libs/waitress/tests/fixtureapps/nocl.py @@ -0,0 +1,23 @@ +def chunks(l, n): # pragma: no cover + """ Yield successive n-sized chunks from l. + """ + for i in range(0, len(l), n): + yield l[i : i + n] + + +def gen(body): # pragma: no cover + for chunk in chunks(body, 10): + yield chunk + + +def app(environ, start_response): # pragma: no cover + cl = environ.get("CONTENT_LENGTH", None) + if cl is not None: + cl = int(cl) + body = environ["wsgi.input"].read(cl) + start_response("200 OK", [("Content-Type", "text/plain")]) + if environ["PATH_INFO"] == "/list": + return [body] + if environ["PATH_INFO"] == "/list_lentwo": + return [body[0:1], body[1:]] + return gen(body) diff --git a/libs/waitress/tests/fixtureapps/runner.py b/libs/waitress/tests/fixtureapps/runner.py new file mode 100644 index 000000000..1d66ad1cc --- /dev/null +++ b/libs/waitress/tests/fixtureapps/runner.py @@ -0,0 +1,6 @@ +def app(): # pragma: no cover + return None + + +def returns_app(): # pragma: no cover + return app diff --git a/libs/waitress/tests/fixtureapps/sleepy.py b/libs/waitress/tests/fixtureapps/sleepy.py new file mode 100644 index 000000000..2d171d8be --- /dev/null +++ b/libs/waitress/tests/fixtureapps/sleepy.py @@ -0,0 +1,12 @@ +import time + + +def app(environ, start_response): # pragma: no cover + if environ["PATH_INFO"] == "/sleepy": + time.sleep(2) + body = b"sleepy returned" + else: + body = b"notsleepy returned" + cl = str(len(body)) + start_response("200 OK", [("Content-Length", cl), ("Content-Type", "text/plain")]) + return [body] diff --git a/libs/waitress/tests/fixtureapps/toolarge.py b/libs/waitress/tests/fixtureapps/toolarge.py new file mode 100644 index 000000000..a0f36d2cc --- /dev/null +++ b/libs/waitress/tests/fixtureapps/toolarge.py @@ -0,0 +1,7 @@ +def app(environ, start_response): # pragma: no cover + body = b"abcdef" + cl = len(body) + start_response( + "200 OK", [("Content-Length", str(cl)), ("Content-Type", "text/plain")] + ) + return [body] diff --git a/libs/waitress/tests/fixtureapps/writecb.py b/libs/waitress/tests/fixtureapps/writecb.py new file mode 100644 index 000000000..e1d2792e6 --- /dev/null +++ b/libs/waitress/tests/fixtureapps/writecb.py @@ -0,0 +1,14 @@ +def app(environ, start_response): # pragma: no cover + path_info = environ["PATH_INFO"] + if path_info == "/no_content_length": + headers = [] + else: + headers = [("Content-Length", "9")] + write = start_response("200 OK", headers) + if path_info == "/long_body": + write(b"abcdefghij") + elif path_info == "/short_body": + write(b"abcdefgh") + else: + write(b"abcdefghi") + return [] diff --git a/libs/waitress/tests/test_adjustments.py b/libs/waitress/tests/test_adjustments.py new file mode 100644 index 000000000..303c1aa3a --- /dev/null +++ b/libs/waitress/tests/test_adjustments.py @@ -0,0 +1,481 @@ +import sys +import socket +import warnings + +from waitress.compat import ( + PY2, + WIN, +) + +if sys.version_info[:2] == (2, 6): # pragma: no cover + import unittest2 as unittest +else: # pragma: no cover + import unittest + + +class Test_asbool(unittest.TestCase): + def _callFUT(self, s): + from waitress.adjustments import asbool + + return asbool(s) + + def test_s_is_None(self): + result = self._callFUT(None) + self.assertEqual(result, False) + + def test_s_is_True(self): + result = self._callFUT(True) + self.assertEqual(result, True) + + def test_s_is_False(self): + result = self._callFUT(False) + self.assertEqual(result, False) + + def test_s_is_true(self): + result = self._callFUT("True") + self.assertEqual(result, True) + + def test_s_is_false(self): + result = self._callFUT("False") + self.assertEqual(result, False) + + def test_s_is_yes(self): + result = self._callFUT("yes") + self.assertEqual(result, True) + + def test_s_is_on(self): + result = self._callFUT("on") + self.assertEqual(result, True) + + def test_s_is_1(self): + result = self._callFUT(1) + self.assertEqual(result, True) + + +class Test_as_socket_list(unittest.TestCase): + def test_only_sockets_in_list(self): + from waitress.adjustments import as_socket_list + + sockets = [ + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + socket.socket(socket.AF_INET6, socket.SOCK_STREAM), + ] + if hasattr(socket, "AF_UNIX"): + sockets.append(socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)) + new_sockets = as_socket_list(sockets) + self.assertEqual(sockets, new_sockets) + for sock in sockets: + sock.close() + + def test_not_only_sockets_in_list(self): + from waitress.adjustments import as_socket_list + + sockets = [ + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + socket.socket(socket.AF_INET6, socket.SOCK_STREAM), + {"something": "else"}, + ] + new_sockets = as_socket_list(sockets) + self.assertEqual(new_sockets, [sockets[0], sockets[1]]) + for sock in [sock for sock in sockets if isinstance(sock, socket.socket)]: + sock.close() + + +class TestAdjustments(unittest.TestCase): + def _hasIPv6(self): # pragma: nocover + if not socket.has_ipv6: + return False + + try: + socket.getaddrinfo( + "::1", + 0, + socket.AF_UNSPEC, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + socket.AI_PASSIVE | socket.AI_ADDRCONFIG, + ) + + return True + except socket.gaierror as e: + # Check to see what the error is + if e.errno == socket.EAI_ADDRFAMILY: + return False + else: + raise e + + def _makeOne(self, **kw): + from waitress.adjustments import Adjustments + + return Adjustments(**kw) + + def test_goodvars(self): + inst = self._makeOne( + host="localhost", + port="8080", + threads="5", + trusted_proxy="192.168.1.1", + trusted_proxy_headers={"forwarded"}, + trusted_proxy_count=2, + log_untrusted_proxy_headers=True, + url_scheme="https", + backlog="20", + recv_bytes="200", + send_bytes="300", + outbuf_overflow="400", + inbuf_overflow="500", + connection_limit="1000", + cleanup_interval="1100", + channel_timeout="1200", + log_socket_errors="true", + max_request_header_size="1300", + max_request_body_size="1400", + expose_tracebacks="true", + ident="abc", + asyncore_loop_timeout="5", + asyncore_use_poll=True, + unix_socket_perms="777", + url_prefix="///foo/", + ipv4=True, + ipv6=False, + ) + + self.assertEqual(inst.host, "localhost") + self.assertEqual(inst.port, 8080) + self.assertEqual(inst.threads, 5) + self.assertEqual(inst.trusted_proxy, "192.168.1.1") + self.assertEqual(inst.trusted_proxy_headers, {"forwarded"}) + self.assertEqual(inst.trusted_proxy_count, 2) + self.assertEqual(inst.log_untrusted_proxy_headers, True) + self.assertEqual(inst.url_scheme, "https") + self.assertEqual(inst.backlog, 20) + self.assertEqual(inst.recv_bytes, 200) + self.assertEqual(inst.send_bytes, 300) + self.assertEqual(inst.outbuf_overflow, 400) + self.assertEqual(inst.inbuf_overflow, 500) + self.assertEqual(inst.connection_limit, 1000) + self.assertEqual(inst.cleanup_interval, 1100) + self.assertEqual(inst.channel_timeout, 1200) + self.assertEqual(inst.log_socket_errors, True) + self.assertEqual(inst.max_request_header_size, 1300) + self.assertEqual(inst.max_request_body_size, 1400) + self.assertEqual(inst.expose_tracebacks, True) + self.assertEqual(inst.asyncore_loop_timeout, 5) + self.assertEqual(inst.asyncore_use_poll, True) + self.assertEqual(inst.ident, "abc") + self.assertEqual(inst.unix_socket_perms, 0o777) + self.assertEqual(inst.url_prefix, "/foo") + self.assertEqual(inst.ipv4, True) + self.assertEqual(inst.ipv6, False) + + bind_pairs = [ + sockaddr[:2] + for (family, _, _, sockaddr) in inst.listen + if family == socket.AF_INET + ] + + # On Travis, somehow we start listening to two sockets when resolving + # localhost... + self.assertEqual(("127.0.0.1", 8080), bind_pairs[0]) + + def test_goodvar_listen(self): + inst = self._makeOne(listen="127.0.0.1") + + bind_pairs = [(host, port) for (_, _, _, (host, port)) in inst.listen] + + self.assertEqual(bind_pairs, [("127.0.0.1", 8080)]) + + def test_default_listen(self): + inst = self._makeOne() + + bind_pairs = [(host, port) for (_, _, _, (host, port)) in inst.listen] + + self.assertEqual(bind_pairs, [("0.0.0.0", 8080)]) + + def test_multiple_listen(self): + inst = self._makeOne(listen="127.0.0.1:9090 127.0.0.1:8080") + + bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen] + + self.assertEqual(bind_pairs, [("127.0.0.1", 9090), ("127.0.0.1", 8080)]) + + def test_wildcard_listen(self): + inst = self._makeOne(listen="*:8080") + + bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen] + + self.assertTrue(len(bind_pairs) >= 1) + + def test_ipv6_no_port(self): # pragma: nocover + if not self._hasIPv6(): + return + + inst = self._makeOne(listen="[::1]") + + bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen] + + self.assertEqual(bind_pairs, [("::1", 8080)]) + + def test_bad_port(self): + self.assertRaises(ValueError, self._makeOne, listen="127.0.0.1:test") + + def test_service_port(self): + if WIN and PY2: # pragma: no cover + # On Windows and Python 2 this is broken, so we raise a ValueError + self.assertRaises( + ValueError, self._makeOne, listen="127.0.0.1:http", + ) + return + + inst = self._makeOne(listen="127.0.0.1:http 0.0.0.0:https") + + bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen] + + self.assertEqual(bind_pairs, [("127.0.0.1", 80), ("0.0.0.0", 443)]) + + def test_dont_mix_host_port_listen(self): + self.assertRaises( + ValueError, + self._makeOne, + host="localhost", + port="8080", + listen="127.0.0.1:8080", + ) + + def test_good_sockets(self): + sockets = [ + socket.socket(socket.AF_INET6, socket.SOCK_STREAM), + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + ] + inst = self._makeOne(sockets=sockets) + self.assertEqual(inst.sockets, sockets) + sockets[0].close() + sockets[1].close() + + def test_dont_mix_sockets_and_listen(self): + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + self.assertRaises( + ValueError, self._makeOne, listen="127.0.0.1:8080", sockets=sockets + ) + sockets[0].close() + + def test_dont_mix_sockets_and_host_port(self): + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + self.assertRaises( + ValueError, self._makeOne, host="localhost", port="8080", sockets=sockets + ) + sockets[0].close() + + def test_dont_mix_sockets_and_unix_socket(self): + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + self.assertRaises( + ValueError, self._makeOne, unix_socket="./tmp/test", sockets=sockets + ) + sockets[0].close() + + def test_dont_mix_unix_socket_and_host_port(self): + self.assertRaises( + ValueError, + self._makeOne, + unix_socket="./tmp/test", + host="localhost", + port="8080", + ) + + def test_dont_mix_unix_socket_and_listen(self): + self.assertRaises( + ValueError, self._makeOne, unix_socket="./tmp/test", listen="127.0.0.1:8080" + ) + + def test_dont_use_unsupported_socket_types(self): + sockets = [socket.socket(socket.AF_INET, socket.SOCK_DGRAM)] + self.assertRaises(ValueError, self._makeOne, sockets=sockets) + sockets[0].close() + + def test_dont_mix_forwarded_with_x_forwarded(self): + with self.assertRaises(ValueError) as cm: + self._makeOne( + trusted_proxy="localhost", + trusted_proxy_headers={"forwarded", "x-forwarded-for"}, + ) + + self.assertIn("The Forwarded proxy header", str(cm.exception)) + + def test_unknown_trusted_proxy_header(self): + with self.assertRaises(ValueError) as cm: + self._makeOne( + trusted_proxy="localhost", + trusted_proxy_headers={"forwarded", "x-forwarded-unknown"}, + ) + + self.assertIn( + "unknown trusted_proxy_headers value (x-forwarded-unknown)", + str(cm.exception), + ) + + def test_trusted_proxy_count_no_trusted_proxy(self): + with self.assertRaises(ValueError) as cm: + self._makeOne(trusted_proxy_count=1) + + self.assertIn("trusted_proxy_count has no meaning", str(cm.exception)) + + def test_trusted_proxy_headers_no_trusted_proxy(self): + with self.assertRaises(ValueError) as cm: + self._makeOne(trusted_proxy_headers={"forwarded"}) + + self.assertIn("trusted_proxy_headers has no meaning", str(cm.exception)) + + def test_trusted_proxy_headers_string_list(self): + inst = self._makeOne( + trusted_proxy="localhost", + trusted_proxy_headers="x-forwarded-for x-forwarded-by", + ) + self.assertEqual( + inst.trusted_proxy_headers, {"x-forwarded-for", "x-forwarded-by"} + ) + + def test_trusted_proxy_headers_string_list_newlines(self): + inst = self._makeOne( + trusted_proxy="localhost", + trusted_proxy_headers="x-forwarded-for\nx-forwarded-by\nx-forwarded-host", + ) + self.assertEqual( + inst.trusted_proxy_headers, + {"x-forwarded-for", "x-forwarded-by", "x-forwarded-host"}, + ) + + def test_no_trusted_proxy_headers_trusted_proxy(self): + with warnings.catch_warnings(record=True) as w: + warnings.resetwarnings() + warnings.simplefilter("always") + self._makeOne(trusted_proxy="localhost") + + self.assertGreaterEqual(len(w), 1) + self.assertTrue(issubclass(w[0].category, DeprecationWarning)) + self.assertIn("Implicitly trusting X-Forwarded-Proto", str(w[0])) + + def test_clear_untrusted_proxy_headers(self): + with warnings.catch_warnings(record=True) as w: + warnings.resetwarnings() + warnings.simplefilter("always") + self._makeOne( + trusted_proxy="localhost", trusted_proxy_headers={"x-forwarded-for"} + ) + + self.assertGreaterEqual(len(w), 1) + self.assertTrue(issubclass(w[0].category, DeprecationWarning)) + self.assertIn( + "clear_untrusted_proxy_headers will be set to True", str(w[0]) + ) + + def test_deprecated_send_bytes(self): + with warnings.catch_warnings(record=True) as w: + warnings.resetwarnings() + warnings.simplefilter("always") + self._makeOne(send_bytes=1) + + self.assertGreaterEqual(len(w), 1) + self.assertTrue(issubclass(w[0].category, DeprecationWarning)) + self.assertIn("send_bytes", str(w[0])) + + def test_badvar(self): + self.assertRaises(ValueError, self._makeOne, nope=True) + + def test_ipv4_disabled(self): + self.assertRaises( + ValueError, self._makeOne, ipv4=False, listen="127.0.0.1:8080" + ) + + def test_ipv6_disabled(self): + self.assertRaises(ValueError, self._makeOne, ipv6=False, listen="[::]:8080") + + def test_server_header_removable(self): + inst = self._makeOne(ident=None) + self.assertEqual(inst.ident, None) + + inst = self._makeOne(ident="") + self.assertEqual(inst.ident, None) + + inst = self._makeOne(ident="specific_header") + self.assertEqual(inst.ident, "specific_header") + + +class TestCLI(unittest.TestCase): + def parse(self, argv): + from waitress.adjustments import Adjustments + + return Adjustments.parse_args(argv) + + def test_noargs(self): + opts, args = self.parse([]) + self.assertDictEqual(opts, {"call": False, "help": False}) + self.assertSequenceEqual(args, []) + + def test_help(self): + opts, args = self.parse(["--help"]) + self.assertDictEqual(opts, {"call": False, "help": True}) + self.assertSequenceEqual(args, []) + + def test_call(self): + opts, args = self.parse(["--call"]) + self.assertDictEqual(opts, {"call": True, "help": False}) + self.assertSequenceEqual(args, []) + + def test_both(self): + opts, args = self.parse(["--call", "--help"]) + self.assertDictEqual(opts, {"call": True, "help": True}) + self.assertSequenceEqual(args, []) + + def test_positive_boolean(self): + opts, args = self.parse(["--expose-tracebacks"]) + self.assertDictContainsSubset({"expose_tracebacks": "true"}, opts) + self.assertSequenceEqual(args, []) + + def test_negative_boolean(self): + opts, args = self.parse(["--no-expose-tracebacks"]) + self.assertDictContainsSubset({"expose_tracebacks": "false"}, opts) + self.assertSequenceEqual(args, []) + + def test_cast_params(self): + opts, args = self.parse( + ["--host=localhost", "--port=80", "--unix-socket-perms=777"] + ) + self.assertDictContainsSubset( + {"host": "localhost", "port": "80", "unix_socket_perms": "777",}, opts + ) + self.assertSequenceEqual(args, []) + + def test_listen_params(self): + opts, args = self.parse(["--listen=test:80",]) + + self.assertDictContainsSubset({"listen": " test:80"}, opts) + self.assertSequenceEqual(args, []) + + def test_multiple_listen_params(self): + opts, args = self.parse(["--listen=test:80", "--listen=test:8080",]) + + self.assertDictContainsSubset({"listen": " test:80 test:8080"}, opts) + self.assertSequenceEqual(args, []) + + def test_bad_param(self): + import getopt + + self.assertRaises(getopt.GetoptError, self.parse, ["--no-host"]) + + +if hasattr(socket, "AF_UNIX"): + + class TestUnixSocket(unittest.TestCase): + def _makeOne(self, **kw): + from waitress.adjustments import Adjustments + + return Adjustments(**kw) + + def test_dont_mix_internet_and_unix_sockets(self): + sockets = [ + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + socket.socket(socket.AF_UNIX, socket.SOCK_STREAM), + ] + self.assertRaises(ValueError, self._makeOne, sockets=sockets) + sockets[0].close() + sockets[1].close() diff --git a/libs/waitress/tests/test_buffers.py b/libs/waitress/tests/test_buffers.py new file mode 100644 index 000000000..a1330ac1b --- /dev/null +++ b/libs/waitress/tests/test_buffers.py @@ -0,0 +1,523 @@ +import unittest +import io + + +class TestFileBasedBuffer(unittest.TestCase): + def _makeOne(self, file=None, from_buffer=None): + from waitress.buffers import FileBasedBuffer + + buf = FileBasedBuffer(file, from_buffer=from_buffer) + self.buffers_to_close.append(buf) + return buf + + def setUp(self): + self.buffers_to_close = [] + + def tearDown(self): + for buf in self.buffers_to_close: + buf.close() + + def test_ctor_from_buffer_None(self): + inst = self._makeOne("file") + self.assertEqual(inst.file, "file") + + def test_ctor_from_buffer(self): + from_buffer = io.BytesIO(b"data") + from_buffer.getfile = lambda *x: from_buffer + f = io.BytesIO() + inst = self._makeOne(f, from_buffer) + self.assertEqual(inst.file, f) + del from_buffer.getfile + self.assertEqual(inst.remain, 4) + from_buffer.close() + + def test___len__(self): + inst = self._makeOne() + inst.remain = 10 + self.assertEqual(len(inst), 10) + + def test___nonzero__(self): + inst = self._makeOne() + inst.remain = 10 + self.assertEqual(bool(inst), True) + inst.remain = 0 + self.assertEqual(bool(inst), True) + + def test_append(self): + f = io.BytesIO(b"data") + inst = self._makeOne(f) + inst.append(b"data2") + self.assertEqual(f.getvalue(), b"datadata2") + self.assertEqual(inst.remain, 5) + + def test_get_skip_true(self): + f = io.BytesIO(b"data") + inst = self._makeOne(f) + result = inst.get(100, skip=True) + self.assertEqual(result, b"data") + self.assertEqual(inst.remain, -4) + + def test_get_skip_false(self): + f = io.BytesIO(b"data") + inst = self._makeOne(f) + result = inst.get(100, skip=False) + self.assertEqual(result, b"data") + self.assertEqual(inst.remain, 0) + + def test_get_skip_bytes_less_than_zero(self): + f = io.BytesIO(b"data") + inst = self._makeOne(f) + result = inst.get(-1, skip=False) + self.assertEqual(result, b"data") + self.assertEqual(inst.remain, 0) + + def test_skip_remain_gt_bytes(self): + f = io.BytesIO(b"d") + inst = self._makeOne(f) + inst.remain = 1 + inst.skip(1) + self.assertEqual(inst.remain, 0) + + def test_skip_remain_lt_bytes(self): + f = io.BytesIO(b"d") + inst = self._makeOne(f) + inst.remain = 1 + self.assertRaises(ValueError, inst.skip, 2) + + def test_newfile(self): + inst = self._makeOne() + self.assertRaises(NotImplementedError, inst.newfile) + + def test_prune_remain_notzero(self): + f = io.BytesIO(b"d") + inst = self._makeOne(f) + inst.remain = 1 + nf = io.BytesIO() + inst.newfile = lambda *x: nf + inst.prune() + self.assertTrue(inst.file is not f) + self.assertEqual(nf.getvalue(), b"d") + + def test_prune_remain_zero_tell_notzero(self): + f = io.BytesIO(b"d") + inst = self._makeOne(f) + nf = io.BytesIO(b"d") + inst.newfile = lambda *x: nf + inst.remain = 0 + inst.prune() + self.assertTrue(inst.file is not f) + self.assertEqual(nf.getvalue(), b"d") + + def test_prune_remain_zero_tell_zero(self): + f = io.BytesIO() + inst = self._makeOne(f) + inst.remain = 0 + inst.prune() + self.assertTrue(inst.file is f) + + def test_close(self): + f = io.BytesIO() + inst = self._makeOne(f) + inst.close() + self.assertTrue(f.closed) + self.buffers_to_close.remove(inst) + + +class TestTempfileBasedBuffer(unittest.TestCase): + def _makeOne(self, from_buffer=None): + from waitress.buffers import TempfileBasedBuffer + + buf = TempfileBasedBuffer(from_buffer=from_buffer) + self.buffers_to_close.append(buf) + return buf + + def setUp(self): + self.buffers_to_close = [] + + def tearDown(self): + for buf in self.buffers_to_close: + buf.close() + + def test_newfile(self): + inst = self._makeOne() + r = inst.newfile() + self.assertTrue(hasattr(r, "fileno")) # file + r.close() + + +class TestBytesIOBasedBuffer(unittest.TestCase): + def _makeOne(self, from_buffer=None): + from waitress.buffers import BytesIOBasedBuffer + + return BytesIOBasedBuffer(from_buffer=from_buffer) + + def test_ctor_from_buffer_not_None(self): + f = io.BytesIO() + f.getfile = lambda *x: f + inst = self._makeOne(f) + self.assertTrue(hasattr(inst.file, "read")) + + def test_ctor_from_buffer_None(self): + inst = self._makeOne() + self.assertTrue(hasattr(inst.file, "read")) + + def test_newfile(self): + inst = self._makeOne() + r = inst.newfile() + self.assertTrue(hasattr(r, "read")) + + +class TestReadOnlyFileBasedBuffer(unittest.TestCase): + def _makeOne(self, file, block_size=8192): + from waitress.buffers import ReadOnlyFileBasedBuffer + + buf = ReadOnlyFileBasedBuffer(file, block_size) + self.buffers_to_close.append(buf) + return buf + + def setUp(self): + self.buffers_to_close = [] + + def tearDown(self): + for buf in self.buffers_to_close: + buf.close() + + def test_prepare_not_seekable(self): + f = KindaFilelike(b"abc") + inst = self._makeOne(f) + result = inst.prepare() + self.assertEqual(result, False) + self.assertEqual(inst.remain, 0) + + def test_prepare_not_seekable_closeable(self): + f = KindaFilelike(b"abc", close=1) + inst = self._makeOne(f) + result = inst.prepare() + self.assertEqual(result, False) + self.assertEqual(inst.remain, 0) + self.assertTrue(hasattr(inst, "close")) + + def test_prepare_seekable_closeable(self): + f = Filelike(b"abc", close=1, tellresults=[0, 10]) + inst = self._makeOne(f) + result = inst.prepare() + self.assertEqual(result, 10) + self.assertEqual(inst.remain, 10) + self.assertEqual(inst.file.seeked, 0) + self.assertTrue(hasattr(inst, "close")) + + def test_get_numbytes_neg_one(self): + f = io.BytesIO(b"abcdef") + inst = self._makeOne(f) + inst.remain = 2 + result = inst.get(-1) + self.assertEqual(result, b"ab") + self.assertEqual(inst.remain, 2) + self.assertEqual(f.tell(), 0) + + def test_get_numbytes_gt_remain(self): + f = io.BytesIO(b"abcdef") + inst = self._makeOne(f) + inst.remain = 2 + result = inst.get(3) + self.assertEqual(result, b"ab") + self.assertEqual(inst.remain, 2) + self.assertEqual(f.tell(), 0) + + def test_get_numbytes_lt_remain(self): + f = io.BytesIO(b"abcdef") + inst = self._makeOne(f) + inst.remain = 2 + result = inst.get(1) + self.assertEqual(result, b"a") + self.assertEqual(inst.remain, 2) + self.assertEqual(f.tell(), 0) + + def test_get_numbytes_gt_remain_withskip(self): + f = io.BytesIO(b"abcdef") + inst = self._makeOne(f) + inst.remain = 2 + result = inst.get(3, skip=True) + self.assertEqual(result, b"ab") + self.assertEqual(inst.remain, 0) + self.assertEqual(f.tell(), 2) + + def test_get_numbytes_lt_remain_withskip(self): + f = io.BytesIO(b"abcdef") + inst = self._makeOne(f) + inst.remain = 2 + result = inst.get(1, skip=True) + self.assertEqual(result, b"a") + self.assertEqual(inst.remain, 1) + self.assertEqual(f.tell(), 1) + + def test___iter__(self): + data = b"a" * 10000 + f = io.BytesIO(data) + inst = self._makeOne(f) + r = b"" + for val in inst: + r += val + self.assertEqual(r, data) + + def test_append(self): + inst = self._makeOne(None) + self.assertRaises(NotImplementedError, inst.append, "a") + + +class TestOverflowableBuffer(unittest.TestCase): + def _makeOne(self, overflow=10): + from waitress.buffers import OverflowableBuffer + + buf = OverflowableBuffer(overflow) + self.buffers_to_close.append(buf) + return buf + + def setUp(self): + self.buffers_to_close = [] + + def tearDown(self): + for buf in self.buffers_to_close: + buf.close() + + def test___len__buf_is_None(self): + inst = self._makeOne() + self.assertEqual(len(inst), 0) + + def test___len__buf_is_not_None(self): + inst = self._makeOne() + inst.buf = b"abc" + self.assertEqual(len(inst), 3) + self.buffers_to_close.remove(inst) + + def test___nonzero__(self): + inst = self._makeOne() + inst.buf = b"abc" + self.assertEqual(bool(inst), True) + inst.buf = b"" + self.assertEqual(bool(inst), False) + self.buffers_to_close.remove(inst) + + def test___nonzero___on_int_overflow_buffer(self): + inst = self._makeOne() + + class int_overflow_buf(bytes): + def __len__(self): + # maxint + 1 + return 0x7FFFFFFFFFFFFFFF + 1 + + inst.buf = int_overflow_buf() + self.assertEqual(bool(inst), True) + inst.buf = b"" + self.assertEqual(bool(inst), False) + self.buffers_to_close.remove(inst) + + def test__create_buffer_large(self): + from waitress.buffers import TempfileBasedBuffer + + inst = self._makeOne() + inst.strbuf = b"x" * 11 + inst._create_buffer() + self.assertEqual(inst.buf.__class__, TempfileBasedBuffer) + self.assertEqual(inst.buf.get(100), b"x" * 11) + self.assertEqual(inst.strbuf, b"") + + def test__create_buffer_small(self): + from waitress.buffers import BytesIOBasedBuffer + + inst = self._makeOne() + inst.strbuf = b"x" * 5 + inst._create_buffer() + self.assertEqual(inst.buf.__class__, BytesIOBasedBuffer) + self.assertEqual(inst.buf.get(100), b"x" * 5) + self.assertEqual(inst.strbuf, b"") + + def test_append_with_len_more_than_max_int(self): + from waitress.compat import MAXINT + + inst = self._makeOne() + inst.overflowed = True + buf = DummyBuffer(length=MAXINT) + inst.buf = buf + result = inst.append(b"x") + # we don't want this to throw an OverflowError on Python 2 (see + # https://github.com/Pylons/waitress/issues/47) + self.assertEqual(result, None) + self.buffers_to_close.remove(inst) + + def test_append_buf_None_not_longer_than_srtbuf_limit(self): + inst = self._makeOne() + inst.strbuf = b"x" * 5 + inst.append(b"hello") + self.assertEqual(inst.strbuf, b"xxxxxhello") + + def test_append_buf_None_longer_than_strbuf_limit(self): + inst = self._makeOne(10000) + inst.strbuf = b"x" * 8192 + inst.append(b"hello") + self.assertEqual(inst.strbuf, b"") + self.assertEqual(len(inst.buf), 8197) + + def test_append_overflow(self): + inst = self._makeOne(10) + inst.strbuf = b"x" * 8192 + inst.append(b"hello") + self.assertEqual(inst.strbuf, b"") + self.assertEqual(len(inst.buf), 8197) + + def test_append_sz_gt_overflow(self): + from waitress.buffers import BytesIOBasedBuffer + + f = io.BytesIO(b"data") + inst = self._makeOne(f) + buf = BytesIOBasedBuffer() + inst.buf = buf + inst.overflow = 2 + inst.append(b"data2") + self.assertEqual(f.getvalue(), b"data") + self.assertTrue(inst.overflowed) + self.assertNotEqual(inst.buf, buf) + + def test_get_buf_None_skip_False(self): + inst = self._makeOne() + inst.strbuf = b"x" * 5 + r = inst.get(5) + self.assertEqual(r, b"xxxxx") + + def test_get_buf_None_skip_True(self): + inst = self._makeOne() + inst.strbuf = b"x" * 5 + r = inst.get(5, skip=True) + self.assertFalse(inst.buf is None) + self.assertEqual(r, b"xxxxx") + + def test_skip_buf_None(self): + inst = self._makeOne() + inst.strbuf = b"data" + inst.skip(4) + self.assertEqual(inst.strbuf, b"") + self.assertNotEqual(inst.buf, None) + + def test_skip_buf_None_allow_prune_True(self): + inst = self._makeOne() + inst.strbuf = b"data" + inst.skip(4, True) + self.assertEqual(inst.strbuf, b"") + self.assertEqual(inst.buf, None) + + def test_prune_buf_None(self): + inst = self._makeOne() + inst.prune() + self.assertEqual(inst.strbuf, b"") + + def test_prune_with_buf(self): + inst = self._makeOne() + + class Buf(object): + def prune(self): + self.pruned = True + + inst.buf = Buf() + inst.prune() + self.assertEqual(inst.buf.pruned, True) + self.buffers_to_close.remove(inst) + + def test_prune_with_buf_overflow(self): + inst = self._makeOne() + + class DummyBuffer(io.BytesIO): + def getfile(self): + return self + + def prune(self): + return True + + def __len__(self): + return 5 + + def close(self): + pass + + buf = DummyBuffer(b"data") + inst.buf = buf + inst.overflowed = True + inst.overflow = 10 + inst.prune() + self.assertNotEqual(inst.buf, buf) + + def test_prune_with_buflen_more_than_max_int(self): + from waitress.compat import MAXINT + + inst = self._makeOne() + inst.overflowed = True + buf = DummyBuffer(length=MAXINT + 1) + inst.buf = buf + result = inst.prune() + # we don't want this to throw an OverflowError on Python 2 (see + # https://github.com/Pylons/waitress/issues/47) + self.assertEqual(result, None) + + def test_getfile_buf_None(self): + inst = self._makeOne() + f = inst.getfile() + self.assertTrue(hasattr(f, "read")) + + def test_getfile_buf_not_None(self): + inst = self._makeOne() + buf = io.BytesIO() + buf.getfile = lambda *x: buf + inst.buf = buf + f = inst.getfile() + self.assertEqual(f, buf) + + def test_close_nobuf(self): + inst = self._makeOne() + inst.buf = None + self.assertEqual(inst.close(), None) # doesnt raise + self.buffers_to_close.remove(inst) + + def test_close_withbuf(self): + class Buffer(object): + def close(self): + self.closed = True + + buf = Buffer() + inst = self._makeOne() + inst.buf = buf + inst.close() + self.assertTrue(buf.closed) + self.buffers_to_close.remove(inst) + + +class KindaFilelike(object): + def __init__(self, bytes, close=None, tellresults=None): + self.bytes = bytes + self.tellresults = tellresults + if close is not None: + self.close = lambda: close + + +class Filelike(KindaFilelike): + def seek(self, v, whence=0): + self.seeked = v + + def tell(self): + v = self.tellresults.pop(0) + return v + + +class DummyBuffer(object): + def __init__(self, length=0): + self.length = length + + def __len__(self): + return self.length + + def append(self, s): + self.length = self.length + len(s) + + def prune(self): + pass + + def close(self): + pass diff --git a/libs/waitress/tests/test_channel.py b/libs/waitress/tests/test_channel.py new file mode 100644 index 000000000..14ef5a0ec --- /dev/null +++ b/libs/waitress/tests/test_channel.py @@ -0,0 +1,882 @@ +import unittest +import io + + +class TestHTTPChannel(unittest.TestCase): + def _makeOne(self, sock, addr, adj, map=None): + from waitress.channel import HTTPChannel + + server = DummyServer() + return HTTPChannel(server, sock, addr, adj=adj, map=map) + + def _makeOneWithMap(self, adj=None): + if adj is None: + adj = DummyAdjustments() + sock = DummySock() + map = {} + inst = self._makeOne(sock, "127.0.0.1", adj, map=map) + inst.outbuf_lock = DummyLock() + return inst, sock, map + + def test_ctor(self): + inst, _, map = self._makeOneWithMap() + self.assertEqual(inst.addr, "127.0.0.1") + self.assertEqual(inst.sendbuf_len, 2048) + self.assertEqual(map[100], inst) + + def test_total_outbufs_len_an_outbuf_size_gt_sys_maxint(self): + from waitress.compat import MAXINT + + inst, _, map = self._makeOneWithMap() + + class DummyBuffer(object): + chunks = [] + + def append(self, data): + self.chunks.append(data) + + class DummyData(object): + def __len__(self): + return MAXINT + + inst.total_outbufs_len = 1 + inst.outbufs = [DummyBuffer()] + inst.write_soon(DummyData()) + # we are testing that this method does not raise an OverflowError + # (see https://github.com/Pylons/waitress/issues/47) + self.assertEqual(inst.total_outbufs_len, MAXINT + 1) + + def test_writable_something_in_outbuf(self): + inst, sock, map = self._makeOneWithMap() + inst.total_outbufs_len = 3 + self.assertTrue(inst.writable()) + + def test_writable_nothing_in_outbuf(self): + inst, sock, map = self._makeOneWithMap() + self.assertFalse(inst.writable()) + + def test_writable_nothing_in_outbuf_will_close(self): + inst, sock, map = self._makeOneWithMap() + inst.will_close = True + self.assertTrue(inst.writable()) + + def test_handle_write_not_connected(self): + inst, sock, map = self._makeOneWithMap() + inst.connected = False + self.assertFalse(inst.handle_write()) + + def test_handle_write_with_requests(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = True + inst.last_activity = 0 + result = inst.handle_write() + self.assertEqual(result, None) + self.assertEqual(inst.last_activity, 0) + + def test_handle_write_no_request_with_outbuf(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = [] + inst.outbufs = [DummyBuffer(b"abc")] + inst.total_outbufs_len = len(inst.outbufs[0]) + inst.last_activity = 0 + result = inst.handle_write() + self.assertEqual(result, None) + self.assertNotEqual(inst.last_activity, 0) + self.assertEqual(sock.sent, b"abc") + + def test_handle_write_outbuf_raises_socketerror(self): + import socket + + inst, sock, map = self._makeOneWithMap() + inst.requests = [] + outbuf = DummyBuffer(b"abc", socket.error) + inst.outbufs = [outbuf] + inst.total_outbufs_len = len(outbuf) + inst.last_activity = 0 + inst.logger = DummyLogger() + result = inst.handle_write() + self.assertEqual(result, None) + self.assertEqual(inst.last_activity, 0) + self.assertEqual(sock.sent, b"") + self.assertEqual(len(inst.logger.exceptions), 1) + self.assertTrue(outbuf.closed) + + def test_handle_write_outbuf_raises_othererror(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = [] + outbuf = DummyBuffer(b"abc", IOError) + inst.outbufs = [outbuf] + inst.total_outbufs_len = len(outbuf) + inst.last_activity = 0 + inst.logger = DummyLogger() + result = inst.handle_write() + self.assertEqual(result, None) + self.assertEqual(inst.last_activity, 0) + self.assertEqual(sock.sent, b"") + self.assertEqual(len(inst.logger.exceptions), 1) + self.assertTrue(outbuf.closed) + + def test_handle_write_no_requests_no_outbuf_will_close(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = [] + outbuf = DummyBuffer(b"") + inst.outbufs = [outbuf] + inst.will_close = True + inst.last_activity = 0 + result = inst.handle_write() + self.assertEqual(result, None) + self.assertEqual(inst.connected, False) + self.assertEqual(sock.closed, True) + self.assertEqual(inst.last_activity, 0) + self.assertTrue(outbuf.closed) + + def test_handle_write_no_requests_outbuf_gt_send_bytes(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = [True] + inst.outbufs = [DummyBuffer(b"abc")] + inst.total_outbufs_len = len(inst.outbufs[0]) + inst.adj.send_bytes = 2 + inst.will_close = False + inst.last_activity = 0 + result = inst.handle_write() + self.assertEqual(result, None) + self.assertEqual(inst.will_close, False) + self.assertTrue(inst.outbuf_lock.acquired) + self.assertEqual(sock.sent, b"abc") + + def test_handle_write_close_when_flushed(self): + inst, sock, map = self._makeOneWithMap() + outbuf = DummyBuffer(b"abc") + inst.outbufs = [outbuf] + inst.total_outbufs_len = len(outbuf) + inst.will_close = False + inst.close_when_flushed = True + inst.last_activity = 0 + result = inst.handle_write() + self.assertEqual(result, None) + self.assertEqual(inst.will_close, True) + self.assertEqual(inst.close_when_flushed, False) + self.assertEqual(sock.sent, b"abc") + self.assertTrue(outbuf.closed) + + def test_readable_no_requests_not_will_close(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = [] + inst.will_close = False + self.assertEqual(inst.readable(), True) + + def test_readable_no_requests_will_close(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = [] + inst.will_close = True + self.assertEqual(inst.readable(), False) + + def test_readable_with_requests(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = True + self.assertEqual(inst.readable(), False) + + def test_handle_read_no_error(self): + inst, sock, map = self._makeOneWithMap() + inst.will_close = False + inst.recv = lambda *arg: b"abc" + inst.last_activity = 0 + L = [] + inst.received = lambda x: L.append(x) + result = inst.handle_read() + self.assertEqual(result, None) + self.assertNotEqual(inst.last_activity, 0) + self.assertEqual(L, [b"abc"]) + + def test_handle_read_error(self): + import socket + + inst, sock, map = self._makeOneWithMap() + inst.will_close = False + + def recv(b): + raise socket.error + + inst.recv = recv + inst.last_activity = 0 + inst.logger = DummyLogger() + result = inst.handle_read() + self.assertEqual(result, None) + self.assertEqual(inst.last_activity, 0) + self.assertEqual(len(inst.logger.exceptions), 1) + + def test_write_soon_empty_byte(self): + inst, sock, map = self._makeOneWithMap() + wrote = inst.write_soon(b"") + self.assertEqual(wrote, 0) + self.assertEqual(len(inst.outbufs[0]), 0) + + def test_write_soon_nonempty_byte(self): + inst, sock, map = self._makeOneWithMap() + wrote = inst.write_soon(b"a") + self.assertEqual(wrote, 1) + self.assertEqual(len(inst.outbufs[0]), 1) + + def test_write_soon_filewrapper(self): + from waitress.buffers import ReadOnlyFileBasedBuffer + + f = io.BytesIO(b"abc") + wrapper = ReadOnlyFileBasedBuffer(f, 8192) + wrapper.prepare() + inst, sock, map = self._makeOneWithMap() + outbufs = inst.outbufs + orig_outbuf = outbufs[0] + wrote = inst.write_soon(wrapper) + self.assertEqual(wrote, 3) + self.assertEqual(len(outbufs), 3) + self.assertEqual(outbufs[0], orig_outbuf) + self.assertEqual(outbufs[1], wrapper) + self.assertEqual(outbufs[2].__class__.__name__, "OverflowableBuffer") + + def test_write_soon_disconnected(self): + from waitress.channel import ClientDisconnected + + inst, sock, map = self._makeOneWithMap() + inst.connected = False + self.assertRaises(ClientDisconnected, lambda: inst.write_soon(b"stuff")) + + def test_write_soon_disconnected_while_over_watermark(self): + from waitress.channel import ClientDisconnected + + inst, sock, map = self._makeOneWithMap() + + def dummy_flush(): + inst.connected = False + + inst._flush_outbufs_below_high_watermark = dummy_flush + self.assertRaises(ClientDisconnected, lambda: inst.write_soon(b"stuff")) + + def test_write_soon_rotates_outbuf_on_overflow(self): + inst, sock, map = self._makeOneWithMap() + inst.adj.outbuf_high_watermark = 3 + inst.current_outbuf_count = 4 + wrote = inst.write_soon(b"xyz") + self.assertEqual(wrote, 3) + self.assertEqual(len(inst.outbufs), 2) + self.assertEqual(inst.outbufs[0].get(), b"") + self.assertEqual(inst.outbufs[1].get(), b"xyz") + + def test_write_soon_waits_on_backpressure(self): + inst, sock, map = self._makeOneWithMap() + inst.adj.outbuf_high_watermark = 3 + inst.total_outbufs_len = 4 + inst.current_outbuf_count = 4 + + class Lock(DummyLock): + def wait(self): + inst.total_outbufs_len = 0 + super(Lock, self).wait() + + inst.outbuf_lock = Lock() + wrote = inst.write_soon(b"xyz") + self.assertEqual(wrote, 3) + self.assertEqual(len(inst.outbufs), 2) + self.assertEqual(inst.outbufs[0].get(), b"") + self.assertEqual(inst.outbufs[1].get(), b"xyz") + self.assertTrue(inst.outbuf_lock.waited) + + def test_handle_write_notify_after_flush(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = [True] + inst.outbufs = [DummyBuffer(b"abc")] + inst.total_outbufs_len = len(inst.outbufs[0]) + inst.adj.send_bytes = 1 + inst.adj.outbuf_high_watermark = 5 + inst.will_close = False + inst.last_activity = 0 + result = inst.handle_write() + self.assertEqual(result, None) + self.assertEqual(inst.will_close, False) + self.assertTrue(inst.outbuf_lock.acquired) + self.assertTrue(inst.outbuf_lock.notified) + self.assertEqual(sock.sent, b"abc") + + def test_handle_write_no_notify_after_flush(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = [True] + inst.outbufs = [DummyBuffer(b"abc")] + inst.total_outbufs_len = len(inst.outbufs[0]) + inst.adj.send_bytes = 1 + inst.adj.outbuf_high_watermark = 2 + sock.send = lambda x: False + inst.will_close = False + inst.last_activity = 0 + result = inst.handle_write() + self.assertEqual(result, None) + self.assertEqual(inst.will_close, False) + self.assertTrue(inst.outbuf_lock.acquired) + self.assertFalse(inst.outbuf_lock.notified) + self.assertEqual(sock.sent, b"") + + def test__flush_some_empty_outbuf(self): + inst, sock, map = self._makeOneWithMap() + result = inst._flush_some() + self.assertEqual(result, False) + + def test__flush_some_full_outbuf_socket_returns_nonzero(self): + inst, sock, map = self._makeOneWithMap() + inst.outbufs[0].append(b"abc") + inst.total_outbufs_len = sum(len(x) for x in inst.outbufs) + result = inst._flush_some() + self.assertEqual(result, True) + + def test__flush_some_full_outbuf_socket_returns_zero(self): + inst, sock, map = self._makeOneWithMap() + sock.send = lambda x: False + inst.outbufs[0].append(b"abc") + inst.total_outbufs_len = sum(len(x) for x in inst.outbufs) + result = inst._flush_some() + self.assertEqual(result, False) + + def test_flush_some_multiple_buffers_first_empty(self): + inst, sock, map = self._makeOneWithMap() + sock.send = lambda x: len(x) + buffer = DummyBuffer(b"abc") + inst.outbufs.append(buffer) + inst.total_outbufs_len = sum(len(x) for x in inst.outbufs) + result = inst._flush_some() + self.assertEqual(result, True) + self.assertEqual(buffer.skipped, 3) + self.assertEqual(inst.outbufs, [buffer]) + + def test_flush_some_multiple_buffers_close_raises(self): + inst, sock, map = self._makeOneWithMap() + sock.send = lambda x: len(x) + buffer = DummyBuffer(b"abc") + inst.outbufs.append(buffer) + inst.total_outbufs_len = sum(len(x) for x in inst.outbufs) + inst.logger = DummyLogger() + + def doraise(): + raise NotImplementedError + + inst.outbufs[0].close = doraise + result = inst._flush_some() + self.assertEqual(result, True) + self.assertEqual(buffer.skipped, 3) + self.assertEqual(inst.outbufs, [buffer]) + self.assertEqual(len(inst.logger.exceptions), 1) + + def test__flush_some_outbuf_len_gt_sys_maxint(self): + from waitress.compat import MAXINT + + inst, sock, map = self._makeOneWithMap() + + class DummyHugeOutbuffer(object): + def __init__(self): + self.length = MAXINT + 1 + + def __len__(self): + return self.length + + def get(self, numbytes): + self.length = 0 + return b"123" + + buf = DummyHugeOutbuffer() + inst.outbufs = [buf] + inst.send = lambda *arg: 0 + result = inst._flush_some() + # we are testing that _flush_some doesn't raise an OverflowError + # when one of its outbufs has a __len__ that returns gt sys.maxint + self.assertEqual(result, False) + + def test_handle_close(self): + inst, sock, map = self._makeOneWithMap() + inst.handle_close() + self.assertEqual(inst.connected, False) + self.assertEqual(sock.closed, True) + + def test_handle_close_outbuf_raises_on_close(self): + inst, sock, map = self._makeOneWithMap() + + def doraise(): + raise NotImplementedError + + inst.outbufs[0].close = doraise + inst.logger = DummyLogger() + inst.handle_close() + self.assertEqual(inst.connected, False) + self.assertEqual(sock.closed, True) + self.assertEqual(len(inst.logger.exceptions), 1) + + def test_add_channel(self): + inst, sock, map = self._makeOneWithMap() + fileno = inst._fileno + inst.add_channel(map) + self.assertEqual(map[fileno], inst) + self.assertEqual(inst.server.active_channels[fileno], inst) + + def test_del_channel(self): + inst, sock, map = self._makeOneWithMap() + fileno = inst._fileno + inst.server.active_channels[fileno] = True + inst.del_channel(map) + self.assertEqual(map.get(fileno), None) + self.assertEqual(inst.server.active_channels.get(fileno), None) + + def test_received(self): + inst, sock, map = self._makeOneWithMap() + inst.server = DummyServer() + inst.received(b"GET / HTTP/1.1\r\n\r\n") + self.assertEqual(inst.server.tasks, [inst]) + self.assertTrue(inst.requests) + + def test_received_no_chunk(self): + inst, sock, map = self._makeOneWithMap() + self.assertEqual(inst.received(b""), False) + + def test_received_preq_not_completed(self): + inst, sock, map = self._makeOneWithMap() + inst.server = DummyServer() + preq = DummyParser() + inst.request = preq + preq.completed = False + preq.empty = True + inst.received(b"GET / HTTP/1.1\r\n\r\n") + self.assertEqual(inst.requests, ()) + self.assertEqual(inst.server.tasks, []) + + def test_received_preq_completed_empty(self): + inst, sock, map = self._makeOneWithMap() + inst.server = DummyServer() + preq = DummyParser() + inst.request = preq + preq.completed = True + preq.empty = True + inst.received(b"GET / HTTP/1.1\r\n\r\n") + self.assertEqual(inst.request, None) + self.assertEqual(inst.server.tasks, []) + + def test_received_preq_error(self): + inst, sock, map = self._makeOneWithMap() + inst.server = DummyServer() + preq = DummyParser() + inst.request = preq + preq.completed = True + preq.error = True + inst.received(b"GET / HTTP/1.1\r\n\r\n") + self.assertEqual(inst.request, None) + self.assertEqual(len(inst.server.tasks), 1) + self.assertTrue(inst.requests) + + def test_received_preq_completed_connection_close(self): + inst, sock, map = self._makeOneWithMap() + inst.server = DummyServer() + preq = DummyParser() + inst.request = preq + preq.completed = True + preq.empty = True + preq.connection_close = True + inst.received(b"GET / HTTP/1.1\r\n\r\n" + b"a" * 50000) + self.assertEqual(inst.request, None) + self.assertEqual(inst.server.tasks, []) + + def test_received_headers_finished_expect_continue_false(self): + inst, sock, map = self._makeOneWithMap() + inst.server = DummyServer() + preq = DummyParser() + inst.request = preq + preq.expect_continue = False + preq.headers_finished = True + preq.completed = False + preq.empty = False + preq.retval = 1 + inst.received(b"GET / HTTP/1.1\r\n\r\n") + self.assertEqual(inst.request, preq) + self.assertEqual(inst.server.tasks, []) + self.assertEqual(inst.outbufs[0].get(100), b"") + + def test_received_headers_finished_expect_continue_true(self): + inst, sock, map = self._makeOneWithMap() + inst.server = DummyServer() + preq = DummyParser() + inst.request = preq + preq.expect_continue = True + preq.headers_finished = True + preq.completed = False + preq.empty = False + inst.received(b"GET / HTTP/1.1\r\n\r\n") + self.assertEqual(inst.request, preq) + self.assertEqual(inst.server.tasks, []) + self.assertEqual(sock.sent, b"HTTP/1.1 100 Continue\r\n\r\n") + self.assertEqual(inst.sent_continue, True) + self.assertEqual(preq.completed, False) + + def test_received_headers_finished_expect_continue_true_sent_true(self): + inst, sock, map = self._makeOneWithMap() + inst.server = DummyServer() + preq = DummyParser() + inst.request = preq + preq.expect_continue = True + preq.headers_finished = True + preq.completed = False + preq.empty = False + inst.sent_continue = True + inst.received(b"GET / HTTP/1.1\r\n\r\n") + self.assertEqual(inst.request, preq) + self.assertEqual(inst.server.tasks, []) + self.assertEqual(sock.sent, b"") + self.assertEqual(inst.sent_continue, True) + self.assertEqual(preq.completed, False) + + def test_service_no_requests(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = [] + inst.service() + self.assertEqual(inst.requests, []) + self.assertTrue(inst.server.trigger_pulled) + self.assertTrue(inst.last_activity) + + def test_service_with_one_request(self): + inst, sock, map = self._makeOneWithMap() + request = DummyRequest() + inst.task_class = DummyTaskClass() + inst.requests = [request] + inst.service() + self.assertEqual(inst.requests, []) + self.assertTrue(request.serviced) + self.assertTrue(request.closed) + + def test_service_with_one_error_request(self): + inst, sock, map = self._makeOneWithMap() + request = DummyRequest() + request.error = DummyError() + inst.error_task_class = DummyTaskClass() + inst.requests = [request] + inst.service() + self.assertEqual(inst.requests, []) + self.assertTrue(request.serviced) + self.assertTrue(request.closed) + + def test_service_with_multiple_requests(self): + inst, sock, map = self._makeOneWithMap() + request1 = DummyRequest() + request2 = DummyRequest() + inst.task_class = DummyTaskClass() + inst.requests = [request1, request2] + inst.service() + self.assertEqual(inst.requests, []) + self.assertTrue(request1.serviced) + self.assertTrue(request2.serviced) + self.assertTrue(request1.closed) + self.assertTrue(request2.closed) + + def test_service_with_request_raises(self): + inst, sock, map = self._makeOneWithMap() + inst.adj.expose_tracebacks = False + inst.server = DummyServer() + request = DummyRequest() + inst.requests = [request] + inst.task_class = DummyTaskClass(ValueError) + inst.task_class.wrote_header = False + inst.error_task_class = DummyTaskClass() + inst.logger = DummyLogger() + inst.service() + self.assertTrue(request.serviced) + self.assertEqual(inst.requests, []) + self.assertEqual(len(inst.logger.exceptions), 1) + self.assertTrue(inst.server.trigger_pulled) + self.assertTrue(inst.last_activity) + self.assertFalse(inst.will_close) + self.assertEqual(inst.error_task_class.serviced, True) + self.assertTrue(request.closed) + + def test_service_with_requests_raises_already_wrote_header(self): + inst, sock, map = self._makeOneWithMap() + inst.adj.expose_tracebacks = False + inst.server = DummyServer() + request = DummyRequest() + inst.requests = [request] + inst.task_class = DummyTaskClass(ValueError) + inst.error_task_class = DummyTaskClass() + inst.logger = DummyLogger() + inst.service() + self.assertTrue(request.serviced) + self.assertEqual(inst.requests, []) + self.assertEqual(len(inst.logger.exceptions), 1) + self.assertTrue(inst.server.trigger_pulled) + self.assertTrue(inst.last_activity) + self.assertTrue(inst.close_when_flushed) + self.assertEqual(inst.error_task_class.serviced, False) + self.assertTrue(request.closed) + + def test_service_with_requests_raises_didnt_write_header_expose_tbs(self): + inst, sock, map = self._makeOneWithMap() + inst.adj.expose_tracebacks = True + inst.server = DummyServer() + request = DummyRequest() + inst.requests = [request] + inst.task_class = DummyTaskClass(ValueError) + inst.task_class.wrote_header = False + inst.error_task_class = DummyTaskClass() + inst.logger = DummyLogger() + inst.service() + self.assertTrue(request.serviced) + self.assertFalse(inst.will_close) + self.assertEqual(inst.requests, []) + self.assertEqual(len(inst.logger.exceptions), 1) + self.assertTrue(inst.server.trigger_pulled) + self.assertTrue(inst.last_activity) + self.assertEqual(inst.error_task_class.serviced, True) + self.assertTrue(request.closed) + + def test_service_with_requests_raises_didnt_write_header(self): + inst, sock, map = self._makeOneWithMap() + inst.adj.expose_tracebacks = False + inst.server = DummyServer() + request = DummyRequest() + inst.requests = [request] + inst.task_class = DummyTaskClass(ValueError) + inst.task_class.wrote_header = False + inst.logger = DummyLogger() + inst.service() + self.assertTrue(request.serviced) + self.assertEqual(inst.requests, []) + self.assertEqual(len(inst.logger.exceptions), 1) + self.assertTrue(inst.server.trigger_pulled) + self.assertTrue(inst.last_activity) + self.assertTrue(inst.close_when_flushed) + self.assertTrue(request.closed) + + def test_service_with_request_raises_disconnect(self): + from waitress.channel import ClientDisconnected + + inst, sock, map = self._makeOneWithMap() + inst.adj.expose_tracebacks = False + inst.server = DummyServer() + request = DummyRequest() + inst.requests = [request] + inst.task_class = DummyTaskClass(ClientDisconnected) + inst.error_task_class = DummyTaskClass() + inst.logger = DummyLogger() + inst.service() + self.assertTrue(request.serviced) + self.assertEqual(inst.requests, []) + self.assertEqual(len(inst.logger.infos), 1) + self.assertTrue(inst.server.trigger_pulled) + self.assertTrue(inst.last_activity) + self.assertFalse(inst.will_close) + self.assertEqual(inst.error_task_class.serviced, False) + self.assertTrue(request.closed) + + def test_service_with_request_error_raises_disconnect(self): + from waitress.channel import ClientDisconnected + + inst, sock, map = self._makeOneWithMap() + inst.adj.expose_tracebacks = False + inst.server = DummyServer() + request = DummyRequest() + err_request = DummyRequest() + inst.requests = [request] + inst.parser_class = lambda x: err_request + inst.task_class = DummyTaskClass(RuntimeError) + inst.task_class.wrote_header = False + inst.error_task_class = DummyTaskClass(ClientDisconnected) + inst.logger = DummyLogger() + inst.service() + self.assertTrue(request.serviced) + self.assertTrue(err_request.serviced) + self.assertEqual(inst.requests, []) + self.assertEqual(len(inst.logger.exceptions), 1) + self.assertEqual(len(inst.logger.infos), 0) + self.assertTrue(inst.server.trigger_pulled) + self.assertTrue(inst.last_activity) + self.assertFalse(inst.will_close) + self.assertEqual(inst.task_class.serviced, True) + self.assertEqual(inst.error_task_class.serviced, True) + self.assertTrue(request.closed) + + def test_cancel_no_requests(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = () + inst.cancel() + self.assertEqual(inst.requests, []) + + def test_cancel_with_requests(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = [None] + inst.cancel() + self.assertEqual(inst.requests, []) + + +class DummySock(object): + blocking = False + closed = False + + def __init__(self): + self.sent = b"" + + def setblocking(self, *arg): + self.blocking = True + + def fileno(self): + return 100 + + def getpeername(self): + return "127.0.0.1" + + def getsockopt(self, level, option): + return 2048 + + def close(self): + self.closed = True + + def send(self, data): + self.sent += data + return len(data) + + +class DummyLock(object): + notified = False + + def __init__(self, acquirable=True): + self.acquirable = acquirable + + def acquire(self, val): + self.val = val + self.acquired = True + return self.acquirable + + def release(self): + self.released = True + + def notify(self): + self.notified = True + + def wait(self): + self.waited = True + + def __exit__(self, type, val, traceback): + self.acquire(True) + + def __enter__(self): + pass + + +class DummyBuffer(object): + closed = False + + def __init__(self, data, toraise=None): + self.data = data + self.toraise = toraise + + def get(self, *arg): + if self.toraise: + raise self.toraise + data = self.data + self.data = b"" + return data + + def skip(self, num, x): + self.skipped = num + + def __len__(self): + return len(self.data) + + def close(self): + self.closed = True + + +class DummyAdjustments(object): + outbuf_overflow = 1048576 + outbuf_high_watermark = 1048576 + inbuf_overflow = 512000 + cleanup_interval = 900 + url_scheme = "http" + channel_timeout = 300 + log_socket_errors = True + recv_bytes = 8192 + send_bytes = 1 + expose_tracebacks = True + ident = "waitress" + max_request_header_size = 10000 + + +class DummyServer(object): + trigger_pulled = False + adj = DummyAdjustments() + + def __init__(self): + self.tasks = [] + self.active_channels = {} + + def add_task(self, task): + self.tasks.append(task) + + def pull_trigger(self): + self.trigger_pulled = True + + +class DummyParser(object): + version = 1 + data = None + completed = True + empty = False + headers_finished = False + expect_continue = False + retval = None + error = None + connection_close = False + + def received(self, data): + self.data = data + if self.retval is not None: + return self.retval + return len(data) + + +class DummyRequest(object): + error = None + path = "/" + version = "1.0" + closed = False + + def __init__(self): + self.headers = {} + + def close(self): + self.closed = True + + +class DummyLogger(object): + def __init__(self): + self.exceptions = [] + self.infos = [] + self.warnings = [] + + def info(self, msg): + self.infos.append(msg) + + def exception(self, msg): + self.exceptions.append(msg) + + +class DummyError(object): + code = "431" + reason = "Bleh" + body = "My body" + + +class DummyTaskClass(object): + wrote_header = True + close_on_finish = False + serviced = False + + def __init__(self, toraise=None): + self.toraise = toraise + + def __call__(self, channel, request): + self.request = request + return self + + def service(self): + self.serviced = True + self.request.serviced = True + if self.toraise: + raise self.toraise diff --git a/libs/waitress/tests/test_compat.py b/libs/waitress/tests/test_compat.py new file mode 100644 index 000000000..37c219303 --- /dev/null +++ b/libs/waitress/tests/test_compat.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- + +import unittest + + +class Test_unquote_bytes_to_wsgi(unittest.TestCase): + def _callFUT(self, v): + from waitress.compat import unquote_bytes_to_wsgi + + return unquote_bytes_to_wsgi(v) + + def test_highorder(self): + from waitress.compat import PY3 + + val = b"/a%C5%9B" + result = self._callFUT(val) + if PY3: # pragma: no cover + # PEP 3333 urlunquoted-latin1-decoded-bytes + self.assertEqual(result, "/aÅ\x9b") + else: # pragma: no cover + # sanity + self.assertEqual(result, b"/a\xc5\x9b") diff --git a/libs/waitress/tests/test_functional.py b/libs/waitress/tests/test_functional.py new file mode 100644 index 000000000..8f4b262fe --- /dev/null +++ b/libs/waitress/tests/test_functional.py @@ -0,0 +1,1667 @@ +import errno +import logging +import multiprocessing +import os +import signal +import socket +import string +import subprocess +import sys +import time +import unittest +from waitress import server +from waitress.compat import httplib, tobytes +from waitress.utilities import cleanup_unix_socket + +dn = os.path.dirname +here = dn(__file__) + + +class NullHandler(logging.Handler): # pragma: no cover + """A logging handler that swallows all emitted messages. + """ + + def emit(self, record): + pass + + +def start_server(app, svr, queue, **kwargs): # pragma: no cover + """Run a fixture application. + """ + logging.getLogger("waitress").addHandler(NullHandler()) + try_register_coverage() + svr(app, queue, **kwargs).run() + + +def try_register_coverage(): # pragma: no cover + # Hack around multiprocessing exiting early and not triggering coverage's + # atexit handler by always registering a signal handler + + if "COVERAGE_PROCESS_START" in os.environ: + def sigterm(*args): + sys.exit(0) + + signal.signal(signal.SIGTERM, sigterm) + + +class FixtureTcpWSGIServer(server.TcpWSGIServer): + """A version of TcpWSGIServer that relays back what it's bound to. + """ + + family = socket.AF_INET # Testing + + def __init__(self, application, queue, **kw): # pragma: no cover + # Coverage doesn't see this as it's ran in a separate process. + kw["port"] = 0 # Bind to any available port. + super(FixtureTcpWSGIServer, self).__init__(application, **kw) + host, port = self.socket.getsockname() + if os.name == "nt": + host = "127.0.0.1" + queue.put((host, port)) + + +class SubprocessTests(object): + + # For nose: all tests may be ran in separate processes. + _multiprocess_can_split_ = True + + exe = sys.executable + + server = None + + def start_subprocess(self, target, **kw): + # Spawn a server process. + self.queue = multiprocessing.Queue() + + if "COVERAGE_RCFILE" in os.environ: + os.environ["COVERAGE_PROCESS_START"] = os.environ["COVERAGE_RCFILE"] + + self.proc = multiprocessing.Process( + target=start_server, args=(target, self.server, self.queue), kwargs=kw, + ) + self.proc.start() + + if self.proc.exitcode is not None: # pragma: no cover + raise RuntimeError("%s didn't start" % str(target)) + # Get the socket the server is listening on. + self.bound_to = self.queue.get(timeout=5) + self.sock = self.create_socket() + + def stop_subprocess(self): + if self.proc.exitcode is None: + self.proc.terminate() + self.sock.close() + # This give us one FD back ... + self.queue.close() + self.proc.join() + + def assertline(self, line, status, reason, version): + v, s, r = (x.strip() for x in line.split(None, 2)) + self.assertEqual(s, tobytes(status)) + self.assertEqual(r, tobytes(reason)) + self.assertEqual(v, tobytes(version)) + + def create_socket(self): + return socket.socket(self.server.family, socket.SOCK_STREAM) + + def connect(self): + self.sock.connect(self.bound_to) + + def make_http_connection(self): + raise NotImplementedError # pragma: no cover + + def send_check_error(self, to_send): + self.sock.send(to_send) + + +class TcpTests(SubprocessTests): + + server = FixtureTcpWSGIServer + + def make_http_connection(self): + return httplib.HTTPConnection(*self.bound_to) + + +class SleepyThreadTests(TcpTests, unittest.TestCase): + # test that sleepy thread doesnt block other requests + + def setUp(self): + from waitress.tests.fixtureapps import sleepy + + self.start_subprocess(sleepy.app) + + def tearDown(self): + self.stop_subprocess() + + def test_it(self): + getline = os.path.join(here, "fixtureapps", "getline.py") + cmds = ( + [self.exe, getline, "http://%s:%d/sleepy" % self.bound_to], + [self.exe, getline, "http://%s:%d/" % self.bound_to], + ) + r, w = os.pipe() + procs = [] + for cmd in cmds: + procs.append(subprocess.Popen(cmd, stdout=w)) + time.sleep(3) + for proc in procs: + if proc.returncode is not None: # pragma: no cover + proc.terminate() + proc.wait() + # the notsleepy response should always be first returned (it sleeps + # for 2 seconds, then returns; the notsleepy response should be + # processed in the meantime) + result = os.read(r, 10000) + os.close(r) + os.close(w) + self.assertEqual(result, b"notsleepy returnedsleepy returned") + + +class EchoTests(object): + def setUp(self): + from waitress.tests.fixtureapps import echo + + self.start_subprocess( + echo.app, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"x-forwarded-for", "x-forwarded-proto"}, + clear_untrusted_proxy_headers=True, + ) + + def tearDown(self): + self.stop_subprocess() + + def _read_echo(self, fp): + from waitress.tests.fixtureapps import echo + + line, headers, body = read_http(fp) + return line, headers, echo.parse_response(body) + + def test_date_and_server(self): + to_send = "GET / HTTP/1.0\r\nContent-Length: 0\r\n\r\n" + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, echo = self._read_echo(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(headers.get("server"), "waitress") + self.assertTrue(headers.get("date")) + + def test_bad_host_header(self): + # https://corte.si/posts/code/pathod/pythonservers/index.html + to_send = "GET / HTTP/1.0\r\n Host: 0\r\n\r\n" + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "400", "Bad Request", "HTTP/1.0") + self.assertEqual(headers.get("server"), "waitress") + self.assertTrue(headers.get("date")) + + def test_send_with_body(self): + to_send = "GET / HTTP/1.0\r\nContent-Length: 5\r\n\r\n" + to_send += "hello" + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, echo = self._read_echo(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(echo.content_length, "5") + self.assertEqual(echo.body, b"hello") + + def test_send_empty_body(self): + to_send = "GET / HTTP/1.0\r\nContent-Length: 0\r\n\r\n" + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, echo = self._read_echo(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(echo.content_length, "0") + self.assertEqual(echo.body, b"") + + def test_multiple_requests_with_body(self): + orig_sock = self.sock + for x in range(3): + self.sock = self.create_socket() + self.test_send_with_body() + self.sock.close() + self.sock = orig_sock + + def test_multiple_requests_without_body(self): + orig_sock = self.sock + for x in range(3): + self.sock = self.create_socket() + self.test_send_empty_body() + self.sock.close() + self.sock = orig_sock + + def test_without_crlf(self): + data = "Echo\r\nthis\r\nplease" + s = tobytes( + "GET / HTTP/1.0\r\n" + "Connection: close\r\n" + "Content-Length: %d\r\n" + "\r\n" + "%s" % (len(data), data) + ) + self.connect() + self.sock.send(s) + fp = self.sock.makefile("rb", 0) + line, headers, echo = self._read_echo(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(int(echo.content_length), len(data)) + self.assertEqual(len(echo.body), len(data)) + self.assertEqual(echo.body, tobytes(data)) + + def test_large_body(self): + # 1024 characters. + body = "This string has 32 characters.\r\n" * 32 + s = tobytes( + "GET / HTTP/1.0\r\nContent-Length: %d\r\n\r\n%s" % (len(body), body) + ) + self.connect() + self.sock.send(s) + fp = self.sock.makefile("rb", 0) + line, headers, echo = self._read_echo(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(echo.content_length, "1024") + self.assertEqual(echo.body, tobytes(body)) + + def test_many_clients(self): + conns = [] + for n in range(50): + h = self.make_http_connection() + h.request("GET", "/", headers={"Accept": "text/plain"}) + conns.append(h) + responses = [] + for h in conns: + response = h.getresponse() + self.assertEqual(response.status, 200) + responses.append(response) + for response in responses: + response.read() + for h in conns: + h.close() + + def test_chunking_request_without_content(self): + header = tobytes("GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n") + self.connect() + self.sock.send(header) + self.sock.send(b"0\r\n\r\n") + fp = self.sock.makefile("rb", 0) + line, headers, echo = self._read_echo(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + self.assertEqual(echo.body, b"") + self.assertEqual(echo.content_length, "0") + self.assertFalse("transfer-encoding" in headers) + + def test_chunking_request_with_content(self): + control_line = b"20;\r\n" # 20 hex = 32 dec + s = b"This string has 32 characters.\r\n" + expected = s * 12 + header = tobytes("GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n") + self.connect() + self.sock.send(header) + fp = self.sock.makefile("rb", 0) + for n in range(12): + self.sock.send(control_line) + self.sock.send(s) + self.sock.send(b"\r\n") # End the chunk + self.sock.send(b"0\r\n\r\n") + line, headers, echo = self._read_echo(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + self.assertEqual(echo.body, expected) + self.assertEqual(echo.content_length, str(len(expected))) + self.assertFalse("transfer-encoding" in headers) + + def test_broken_chunked_encoding(self): + control_line = "20;\r\n" # 20 hex = 32 dec + s = "This string has 32 characters.\r\n" + to_send = "GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" + to_send += control_line + s + "\r\n" + # garbage in input + to_send += "garbage\r\n" + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + # receiver caught garbage and turned it into a 400 + self.assertline(line, "400", "Bad Request", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertEqual( + sorted(headers.keys()), ["connection", "content-length", "content-type", "date", "server"] + ) + self.assertEqual(headers["content-type"], "text/plain") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_broken_chunked_encoding_missing_chunk_end(self): + control_line = "20;\r\n" # 20 hex = 32 dec + s = "This string has 32 characters.\r\n" + to_send = "GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" + to_send += control_line + s + # garbage in input + to_send += "garbage" + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + # receiver caught garbage and turned it into a 400 + self.assertline(line, "400", "Bad Request", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertTrue(b"Chunk not properly terminated" in response_body) + self.assertEqual( + sorted(headers.keys()), ["connection", "content-length", "content-type", "date", "server"] + ) + self.assertEqual(headers["content-type"], "text/plain") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_keepalive_http_10(self): + # Handling of Keep-Alive within HTTP 1.0 + data = "Default: Don't keep me alive" + s = tobytes( + "GET / HTTP/1.0\r\nContent-Length: %d\r\n\r\n%s" % (len(data), data) + ) + self.connect() + self.sock.send(s) + response = httplib.HTTPResponse(self.sock) + response.begin() + self.assertEqual(int(response.status), 200) + connection = response.getheader("Connection", "") + # We sent no Connection: Keep-Alive header + # Connection: close (or no header) is default. + self.assertTrue(connection != "Keep-Alive") + + def test_keepalive_http10_explicit(self): + # If header Connection: Keep-Alive is explicitly sent, + # we want to keept the connection open, we also need to return + # the corresponding header + data = "Keep me alive" + s = tobytes( + "GET / HTTP/1.0\r\n" + "Connection: Keep-Alive\r\n" + "Content-Length: %d\r\n" + "\r\n" + "%s" % (len(data), data) + ) + self.connect() + self.sock.send(s) + response = httplib.HTTPResponse(self.sock) + response.begin() + self.assertEqual(int(response.status), 200) + connection = response.getheader("Connection", "") + self.assertEqual(connection, "Keep-Alive") + + def test_keepalive_http_11(self): + # Handling of Keep-Alive within HTTP 1.1 + + # All connections are kept alive, unless stated otherwise + data = "Default: Keep me alive" + s = tobytes( + "GET / HTTP/1.1\r\nContent-Length: %d\r\n\r\n%s" % (len(data), data) + ) + self.connect() + self.sock.send(s) + response = httplib.HTTPResponse(self.sock) + response.begin() + self.assertEqual(int(response.status), 200) + self.assertTrue(response.getheader("connection") != "close") + + def test_keepalive_http11_explicit(self): + # Explicitly set keep-alive + data = "Default: Keep me alive" + s = tobytes( + "GET / HTTP/1.1\r\n" + "Connection: keep-alive\r\n" + "Content-Length: %d\r\n" + "\r\n" + "%s" % (len(data), data) + ) + self.connect() + self.sock.send(s) + response = httplib.HTTPResponse(self.sock) + response.begin() + self.assertEqual(int(response.status), 200) + self.assertTrue(response.getheader("connection") != "close") + + def test_keepalive_http11_connclose(self): + # specifying Connection: close explicitly + data = "Don't keep me alive" + s = tobytes( + "GET / HTTP/1.1\r\n" + "Connection: close\r\n" + "Content-Length: %d\r\n" + "\r\n" + "%s" % (len(data), data) + ) + self.connect() + self.sock.send(s) + response = httplib.HTTPResponse(self.sock) + response.begin() + self.assertEqual(int(response.status), 200) + self.assertEqual(response.getheader("connection"), "close") + + def test_proxy_headers(self): + to_send = ( + "GET / HTTP/1.0\r\n" + "Content-Length: 0\r\n" + "Host: www.google.com:8080\r\n" + "X-Forwarded-For: 192.168.1.1\r\n" + "X-Forwarded-Proto: https\r\n" + "X-Forwarded-Port: 5000\r\n\r\n" + ) + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, echo = self._read_echo(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(headers.get("server"), "waitress") + self.assertTrue(headers.get("date")) + self.assertIsNone(echo.headers.get("X_FORWARDED_PORT")) + self.assertEqual(echo.headers["HOST"], "www.google.com:8080") + self.assertEqual(echo.scheme, "https") + self.assertEqual(echo.remote_addr, "192.168.1.1") + self.assertEqual(echo.remote_host, "192.168.1.1") + + +class PipeliningTests(object): + def setUp(self): + from waitress.tests.fixtureapps import echo + + self.start_subprocess(echo.app_body_only) + + def tearDown(self): + self.stop_subprocess() + + def test_pipelining(self): + s = ( + "GET / HTTP/1.0\r\n" + "Connection: %s\r\n" + "Content-Length: %d\r\n" + "\r\n" + "%s" + ) + to_send = b"" + count = 25 + for n in range(count): + body = "Response #%d\r\n" % (n + 1) + if n + 1 < count: + conn = "keep-alive" + else: + conn = "close" + to_send += tobytes(s % (conn, len(body), body)) + + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + for n in range(count): + expect_body = tobytes("Response #%d\r\n" % (n + 1)) + line = fp.readline() # status line + version, status, reason = (x.strip() for x in line.split(None, 2)) + headers = parse_headers(fp) + length = int(headers.get("content-length")) or None + response_body = fp.read(length) + self.assertEqual(int(status), 200) + self.assertEqual(length, len(response_body)) + self.assertEqual(response_body, expect_body) + + +class ExpectContinueTests(object): + def setUp(self): + from waitress.tests.fixtureapps import echo + + self.start_subprocess(echo.app_body_only) + + def tearDown(self): + self.stop_subprocess() + + def test_expect_continue(self): + # specifying Connection: close explicitly + data = "I have expectations" + to_send = tobytes( + "GET / HTTP/1.1\r\n" + "Connection: close\r\n" + "Content-Length: %d\r\n" + "Expect: 100-continue\r\n" + "\r\n" + "%s" % (len(data), data) + ) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line = fp.readline() # continue status line + version, status, reason = (x.strip() for x in line.split(None, 2)) + self.assertEqual(int(status), 100) + self.assertEqual(reason, b"Continue") + self.assertEqual(version, b"HTTP/1.1") + fp.readline() # blank line + line = fp.readline() # next status line + version, status, reason = (x.strip() for x in line.split(None, 2)) + headers = parse_headers(fp) + length = int(headers.get("content-length")) or None + response_body = fp.read(length) + self.assertEqual(int(status), 200) + self.assertEqual(length, len(response_body)) + self.assertEqual(response_body, tobytes(data)) + + +class BadContentLengthTests(object): + def setUp(self): + from waitress.tests.fixtureapps import badcl + + self.start_subprocess(badcl.app) + + def tearDown(self): + self.stop_subprocess() + + def test_short_body(self): + # check to see if server closes connection when body is too short + # for cl header + to_send = tobytes( + "GET /short_body HTTP/1.0\r\n" + "Connection: Keep-Alive\r\n" + "Content-Length: 0\r\n" + "\r\n" + ) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line = fp.readline() # status line + version, status, reason = (x.strip() for x in line.split(None, 2)) + headers = parse_headers(fp) + content_length = int(headers.get("content-length")) + response_body = fp.read(content_length) + self.assertEqual(int(status), 200) + self.assertNotEqual(content_length, len(response_body)) + self.assertEqual(len(response_body), content_length - 1) + self.assertEqual(response_body, tobytes("abcdefghi")) + # remote closed connection (despite keepalive header); not sure why + # first send succeeds + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_long_body(self): + # check server doesnt close connection when body is too short + # for cl header + to_send = tobytes( + "GET /long_body HTTP/1.0\r\n" + "Connection: Keep-Alive\r\n" + "Content-Length: 0\r\n" + "\r\n" + ) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line = fp.readline() # status line + version, status, reason = (x.strip() for x in line.split(None, 2)) + headers = parse_headers(fp) + content_length = int(headers.get("content-length")) or None + response_body = fp.read(content_length) + self.assertEqual(int(status), 200) + self.assertEqual(content_length, len(response_body)) + self.assertEqual(response_body, tobytes("abcdefgh")) + # remote does not close connection (keepalive header) + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line = fp.readline() # status line + version, status, reason = (x.strip() for x in line.split(None, 2)) + headers = parse_headers(fp) + content_length = int(headers.get("content-length")) or None + response_body = fp.read(content_length) + self.assertEqual(int(status), 200) + + +class NoContentLengthTests(object): + def setUp(self): + from waitress.tests.fixtureapps import nocl + + self.start_subprocess(nocl.app) + + def tearDown(self): + self.stop_subprocess() + + def test_http10_generator(self): + body = string.ascii_letters + to_send = ( + "GET / HTTP/1.0\r\n" + "Connection: Keep-Alive\r\n" + "Content-Length: %d\r\n\r\n" % len(body) + ) + to_send += body + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(headers.get("content-length"), None) + self.assertEqual(headers.get("connection"), "close") + self.assertEqual(response_body, tobytes(body)) + # remote closed connection (despite keepalive header), because + # generators cannot have a content-length divined + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_http10_list(self): + body = string.ascii_letters + to_send = ( + "GET /list HTTP/1.0\r\n" + "Connection: Keep-Alive\r\n" + "Content-Length: %d\r\n\r\n" % len(body) + ) + to_send += body + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(headers["content-length"], str(len(body))) + self.assertEqual(headers.get("connection"), "Keep-Alive") + self.assertEqual(response_body, tobytes(body)) + # remote keeps connection open because it divined the content length + # from a length-1 list + self.sock.send(to_send) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + + def test_http10_listlentwo(self): + body = string.ascii_letters + to_send = ( + "GET /list_lentwo HTTP/1.0\r\n" + "Connection: Keep-Alive\r\n" + "Content-Length: %d\r\n\r\n" % len(body) + ) + to_send += body + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(headers.get("content-length"), None) + self.assertEqual(headers.get("connection"), "close") + self.assertEqual(response_body, tobytes(body)) + # remote closed connection (despite keepalive header), because + # lists of length > 1 cannot have their content length divined + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_http11_generator(self): + body = string.ascii_letters + to_send = "GET / HTTP/1.1\r\nContent-Length: %s\r\n\r\n" % len(body) + to_send += body + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb") + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + expected = b"" + for chunk in chunks(body, 10): + expected += tobytes( + "%s\r\n%s\r\n" % (str(hex(len(chunk))[2:].upper()), chunk) + ) + expected += b"0\r\n\r\n" + self.assertEqual(response_body, expected) + # connection is always closed at the end of a chunked response + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_http11_list(self): + body = string.ascii_letters + to_send = "GET /list HTTP/1.1\r\nContent-Length: %d\r\n\r\n" % len(body) + to_send += body + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + self.assertEqual(headers["content-length"], str(len(body))) + self.assertEqual(response_body, tobytes(body)) + # remote keeps connection open because it divined the content length + # from a length-1 list + self.sock.send(to_send) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + + def test_http11_listlentwo(self): + body = string.ascii_letters + to_send = "GET /list_lentwo HTTP/1.1\r\nContent-Length: %s\r\n\r\n" % len(body) + to_send += body + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb") + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + expected = b"" + for chunk in (body[0], body[1:]): + expected += tobytes( + "%s\r\n%s\r\n" % (str(hex(len(chunk))[2:].upper()), chunk) + ) + expected += b"0\r\n\r\n" + self.assertEqual(response_body, expected) + # connection is always closed at the end of a chunked response + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + +class WriteCallbackTests(object): + def setUp(self): + from waitress.tests.fixtureapps import writecb + + self.start_subprocess(writecb.app) + + def tearDown(self): + self.stop_subprocess() + + def test_short_body(self): + # check to see if server closes connection when body is too short + # for cl header + to_send = tobytes( + "GET /short_body HTTP/1.0\r\n" + "Connection: Keep-Alive\r\n" + "Content-Length: 0\r\n" + "\r\n" + ) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + # server trusts the content-length header (5) + self.assertline(line, "200", "OK", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, 9) + self.assertNotEqual(cl, len(response_body)) + self.assertEqual(len(response_body), cl - 1) + self.assertEqual(response_body, tobytes("abcdefgh")) + # remote closed connection (despite keepalive header) + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_long_body(self): + # check server doesnt close connection when body is too long + # for cl header + to_send = tobytes( + "GET /long_body HTTP/1.0\r\n" + "Connection: Keep-Alive\r\n" + "Content-Length: 0\r\n" + "\r\n" + ) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + content_length = int(headers.get("content-length")) or None + self.assertEqual(content_length, 9) + self.assertEqual(content_length, len(response_body)) + self.assertEqual(response_body, tobytes("abcdefghi")) + # remote does not close connection (keepalive header) + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + + def test_equal_body(self): + # check server doesnt close connection when body is equal to + # cl header + to_send = tobytes( + "GET /equal_body HTTP/1.0\r\n" + "Connection: Keep-Alive\r\n" + "Content-Length: 0\r\n" + "\r\n" + ) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + content_length = int(headers.get("content-length")) or None + self.assertEqual(content_length, 9) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(content_length, len(response_body)) + self.assertEqual(response_body, tobytes("abcdefghi")) + # remote does not close connection (keepalive header) + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + + def test_no_content_length(self): + # wtf happens when there's no content-length + to_send = tobytes( + "GET /no_content_length HTTP/1.0\r\n" + "Connection: Keep-Alive\r\n" + "Content-Length: 0\r\n" + "\r\n" + ) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line = fp.readline() # status line + line, headers, response_body = read_http(fp) + content_length = headers.get("content-length") + self.assertEqual(content_length, None) + self.assertEqual(response_body, tobytes("abcdefghi")) + # remote closed connection (despite keepalive header) + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + +class TooLargeTests(object): + + toobig = 1050 + + def setUp(self): + from waitress.tests.fixtureapps import toolarge + + self.start_subprocess( + toolarge.app, max_request_header_size=1000, max_request_body_size=1000 + ) + + def tearDown(self): + self.stop_subprocess() + + def test_request_body_too_large_with_wrong_cl_http10(self): + body = "a" * self.toobig + to_send = "GET / HTTP/1.0\r\nContent-Length: 5\r\n\r\n" + to_send += body + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb") + # first request succeeds (content-length 5) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # server trusts the content-length header; no pipelining, + # so request fulfilled, extra bytes are thrown away + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_request_body_too_large_with_wrong_cl_http10_keepalive(self): + body = "a" * self.toobig + to_send = "GET / HTTP/1.0\r\nContent-Length: 5\r\nConnection: Keep-Alive\r\n\r\n" + to_send += body + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb") + # first request succeeds (content-length 5) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + line, headers, response_body = read_http(fp) + self.assertline(line, "431", "Request Header Fields Too Large", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_request_body_too_large_with_no_cl_http10(self): + body = "a" * self.toobig + to_send = "GET / HTTP/1.0\r\n\r\n" + to_send += body + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # extra bytes are thrown away (no pipelining), connection closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_request_body_too_large_with_no_cl_http10_keepalive(self): + body = "a" * self.toobig + to_send = "GET / HTTP/1.0\r\nConnection: Keep-Alive\r\n\r\n" + to_send += body + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + # server trusts the content-length header (assumed zero) + self.assertline(line, "200", "OK", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + line, headers, response_body = read_http(fp) + # next response overruns because the extra data appears to be + # header data + self.assertline(line, "431", "Request Header Fields Too Large", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_request_body_too_large_with_wrong_cl_http11(self): + body = "a" * self.toobig + to_send = "GET / HTTP/1.1\r\nContent-Length: 5\r\n\r\n" + to_send += body + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb") + # first request succeeds (content-length 5) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # second response is an error response + line, headers, response_body = read_http(fp) + self.assertline(line, "431", "Request Header Fields Too Large", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_request_body_too_large_with_wrong_cl_http11_connclose(self): + body = "a" * self.toobig + to_send = "GET / HTTP/1.1\r\nContent-Length: 5\r\nConnection: close\r\n\r\n" + to_send += body + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + # server trusts the content-length header (5) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_request_body_too_large_with_no_cl_http11(self): + body = "a" * self.toobig + to_send = "GET / HTTP/1.1\r\n\r\n" + to_send += body + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb") + # server trusts the content-length header (assumed 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # server assumes pipelined requests due to http/1.1, and the first + # request was assumed c-l 0 because it had no content-length header, + # so entire body looks like the header of the subsequent request + # second response is an error response + line, headers, response_body = read_http(fp) + self.assertline(line, "431", "Request Header Fields Too Large", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_request_body_too_large_with_no_cl_http11_connclose(self): + body = "a" * self.toobig + to_send = "GET / HTTP/1.1\r\nConnection: close\r\n\r\n" + to_send += body + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + # server trusts the content-length header (assumed 0) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_request_body_too_large_chunked_encoding(self): + control_line = "20;\r\n" # 20 hex = 32 dec + s = "This string has 32 characters.\r\n" + to_send = "GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" + repeat = control_line + s + to_send += repeat * ((self.toobig // len(repeat)) + 1) + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + # body bytes counter caught a max_request_body_size overrun + self.assertline(line, "413", "Request Entity Too Large", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertEqual(headers["content-type"], "text/plain") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + +class InternalServerErrorTests(object): + def setUp(self): + from waitress.tests.fixtureapps import error + + self.start_subprocess(error.app, expose_tracebacks=True) + + def tearDown(self): + self.stop_subprocess() + + def test_before_start_response_http_10(self): + to_send = "GET /before_start_response HTTP/1.0\r\n\r\n" + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "500", "Internal Server Error", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertTrue(response_body.startswith(b"Internal Server Error")) + self.assertEqual(headers["connection"], "close") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_before_start_response_http_11(self): + to_send = "GET /before_start_response HTTP/1.1\r\n\r\n" + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "500", "Internal Server Error", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertTrue(response_body.startswith(b"Internal Server Error")) + self.assertEqual( + sorted(headers.keys()), ["connection", "content-length", "content-type", "date", "server"] + ) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_before_start_response_http_11_close(self): + to_send = tobytes( + "GET /before_start_response HTTP/1.1\r\nConnection: close\r\n\r\n" + ) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "500", "Internal Server Error", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertTrue(response_body.startswith(b"Internal Server Error")) + self.assertEqual( + sorted(headers.keys()), + ["connection", "content-length", "content-type", "date", "server"], + ) + self.assertEqual(headers["connection"], "close") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_after_start_response_http10(self): + to_send = "GET /after_start_response HTTP/1.0\r\n\r\n" + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "500", "Internal Server Error", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertTrue(response_body.startswith(b"Internal Server Error")) + self.assertEqual( + sorted(headers.keys()), + ["connection", "content-length", "content-type", "date", "server"], + ) + self.assertEqual(headers["connection"], "close") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_after_start_response_http11(self): + to_send = "GET /after_start_response HTTP/1.1\r\n\r\n" + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "500", "Internal Server Error", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertTrue(response_body.startswith(b"Internal Server Error")) + self.assertEqual( + sorted(headers.keys()), ["connection", "content-length", "content-type", "date", "server"] + ) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_after_start_response_http11_close(self): + to_send = tobytes( + "GET /after_start_response HTTP/1.1\r\nConnection: close\r\n\r\n" + ) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "500", "Internal Server Error", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertTrue(response_body.startswith(b"Internal Server Error")) + self.assertEqual( + sorted(headers.keys()), + ["connection", "content-length", "content-type", "date", "server"], + ) + self.assertEqual(headers["connection"], "close") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_after_write_cb(self): + to_send = "GET /after_write_cb HTTP/1.1\r\n\r\n" + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + self.assertEqual(response_body, b"") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_in_generator(self): + to_send = "GET /in_generator HTTP/1.1\r\n\r\n" + to_send = tobytes(to_send) + self.connect() + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + self.assertEqual(response_body, b"") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + +class FileWrapperTests(object): + def setUp(self): + from waitress.tests.fixtureapps import filewrapper + + self.start_subprocess(filewrapper.app) + + def tearDown(self): + self.stop_subprocess() + + def test_filelike_http11(self): + to_send = "GET /filelike HTTP/1.1\r\n\r\n" + to_send = tobytes(to_send) + + self.connect() + + for t in range(0, 2): + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has not been closed + + def test_filelike_nocl_http11(self): + to_send = "GET /filelike_nocl HTTP/1.1\r\n\r\n" + to_send = tobytes(to_send) + + self.connect() + + for t in range(0, 2): + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has not been closed + + def test_filelike_shortcl_http11(self): + to_send = "GET /filelike_shortcl HTTP/1.1\r\n\r\n" + to_send = tobytes(to_send) + + self.connect() + + for t in range(0, 2): + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, 1) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377" in response_body) + # connection has not been closed + + def test_filelike_longcl_http11(self): + to_send = "GET /filelike_longcl HTTP/1.1\r\n\r\n" + to_send = tobytes(to_send) + + self.connect() + + for t in range(0, 2): + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has not been closed + + def test_notfilelike_http11(self): + to_send = "GET /notfilelike HTTP/1.1\r\n\r\n" + to_send = tobytes(to_send) + + self.connect() + + for t in range(0, 2): + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has not been closed + + def test_notfilelike_iobase_http11(self): + to_send = "GET /notfilelike_iobase HTTP/1.1\r\n\r\n" + to_send = tobytes(to_send) + + self.connect() + + for t in range(0, 2): + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has not been closed + + def test_notfilelike_nocl_http11(self): + to_send = "GET /notfilelike_nocl HTTP/1.1\r\n\r\n" + to_send = tobytes(to_send) + + self.connect() + + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has been closed (no content-length) + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_notfilelike_shortcl_http11(self): + to_send = "GET /notfilelike_shortcl HTTP/1.1\r\n\r\n" + to_send = tobytes(to_send) + + self.connect() + + for t in range(0, 2): + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, 1) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377" in response_body) + # connection has not been closed + + def test_notfilelike_longcl_http11(self): + to_send = "GET /notfilelike_longcl HTTP/1.1\r\n\r\n" + to_send = tobytes(to_send) + + self.connect() + + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body) + 10) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_filelike_http10(self): + to_send = "GET /filelike HTTP/1.0\r\n\r\n" + to_send = tobytes(to_send) + + self.connect() + + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_filelike_nocl_http10(self): + to_send = "GET /filelike_nocl HTTP/1.0\r\n\r\n" + to_send = tobytes(to_send) + + self.connect() + + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_notfilelike_http10(self): + to_send = "GET /notfilelike HTTP/1.0\r\n\r\n" + to_send = tobytes(to_send) + + self.connect() + + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_notfilelike_nocl_http10(self): + to_send = "GET /notfilelike_nocl HTTP/1.0\r\n\r\n" + to_send = tobytes(to_send) + + self.connect() + + self.sock.send(to_send) + fp = self.sock.makefile("rb", 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has been closed (no content-length) + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + +class TcpEchoTests(EchoTests, TcpTests, unittest.TestCase): + pass + + +class TcpPipeliningTests(PipeliningTests, TcpTests, unittest.TestCase): + pass + + +class TcpExpectContinueTests(ExpectContinueTests, TcpTests, unittest.TestCase): + pass + + +class TcpBadContentLengthTests(BadContentLengthTests, TcpTests, unittest.TestCase): + pass + + +class TcpNoContentLengthTests(NoContentLengthTests, TcpTests, unittest.TestCase): + pass + + +class TcpWriteCallbackTests(WriteCallbackTests, TcpTests, unittest.TestCase): + pass + + +class TcpTooLargeTests(TooLargeTests, TcpTests, unittest.TestCase): + pass + + +class TcpInternalServerErrorTests( + InternalServerErrorTests, TcpTests, unittest.TestCase +): + pass + + +class TcpFileWrapperTests(FileWrapperTests, TcpTests, unittest.TestCase): + pass + + +if hasattr(socket, "AF_UNIX"): + + class FixtureUnixWSGIServer(server.UnixWSGIServer): + """A version of UnixWSGIServer that relays back what it's bound to. + """ + + family = socket.AF_UNIX # Testing + + def __init__(self, application, queue, **kw): # pragma: no cover + # Coverage doesn't see this as it's ran in a separate process. + # To permit parallel testing, use a PID-dependent socket. + kw["unix_socket"] = "/tmp/waitress.test-%d.sock" % os.getpid() + super(FixtureUnixWSGIServer, self).__init__(application, **kw) + queue.put(self.socket.getsockname()) + + class UnixTests(SubprocessTests): + + server = FixtureUnixWSGIServer + + def make_http_connection(self): + return UnixHTTPConnection(self.bound_to) + + def stop_subprocess(self): + super(UnixTests, self).stop_subprocess() + cleanup_unix_socket(self.bound_to) + + def send_check_error(self, to_send): + # Unlike inet domain sockets, Unix domain sockets can trigger a + # 'Broken pipe' error when the socket it closed. + try: + self.sock.send(to_send) + except socket.error as exc: + self.assertEqual(get_errno(exc), errno.EPIPE) + + class UnixEchoTests(EchoTests, UnixTests, unittest.TestCase): + pass + + class UnixPipeliningTests(PipeliningTests, UnixTests, unittest.TestCase): + pass + + class UnixExpectContinueTests(ExpectContinueTests, UnixTests, unittest.TestCase): + pass + + class UnixBadContentLengthTests( + BadContentLengthTests, UnixTests, unittest.TestCase + ): + pass + + class UnixNoContentLengthTests(NoContentLengthTests, UnixTests, unittest.TestCase): + pass + + class UnixWriteCallbackTests(WriteCallbackTests, UnixTests, unittest.TestCase): + pass + + class UnixTooLargeTests(TooLargeTests, UnixTests, unittest.TestCase): + pass + + class UnixInternalServerErrorTests( + InternalServerErrorTests, UnixTests, unittest.TestCase + ): + pass + + class UnixFileWrapperTests(FileWrapperTests, UnixTests, unittest.TestCase): + pass + + +def parse_headers(fp): + """Parses only RFC2822 headers from a file pointer. + """ + headers = {} + while True: + line = fp.readline() + if line in (b"\r\n", b"\n", b""): + break + line = line.decode("iso-8859-1") + name, value = line.strip().split(":", 1) + headers[name.lower().strip()] = value.lower().strip() + return headers + + +class UnixHTTPConnection(httplib.HTTPConnection): + """Patched version of HTTPConnection that uses Unix domain sockets. + """ + + def __init__(self, path): + httplib.HTTPConnection.__init__(self, "localhost") + self.path = path + + def connect(self): + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.connect(self.path) + self.sock = sock + + +class ConnectionClosed(Exception): + pass + + +# stolen from gevent +def read_http(fp): # pragma: no cover + try: + response_line = fp.readline() + except socket.error as exc: + fp.close() + # errno 104 is ENOTRECOVERABLE, In WinSock 10054 is ECONNRESET + if get_errno(exc) in (errno.ECONNABORTED, errno.ECONNRESET, 104, 10054): + raise ConnectionClosed + raise + if not response_line: + raise ConnectionClosed + + header_lines = [] + while True: + line = fp.readline() + if line in (b"\r\n", b"\r\n", b""): + break + else: + header_lines.append(line) + headers = dict() + for x in header_lines: + x = x.strip() + if not x: + continue + key, value = x.split(b": ", 1) + key = key.decode("iso-8859-1").lower() + value = value.decode("iso-8859-1") + assert key not in headers, "%s header duplicated" % key + headers[key] = value + + if "content-length" in headers: + num = int(headers["content-length"]) + body = b"" + left = num + while left > 0: + data = fp.read(left) + if not data: + break + body += data + left -= len(data) + else: + # read until EOF + body = fp.read() + + return response_line, headers, body + + +# stolen from gevent +def get_errno(exc): # pragma: no cover + """ Get the error code out of socket.error objects. + socket.error in <2.5 does not have errno attribute + socket.error in 3.x does not allow indexing access + e.args[0] works for all. + There are cases when args[0] is not errno. + i.e. http://bugs.python.org/issue6471 + Maybe there are cases when errno is set, but it is not the first argument? + """ + try: + if exc.errno is not None: + return exc.errno + except AttributeError: + pass + try: + return exc.args[0] + except IndexError: + return None + + +def chunks(l, n): + """ Yield successive n-sized chunks from l. + """ + for i in range(0, len(l), n): + yield l[i : i + n] diff --git a/libs/waitress/tests/test_init.py b/libs/waitress/tests/test_init.py new file mode 100644 index 000000000..f9b91d762 --- /dev/null +++ b/libs/waitress/tests/test_init.py @@ -0,0 +1,51 @@ +import unittest + + +class Test_serve(unittest.TestCase): + def _callFUT(self, app, **kw): + from waitress import serve + + return serve(app, **kw) + + def test_it(self): + server = DummyServerFactory() + app = object() + result = self._callFUT(app, _server=server, _quiet=True) + self.assertEqual(server.app, app) + self.assertEqual(result, None) + self.assertEqual(server.ran, True) + + +class Test_serve_paste(unittest.TestCase): + def _callFUT(self, app, **kw): + from waitress import serve_paste + + return serve_paste(app, None, **kw) + + def test_it(self): + server = DummyServerFactory() + app = object() + result = self._callFUT(app, _server=server, _quiet=True) + self.assertEqual(server.app, app) + self.assertEqual(result, 0) + self.assertEqual(server.ran, True) + + +class DummyServerFactory(object): + ran = False + + def __call__(self, app, **kw): + self.adj = DummyAdj(kw) + self.app = app + self.kw = kw + return self + + def run(self): + self.ran = True + + +class DummyAdj(object): + verbose = False + + def __init__(self, kw): + self.__dict__.update(kw) diff --git a/libs/waitress/tests/test_parser.py b/libs/waitress/tests/test_parser.py new file mode 100644 index 000000000..91837c7fc --- /dev/null +++ b/libs/waitress/tests/test_parser.py @@ -0,0 +1,732 @@ +############################################################################## +# +# Copyright (c) 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +"""HTTP Request Parser tests +""" +import unittest + +from waitress.compat import text_, tobytes + + +class TestHTTPRequestParser(unittest.TestCase): + def setUp(self): + from waitress.parser import HTTPRequestParser + from waitress.adjustments import Adjustments + + my_adj = Adjustments() + self.parser = HTTPRequestParser(my_adj) + + def test_get_body_stream_None(self): + self.parser.body_recv = None + result = self.parser.get_body_stream() + self.assertEqual(result.getvalue(), b"") + + def test_get_body_stream_nonNone(self): + body_rcv = DummyBodyStream() + self.parser.body_rcv = body_rcv + result = self.parser.get_body_stream() + self.assertEqual(result, body_rcv) + + def test_received_get_no_headers(self): + data = b"HTTP/1.0 GET /foobar\r\n\r\n" + result = self.parser.received(data) + self.assertEqual(result, 24) + self.assertTrue(self.parser.completed) + self.assertEqual(self.parser.headers, {}) + + def test_received_bad_host_header(self): + from waitress.utilities import BadRequest + + data = b"HTTP/1.0 GET /foobar\r\n Host: foo\r\n\r\n" + result = self.parser.received(data) + self.assertEqual(result, 36) + self.assertTrue(self.parser.completed) + self.assertEqual(self.parser.error.__class__, BadRequest) + + def test_received_bad_transfer_encoding(self): + from waitress.utilities import ServerNotImplemented + + data = ( + b"GET /foobar HTTP/1.1\r\n" + b"Transfer-Encoding: foo\r\n" + b"\r\n" + b"1d;\r\n" + b"This string has 29 characters\r\n" + b"0\r\n\r\n" + ) + result = self.parser.received(data) + self.assertEqual(result, 48) + self.assertTrue(self.parser.completed) + self.assertEqual(self.parser.error.__class__, ServerNotImplemented) + + def test_received_nonsense_nothing(self): + data = b"\r\n\r\n" + result = self.parser.received(data) + self.assertEqual(result, 4) + self.assertTrue(self.parser.completed) + self.assertEqual(self.parser.headers, {}) + + def test_received_no_doublecr(self): + data = b"GET /foobar HTTP/8.4\r\n" + result = self.parser.received(data) + self.assertEqual(result, 22) + self.assertFalse(self.parser.completed) + self.assertEqual(self.parser.headers, {}) + + def test_received_already_completed(self): + self.parser.completed = True + result = self.parser.received(b"a") + self.assertEqual(result, 0) + + def test_received_cl_too_large(self): + from waitress.utilities import RequestEntityTooLarge + + self.parser.adj.max_request_body_size = 2 + data = b"GET /foobar HTTP/8.4\r\nContent-Length: 10\r\n\r\n" + result = self.parser.received(data) + self.assertEqual(result, 44) + self.assertTrue(self.parser.completed) + self.assertTrue(isinstance(self.parser.error, RequestEntityTooLarge)) + + def test_received_headers_too_large(self): + from waitress.utilities import RequestHeaderFieldsTooLarge + + self.parser.adj.max_request_header_size = 2 + data = b"GET /foobar HTTP/8.4\r\nX-Foo: 1\r\n\r\n" + result = self.parser.received(data) + self.assertEqual(result, 34) + self.assertTrue(self.parser.completed) + self.assertTrue(isinstance(self.parser.error, RequestHeaderFieldsTooLarge)) + + def test_received_body_too_large(self): + from waitress.utilities import RequestEntityTooLarge + + self.parser.adj.max_request_body_size = 2 + data = ( + b"GET /foobar HTTP/1.1\r\n" + b"Transfer-Encoding: chunked\r\n" + b"X-Foo: 1\r\n" + b"\r\n" + b"1d;\r\n" + b"This string has 29 characters\r\n" + b"0\r\n\r\n" + ) + + result = self.parser.received(data) + self.assertEqual(result, 62) + self.parser.received(data[result:]) + self.assertTrue(self.parser.completed) + self.assertTrue(isinstance(self.parser.error, RequestEntityTooLarge)) + + def test_received_error_from_parser(self): + from waitress.utilities import BadRequest + + data = ( + b"GET /foobar HTTP/1.1\r\n" + b"Transfer-Encoding: chunked\r\n" + b"X-Foo: 1\r\n" + b"\r\n" + b"garbage\r\n" + ) + # header + result = self.parser.received(data) + # body + result = self.parser.received(data[result:]) + self.assertEqual(result, 9) + self.assertTrue(self.parser.completed) + self.assertTrue(isinstance(self.parser.error, BadRequest)) + + def test_received_chunked_completed_sets_content_length(self): + data = ( + b"GET /foobar HTTP/1.1\r\n" + b"Transfer-Encoding: chunked\r\n" + b"X-Foo: 1\r\n" + b"\r\n" + b"1d;\r\n" + b"This string has 29 characters\r\n" + b"0\r\n\r\n" + ) + result = self.parser.received(data) + self.assertEqual(result, 62) + data = data[result:] + result = self.parser.received(data) + self.assertTrue(self.parser.completed) + self.assertTrue(self.parser.error is None) + self.assertEqual(self.parser.headers["CONTENT_LENGTH"], "29") + + def test_parse_header_gardenpath(self): + data = b"GET /foobar HTTP/8.4\r\nfoo: bar\r\n" + self.parser.parse_header(data) + self.assertEqual(self.parser.first_line, b"GET /foobar HTTP/8.4") + self.assertEqual(self.parser.headers["FOO"], "bar") + + def test_parse_header_no_cr_in_headerplus(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/8.4" + + try: + self.parser.parse_header(data) + except ParsingError: + pass + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_bad_content_length(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/8.4\r\ncontent-length: abc\r\n" + + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Content-Length is invalid", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_multiple_content_length(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/8.4\r\ncontent-length: 10\r\ncontent-length: 20\r\n" + + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Content-Length is invalid", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_11_te_chunked(self): + # NB: test that capitalization of header value is unimportant + data = b"GET /foobar HTTP/1.1\r\ntransfer-encoding: ChUnKed\r\n" + self.parser.parse_header(data) + self.assertEqual(self.parser.body_rcv.__class__.__name__, "ChunkedReceiver") + + def test_parse_header_transfer_encoding_invalid(self): + from waitress.parser import TransferEncodingNotImplemented + + data = b"GET /foobar HTTP/1.1\r\ntransfer-encoding: gzip\r\n" + + try: + self.parser.parse_header(data) + except TransferEncodingNotImplemented as e: + self.assertIn("Transfer-Encoding requested is not supported.", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_transfer_encoding_invalid_multiple(self): + from waitress.parser import TransferEncodingNotImplemented + + data = b"GET /foobar HTTP/1.1\r\ntransfer-encoding: gzip\r\ntransfer-encoding: chunked\r\n" + + try: + self.parser.parse_header(data) + except TransferEncodingNotImplemented as e: + self.assertIn("Transfer-Encoding requested is not supported.", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_transfer_encoding_invalid_whitespace(self): + from waitress.parser import TransferEncodingNotImplemented + + data = b"GET /foobar HTTP/1.1\r\nTransfer-Encoding:\x85chunked\r\n" + + try: + self.parser.parse_header(data) + except TransferEncodingNotImplemented as e: + self.assertIn("Transfer-Encoding requested is not supported.", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_transfer_encoding_invalid_unicode(self): + from waitress.parser import TransferEncodingNotImplemented + + # This is the binary encoding for the UTF-8 character + # https://www.compart.com/en/unicode/U+212A "unicode character "K"" + # which if waitress were to accidentally do the wrong thing get + # lowercased to just the ascii "k" due to unicode collisions during + # transformation + data = b"GET /foobar HTTP/1.1\r\nTransfer-Encoding: chun\xe2\x84\xaaed\r\n" + + try: + self.parser.parse_header(data) + except TransferEncodingNotImplemented as e: + self.assertIn("Transfer-Encoding requested is not supported.", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_11_expect_continue(self): + data = b"GET /foobar HTTP/1.1\r\nexpect: 100-continue\r\n" + self.parser.parse_header(data) + self.assertEqual(self.parser.expect_continue, True) + + def test_parse_header_connection_close(self): + data = b"GET /foobar HTTP/1.1\r\nConnection: close\r\n" + self.parser.parse_header(data) + self.assertEqual(self.parser.connection_close, True) + + def test_close_with_body_rcv(self): + body_rcv = DummyBodyStream() + self.parser.body_rcv = body_rcv + self.parser.close() + self.assertTrue(body_rcv.closed) + + def test_close_with_no_body_rcv(self): + self.parser.body_rcv = None + self.parser.close() # doesn't raise + + def test_parse_header_lf_only(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/8.4\nfoo: bar" + + try: + self.parser.parse_header(data) + except ParsingError: + pass + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_cr_only(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/8.4\rfoo: bar" + try: + self.parser.parse_header(data) + except ParsingError: + pass + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_extra_lf_in_header(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/8.4\r\nfoo: \nbar\r\n" + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Bare CR or LF found in header line", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_extra_lf_in_first_line(self): + from waitress.parser import ParsingError + + data = b"GET /foobar\n HTTP/8.4\r\n" + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Bare CR or LF found in HTTP message", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_invalid_whitespace(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/8.4\r\nfoo : bar\r\n" + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Invalid header", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_invalid_whitespace_vtab(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/1.1\r\nfoo:\x0bbar\r\n" + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Invalid header", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_invalid_no_colon(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\nnotvalid\r\n" + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Invalid header", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_invalid_folding_spacing(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\n\t\x0bbaz\r\n" + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Invalid header", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_invalid_chars(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\nfoo: \x0bbaz\r\n" + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Invalid header", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_empty(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\nempty:\r\n" + self.parser.parse_header(data) + + self.assertIn("EMPTY", self.parser.headers) + self.assertIn("FOO", self.parser.headers) + self.assertEqual(self.parser.headers["EMPTY"], "") + self.assertEqual(self.parser.headers["FOO"], "bar") + + def test_parse_header_multiple_values(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/1.1\r\nfoo: bar, whatever, more, please, yes\r\n" + self.parser.parse_header(data) + + self.assertIn("FOO", self.parser.headers) + self.assertEqual(self.parser.headers["FOO"], "bar, whatever, more, please, yes") + + def test_parse_header_multiple_values_header_folded(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/1.1\r\nfoo: bar, whatever,\r\n more, please, yes\r\n" + self.parser.parse_header(data) + + self.assertIn("FOO", self.parser.headers) + self.assertEqual(self.parser.headers["FOO"], "bar, whatever, more, please, yes") + + def test_parse_header_multiple_values_header_folded_multiple(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/1.1\r\nfoo: bar, whatever,\r\n more\r\nfoo: please, yes\r\n" + self.parser.parse_header(data) + + self.assertIn("FOO", self.parser.headers) + self.assertEqual(self.parser.headers["FOO"], "bar, whatever, more, please, yes") + + def test_parse_header_multiple_values_extra_space(self): + # Tests errata from: https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189 + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/1.1\r\nfoo: abrowser/0.001 (C O M M E N T)\r\n" + self.parser.parse_header(data) + + self.assertIn("FOO", self.parser.headers) + self.assertEqual(self.parser.headers["FOO"], "abrowser/0.001 (C O M M E N T)") + + def test_parse_header_invalid_backtrack_bad(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\nfoo: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\x10\r\n" + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Invalid header", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_short_values(self): + from waitress.parser import ParsingError + + data = b"GET /foobar HTTP/1.1\r\none: 1\r\ntwo: 22\r\n" + self.parser.parse_header(data) + + self.assertIn("ONE", self.parser.headers) + self.assertIn("TWO", self.parser.headers) + self.assertEqual(self.parser.headers["ONE"], "1") + self.assertEqual(self.parser.headers["TWO"], "22") + + +class Test_split_uri(unittest.TestCase): + def _callFUT(self, uri): + from waitress.parser import split_uri + + ( + self.proxy_scheme, + self.proxy_netloc, + self.path, + self.query, + self.fragment, + ) = split_uri(uri) + + def test_split_uri_unquoting_unneeded(self): + self._callFUT(b"http://localhost:8080/abc def") + self.assertEqual(self.path, "/abc def") + + def test_split_uri_unquoting_needed(self): + self._callFUT(b"http://localhost:8080/abc%20def") + self.assertEqual(self.path, "/abc def") + + def test_split_url_with_query(self): + self._callFUT(b"http://localhost:8080/abc?a=1&b=2") + self.assertEqual(self.path, "/abc") + self.assertEqual(self.query, "a=1&b=2") + + def test_split_url_with_query_empty(self): + self._callFUT(b"http://localhost:8080/abc?") + self.assertEqual(self.path, "/abc") + self.assertEqual(self.query, "") + + def test_split_url_with_fragment(self): + self._callFUT(b"http://localhost:8080/#foo") + self.assertEqual(self.path, "/") + self.assertEqual(self.fragment, "foo") + + def test_split_url_https(self): + self._callFUT(b"https://localhost:8080/") + self.assertEqual(self.path, "/") + self.assertEqual(self.proxy_scheme, "https") + self.assertEqual(self.proxy_netloc, "localhost:8080") + + def test_split_uri_unicode_error_raises_parsing_error(self): + # See https://github.com/Pylons/waitress/issues/64 + from waitress.parser import ParsingError + + # Either pass or throw a ParsingError, just don't throw another type of + # exception as that will cause the connection to close badly: + try: + self._callFUT(b"/\xd0") + except ParsingError: + pass + + def test_split_uri_path(self): + self._callFUT(b"//testing/whatever") + self.assertEqual(self.path, "//testing/whatever") + self.assertEqual(self.proxy_scheme, "") + self.assertEqual(self.proxy_netloc, "") + self.assertEqual(self.query, "") + self.assertEqual(self.fragment, "") + + def test_split_uri_path_query(self): + self._callFUT(b"//testing/whatever?a=1&b=2") + self.assertEqual(self.path, "//testing/whatever") + self.assertEqual(self.proxy_scheme, "") + self.assertEqual(self.proxy_netloc, "") + self.assertEqual(self.query, "a=1&b=2") + self.assertEqual(self.fragment, "") + + def test_split_uri_path_query_fragment(self): + self._callFUT(b"//testing/whatever?a=1&b=2#fragment") + self.assertEqual(self.path, "//testing/whatever") + self.assertEqual(self.proxy_scheme, "") + self.assertEqual(self.proxy_netloc, "") + self.assertEqual(self.query, "a=1&b=2") + self.assertEqual(self.fragment, "fragment") + + +class Test_get_header_lines(unittest.TestCase): + def _callFUT(self, data): + from waitress.parser import get_header_lines + + return get_header_lines(data) + + def test_get_header_lines(self): + result = self._callFUT(b"slam\r\nslim") + self.assertEqual(result, [b"slam", b"slim"]) + + def test_get_header_lines_folded(self): + # From RFC2616: + # HTTP/1.1 header field values can be folded onto multiple lines if the + # continuation line begins with a space or horizontal tab. All linear + # white space, including folding, has the same semantics as SP. A + # recipient MAY replace any linear white space with a single SP before + # interpreting the field value or forwarding the message downstream. + + # We are just preserving the whitespace that indicates folding. + result = self._callFUT(b"slim\r\n slam") + self.assertEqual(result, [b"slim slam"]) + + def test_get_header_lines_tabbed(self): + result = self._callFUT(b"slam\r\n\tslim") + self.assertEqual(result, [b"slam\tslim"]) + + def test_get_header_lines_malformed(self): + # https://corte.si/posts/code/pathod/pythonservers/index.html + from waitress.parser import ParsingError + + self.assertRaises(ParsingError, self._callFUT, b" Host: localhost\r\n\r\n") + + +class Test_crack_first_line(unittest.TestCase): + def _callFUT(self, line): + from waitress.parser import crack_first_line + + return crack_first_line(line) + + def test_crack_first_line_matchok(self): + result = self._callFUT(b"GET / HTTP/1.0") + self.assertEqual(result, (b"GET", b"/", b"1.0")) + + def test_crack_first_line_lowercase_method(self): + from waitress.parser import ParsingError + + self.assertRaises(ParsingError, self._callFUT, b"get / HTTP/1.0") + + def test_crack_first_line_nomatch(self): + result = self._callFUT(b"GET / bleh") + self.assertEqual(result, (b"", b"", b"")) + + result = self._callFUT(b"GET /info?txtAirPlay&txtRAOP RTSP/1.0") + self.assertEqual(result, (b"", b"", b"")) + + def test_crack_first_line_missing_version(self): + result = self._callFUT(b"GET /") + self.assertEqual(result, (b"GET", b"/", b"")) + + +class TestHTTPRequestParserIntegration(unittest.TestCase): + def setUp(self): + from waitress.parser import HTTPRequestParser + from waitress.adjustments import Adjustments + + my_adj = Adjustments() + self.parser = HTTPRequestParser(my_adj) + + def feed(self, data): + parser = self.parser + + for n in range(100): # make sure we never loop forever + consumed = parser.received(data) + data = data[consumed:] + + if parser.completed: + return + raise ValueError("Looping") # pragma: no cover + + def testSimpleGET(self): + data = ( + b"GET /foobar HTTP/8.4\r\n" + b"FirstName: mickey\r\n" + b"lastname: Mouse\r\n" + b"content-length: 6\r\n" + b"\r\n" + b"Hello." + ) + parser = self.parser + self.feed(data) + self.assertTrue(parser.completed) + self.assertEqual(parser.version, "8.4") + self.assertFalse(parser.empty) + self.assertEqual( + parser.headers, + {"FIRSTNAME": "mickey", "LASTNAME": "Mouse", "CONTENT_LENGTH": "6",}, + ) + self.assertEqual(parser.path, "/foobar") + self.assertEqual(parser.command, "GET") + self.assertEqual(parser.query, "") + self.assertEqual(parser.proxy_scheme, "") + self.assertEqual(parser.proxy_netloc, "") + self.assertEqual(parser.get_body_stream().getvalue(), b"Hello.") + + def testComplexGET(self): + data = ( + b"GET /foo/a+%2B%2F%C3%A4%3D%26a%3Aint?d=b+%2B%2F%3D%26b%3Aint&c+%2B%2F%3D%26c%3Aint=6 HTTP/8.4\r\n" + b"FirstName: mickey\r\n" + b"lastname: Mouse\r\n" + b"content-length: 10\r\n" + b"\r\n" + b"Hello mickey." + ) + parser = self.parser + self.feed(data) + self.assertEqual(parser.command, "GET") + self.assertEqual(parser.version, "8.4") + self.assertFalse(parser.empty) + self.assertEqual( + parser.headers, + {"FIRSTNAME": "mickey", "LASTNAME": "Mouse", "CONTENT_LENGTH": "10"}, + ) + # path should be utf-8 encoded + self.assertEqual( + tobytes(parser.path).decode("utf-8"), + text_(b"/foo/a++/\xc3\xa4=&a:int", "utf-8"), + ) + self.assertEqual( + parser.query, "d=b+%2B%2F%3D%26b%3Aint&c+%2B%2F%3D%26c%3Aint=6" + ) + self.assertEqual(parser.get_body_stream().getvalue(), b"Hello mick") + + def testProxyGET(self): + data = ( + b"GET https://example.com:8080/foobar HTTP/8.4\r\n" + b"content-length: 6\r\n" + b"\r\n" + b"Hello." + ) + parser = self.parser + self.feed(data) + self.assertTrue(parser.completed) + self.assertEqual(parser.version, "8.4") + self.assertFalse(parser.empty) + self.assertEqual(parser.headers, {"CONTENT_LENGTH": "6"}) + self.assertEqual(parser.path, "/foobar") + self.assertEqual(parser.command, "GET") + self.assertEqual(parser.proxy_scheme, "https") + self.assertEqual(parser.proxy_netloc, "example.com:8080") + self.assertEqual(parser.command, "GET") + self.assertEqual(parser.query, "") + self.assertEqual(parser.get_body_stream().getvalue(), b"Hello.") + + def testDuplicateHeaders(self): + # Ensure that headers with the same key get concatenated as per + # RFC2616. + data = ( + b"GET /foobar HTTP/8.4\r\n" + b"x-forwarded-for: 10.11.12.13\r\n" + b"x-forwarded-for: unknown,127.0.0.1\r\n" + b"X-Forwarded_for: 255.255.255.255\r\n" + b"content-length: 6\r\n" + b"\r\n" + b"Hello." + ) + self.feed(data) + self.assertTrue(self.parser.completed) + self.assertEqual( + self.parser.headers, + { + "CONTENT_LENGTH": "6", + "X_FORWARDED_FOR": "10.11.12.13, unknown,127.0.0.1", + }, + ) + + def testSpoofedHeadersDropped(self): + data = ( + b"GET /foobar HTTP/8.4\r\n" + b"x-auth_user: bob\r\n" + b"content-length: 6\r\n" + b"\r\n" + b"Hello." + ) + self.feed(data) + self.assertTrue(self.parser.completed) + self.assertEqual(self.parser.headers, {"CONTENT_LENGTH": "6",}) + + +class DummyBodyStream(object): + def getfile(self): + return self + + def getbuf(self): + return self + + def close(self): + self.closed = True diff --git a/libs/waitress/tests/test_proxy_headers.py b/libs/waitress/tests/test_proxy_headers.py new file mode 100644 index 000000000..15b4a0828 --- /dev/null +++ b/libs/waitress/tests/test_proxy_headers.py @@ -0,0 +1,724 @@ +import unittest + +from waitress.compat import tobytes + + +class TestProxyHeadersMiddleware(unittest.TestCase): + def _makeOne(self, app, **kw): + from waitress.proxy_headers import proxy_headers_middleware + + return proxy_headers_middleware(app, **kw) + + def _callFUT(self, app, **kw): + response = DummyResponse() + environ = DummyEnviron(**kw) + + def start_response(status, response_headers): + response.status = status + response.headers = response_headers + + response.steps = list(app(environ, start_response)) + response.body = b"".join(tobytes(s) for s in response.steps) + return response + + def test_get_environment_values_w_scheme_override_untrusted(self): + inner = DummyApp() + app = self._makeOne(inner) + response = self._callFUT( + app, headers={"X_FOO": "BAR", "X_FORWARDED_PROTO": "https",} + ) + self.assertEqual(response.status, "200 OK") + self.assertEqual(inner.environ["wsgi.url_scheme"], "http") + + def test_get_environment_values_w_scheme_override_trusted(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="192.168.1.1", + trusted_proxy_headers={"x-forwarded-proto"}, + ) + response = self._callFUT( + app, + addr=["192.168.1.1", 8080], + headers={"X_FOO": "BAR", "X_FORWARDED_PROTO": "https",}, + ) + + environ = inner.environ + self.assertEqual(response.status, "200 OK") + self.assertEqual(environ["SERVER_PORT"], "443") + self.assertEqual(environ["SERVER_NAME"], "localhost") + self.assertEqual(environ["REMOTE_ADDR"], "192.168.1.1") + self.assertEqual(environ["HTTP_X_FOO"], "BAR") + + def test_get_environment_values_w_bogus_scheme_override(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="192.168.1.1", + trusted_proxy_headers={"x-forwarded-proto"}, + ) + response = self._callFUT( + app, + addr=["192.168.1.1", 80], + headers={ + "X_FOO": "BAR", + "X_FORWARDED_PROTO": "http://p02n3e.com?url=http", + }, + ) + self.assertEqual(response.status, "400 Bad Request") + self.assertIn(b'Header "X-Forwarded-Proto" malformed', response.body) + + def test_get_environment_warning_other_proxy_headers(self): + inner = DummyApp() + logger = DummyLogger() + app = self._makeOne( + inner, + trusted_proxy="192.168.1.1", + trusted_proxy_count=1, + trusted_proxy_headers={"forwarded"}, + log_untrusted=True, + logger=logger, + ) + response = self._callFUT( + app, + addr=["192.168.1.1", 80], + headers={ + "X_FORWARDED_FOR": "[2001:db8::1]", + "FORWARDED": "For=198.51.100.2;host=example.com:8080;proto=https", + }, + ) + self.assertEqual(response.status, "200 OK") + + self.assertEqual(len(logger.logged), 1) + + environ = inner.environ + self.assertNotIn("HTTP_X_FORWARDED_FOR", environ) + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:8080") + self.assertEqual(environ["SERVER_PORT"], "8080") + self.assertEqual(environ["wsgi.url_scheme"], "https") + + def test_get_environment_contains_all_headers_including_untrusted(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="192.168.1.1", + trusted_proxy_count=1, + trusted_proxy_headers={"x-forwarded-by"}, + clear_untrusted=False, + ) + headers_orig = { + "X_FORWARDED_FOR": "198.51.100.2", + "X_FORWARDED_BY": "Waitress", + "X_FORWARDED_PROTO": "https", + "X_FORWARDED_HOST": "example.org", + } + response = self._callFUT( + app, addr=["192.168.1.1", 80], headers=headers_orig.copy(), + ) + self.assertEqual(response.status, "200 OK") + environ = inner.environ + for k, expected in headers_orig.items(): + result = environ["HTTP_%s" % k] + self.assertEqual(result, expected) + + def test_get_environment_contains_only_trusted_headers(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="192.168.1.1", + trusted_proxy_count=1, + trusted_proxy_headers={"x-forwarded-by"}, + clear_untrusted=True, + ) + response = self._callFUT( + app, + addr=["192.168.1.1", 80], + headers={ + "X_FORWARDED_FOR": "198.51.100.2", + "X_FORWARDED_BY": "Waitress", + "X_FORWARDED_PROTO": "https", + "X_FORWARDED_HOST": "example.org", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["HTTP_X_FORWARDED_BY"], "Waitress") + self.assertNotIn("HTTP_X_FORWARDED_FOR", environ) + self.assertNotIn("HTTP_X_FORWARDED_PROTO", environ) + self.assertNotIn("HTTP_X_FORWARDED_HOST", environ) + + def test_get_environment_clears_headers_if_untrusted_proxy(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="192.168.1.1", + trusted_proxy_count=1, + trusted_proxy_headers={"x-forwarded-by"}, + clear_untrusted=True, + ) + response = self._callFUT( + app, + addr=["192.168.1.255", 80], + headers={ + "X_FORWARDED_FOR": "198.51.100.2", + "X_FORWARDED_BY": "Waitress", + "X_FORWARDED_PROTO": "https", + "X_FORWARDED_HOST": "example.org", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertNotIn("HTTP_X_FORWARDED_BY", environ) + self.assertNotIn("HTTP_X_FORWARDED_FOR", environ) + self.assertNotIn("HTTP_X_FORWARDED_PROTO", environ) + self.assertNotIn("HTTP_X_FORWARDED_HOST", environ) + + def test_parse_proxy_headers_forwarded_for(self): + inner = DummyApp() + app = self._makeOne( + inner, trusted_proxy="*", trusted_proxy_headers={"x-forwarded-for"}, + ) + response = self._callFUT(app, headers={"X_FORWARDED_FOR": "192.0.2.1"}) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "192.0.2.1") + + def test_parse_proxy_headers_forwarded_for_v6_missing_brackets(self): + inner = DummyApp() + app = self._makeOne( + inner, trusted_proxy="*", trusted_proxy_headers={"x-forwarded-for"}, + ) + response = self._callFUT(app, headers={"X_FORWARDED_FOR": "2001:db8::0"}) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "2001:db8::0") + + def test_parse_proxy_headers_forwared_for_multiple(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=2, + trusted_proxy_headers={"x-forwarded-for"}, + ) + response = self._callFUT( + app, headers={"X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1"} + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + + def test_parse_forwarded_multiple_proxies_trust_only_two(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=2, + trusted_proxy_headers={"forwarded"}, + ) + response = self._callFUT( + app, + headers={ + "FORWARDED": ( + "For=192.0.2.1;host=fake.com, " + "For=198.51.100.2;host=example.com:8080, " + "For=203.0.113.1" + ), + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:8080") + self.assertEqual(environ["SERVER_PORT"], "8080") + + def test_parse_forwarded_multiple_proxies(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=2, + trusted_proxy_headers={"forwarded"}, + ) + response = self._callFUT( + app, + headers={ + "FORWARDED": ( + 'for="[2001:db8::1]:3821";host="example.com:8443";proto="https", ' + 'for=192.0.2.1;host="example.internal:8080"' + ), + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "2001:db8::1") + self.assertEqual(environ["REMOTE_PORT"], "3821") + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:8443") + self.assertEqual(environ["SERVER_PORT"], "8443") + self.assertEqual(environ["wsgi.url_scheme"], "https") + + def test_parse_forwarded_multiple_proxies_minimal(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=2, + trusted_proxy_headers={"forwarded"}, + ) + response = self._callFUT( + app, + headers={ + "FORWARDED": ( + 'for="[2001:db8::1]";proto="https", ' + 'for=192.0.2.1;host="example.org"' + ), + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "2001:db8::1") + self.assertEqual(environ["SERVER_NAME"], "example.org") + self.assertEqual(environ["HTTP_HOST"], "example.org") + self.assertEqual(environ["SERVER_PORT"], "443") + self.assertEqual(environ["wsgi.url_scheme"], "https") + + def test_parse_proxy_headers_forwarded_host_with_port(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=2, + trusted_proxy_headers={ + "x-forwarded-for", + "x-forwarded-proto", + "x-forwarded-host", + }, + ) + response = self._callFUT( + app, + headers={ + "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1", + "X_FORWARDED_PROTO": "http", + "X_FORWARDED_HOST": "example.com:8080", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:8080") + self.assertEqual(environ["SERVER_PORT"], "8080") + + def test_parse_proxy_headers_forwarded_host_without_port(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=2, + trusted_proxy_headers={ + "x-forwarded-for", + "x-forwarded-proto", + "x-forwarded-host", + }, + ) + response = self._callFUT( + app, + headers={ + "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1", + "X_FORWARDED_PROTO": "http", + "X_FORWARDED_HOST": "example.com", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com") + self.assertEqual(environ["SERVER_PORT"], "80") + + def test_parse_proxy_headers_forwarded_host_with_forwarded_port(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=2, + trusted_proxy_headers={ + "x-forwarded-for", + "x-forwarded-proto", + "x-forwarded-host", + "x-forwarded-port", + }, + ) + response = self._callFUT( + app, + headers={ + "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1", + "X_FORWARDED_PROTO": "http", + "X_FORWARDED_HOST": "example.com", + "X_FORWARDED_PORT": "8080", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:8080") + self.assertEqual(environ["SERVER_PORT"], "8080") + + def test_parse_proxy_headers_forwarded_host_multiple_with_forwarded_port(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=2, + trusted_proxy_headers={ + "x-forwarded-for", + "x-forwarded-proto", + "x-forwarded-host", + "x-forwarded-port", + }, + ) + response = self._callFUT( + app, + headers={ + "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1", + "X_FORWARDED_PROTO": "http", + "X_FORWARDED_HOST": "example.com, example.org", + "X_FORWARDED_PORT": "8080", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:8080") + self.assertEqual(environ["SERVER_PORT"], "8080") + + def test_parse_proxy_headers_forwarded_host_multiple_with_forwarded_port_limit_one_trusted( + self, + ): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={ + "x-forwarded-for", + "x-forwarded-proto", + "x-forwarded-host", + "x-forwarded-port", + }, + ) + response = self._callFUT( + app, + headers={ + "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1", + "X_FORWARDED_PROTO": "http", + "X_FORWARDED_HOST": "example.com, example.org", + "X_FORWARDED_PORT": "8080", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "203.0.113.1") + self.assertEqual(environ["SERVER_NAME"], "example.org") + self.assertEqual(environ["HTTP_HOST"], "example.org:8080") + self.assertEqual(environ["SERVER_PORT"], "8080") + + def test_parse_forwarded(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"forwarded"}, + ) + response = self._callFUT( + app, + headers={ + "FORWARDED": "For=198.51.100.2:5858;host=example.com:8080;proto=https", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + self.assertEqual(environ["REMOTE_PORT"], "5858") + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:8080") + self.assertEqual(environ["SERVER_PORT"], "8080") + self.assertEqual(environ["wsgi.url_scheme"], "https") + + def test_parse_forwarded_empty_pair(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"forwarded"}, + ) + response = self._callFUT( + app, headers={"FORWARDED": "For=198.51.100.2;;proto=https;by=_unused",} + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + + def test_parse_forwarded_pair_token_whitespace(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"forwarded"}, + ) + response = self._callFUT( + app, headers={"FORWARDED": "For=198.51.100.2; proto =https",} + ) + self.assertEqual(response.status, "400 Bad Request") + self.assertIn(b'Header "Forwarded" malformed', response.body) + + def test_parse_forwarded_pair_value_whitespace(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"forwarded"}, + ) + response = self._callFUT( + app, headers={"FORWARDED": 'For= "198.51.100.2"; proto =https',} + ) + self.assertEqual(response.status, "400 Bad Request") + self.assertIn(b'Header "Forwarded" malformed', response.body) + + def test_parse_forwarded_pair_no_equals(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"forwarded"}, + ) + response = self._callFUT(app, headers={"FORWARDED": "For"}) + self.assertEqual(response.status, "400 Bad Request") + self.assertIn(b'Header "Forwarded" malformed', response.body) + + def test_parse_forwarded_warning_unknown_token(self): + inner = DummyApp() + logger = DummyLogger() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"forwarded"}, + logger=logger, + ) + response = self._callFUT( + app, + headers={ + "FORWARDED": ( + "For=198.51.100.2;host=example.com:8080;proto=https;" + 'unknown="yolo"' + ), + }, + ) + self.assertEqual(response.status, "200 OK") + + self.assertEqual(len(logger.logged), 1) + self.assertIn("Unknown Forwarded token", logger.logged[0]) + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:8080") + self.assertEqual(environ["SERVER_PORT"], "8080") + self.assertEqual(environ["wsgi.url_scheme"], "https") + + def test_parse_no_valid_proxy_headers(self): + inner = DummyApp() + app = self._makeOne(inner, trusted_proxy="*", trusted_proxy_count=1,) + response = self._callFUT( + app, + headers={ + "X_FORWARDED_FOR": "198.51.100.2", + "FORWARDED": "For=198.51.100.2;host=example.com:8080;proto=https", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "127.0.0.1") + self.assertEqual(environ["SERVER_NAME"], "localhost") + self.assertEqual(environ["HTTP_HOST"], "192.168.1.1:80") + self.assertEqual(environ["SERVER_PORT"], "8080") + self.assertEqual(environ["wsgi.url_scheme"], "http") + + def test_parse_multiple_x_forwarded_proto(self): + inner = DummyApp() + logger = DummyLogger() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"x-forwarded-proto"}, + logger=logger, + ) + response = self._callFUT(app, headers={"X_FORWARDED_PROTO": "http, https",}) + self.assertEqual(response.status, "400 Bad Request") + self.assertIn(b'Header "X-Forwarded-Proto" malformed', response.body) + + def test_parse_multiple_x_forwarded_port(self): + inner = DummyApp() + logger = DummyLogger() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"x-forwarded-port"}, + logger=logger, + ) + response = self._callFUT(app, headers={"X_FORWARDED_PORT": "443, 80",}) + self.assertEqual(response.status, "400 Bad Request") + self.assertIn(b'Header "X-Forwarded-Port" malformed', response.body) + + def test_parse_forwarded_port_wrong_proto_port_80(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={ + "x-forwarded-port", + "x-forwarded-host", + "x-forwarded-proto", + }, + ) + response = self._callFUT( + app, + headers={ + "X_FORWARDED_PORT": "80", + "X_FORWARDED_PROTO": "https", + "X_FORWARDED_HOST": "example.com", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:80") + self.assertEqual(environ["SERVER_PORT"], "80") + self.assertEqual(environ["wsgi.url_scheme"], "https") + + def test_parse_forwarded_port_wrong_proto_port_443(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={ + "x-forwarded-port", + "x-forwarded-host", + "x-forwarded-proto", + }, + ) + response = self._callFUT( + app, + headers={ + "X_FORWARDED_PORT": "443", + "X_FORWARDED_PROTO": "http", + "X_FORWARDED_HOST": "example.com", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:443") + self.assertEqual(environ["SERVER_PORT"], "443") + self.assertEqual(environ["wsgi.url_scheme"], "http") + + def test_parse_forwarded_for_bad_quote(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"x-forwarded-for"}, + ) + response = self._callFUT(app, headers={"X_FORWARDED_FOR": '"foo'}) + self.assertEqual(response.status, "400 Bad Request") + self.assertIn(b'Header "X-Forwarded-For" malformed', response.body) + + def test_parse_forwarded_host_bad_quote(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"x-forwarded-host"}, + ) + response = self._callFUT(app, headers={"X_FORWARDED_HOST": '"foo'}) + self.assertEqual(response.status, "400 Bad Request") + self.assertIn(b'Header "X-Forwarded-Host" malformed', response.body) + + +class DummyLogger(object): + def __init__(self): + self.logged = [] + + def warning(self, msg, *args): + self.logged.append(msg % args) + + +class DummyApp(object): + def __call__(self, environ, start_response): + self.environ = environ + start_response("200 OK", [("Content-Type", "text/plain")]) + yield "hello" + + +class DummyResponse(object): + status = None + headers = None + body = None + + +def DummyEnviron( + addr=("127.0.0.1", 8080), scheme="http", server="localhost", headers=None, +): + environ = { + "REMOTE_ADDR": addr[0], + "REMOTE_HOST": addr[0], + "REMOTE_PORT": addr[1], + "SERVER_PORT": str(addr[1]), + "SERVER_NAME": server, + "wsgi.url_scheme": scheme, + "HTTP_HOST": "192.168.1.1:80", + } + if headers: + environ.update( + { + "HTTP_" + key.upper().replace("-", "_"): value + for key, value in headers.items() + } + ) + return environ diff --git a/libs/waitress/tests/test_receiver.py b/libs/waitress/tests/test_receiver.py new file mode 100644 index 000000000..b4910bba8 --- /dev/null +++ b/libs/waitress/tests/test_receiver.py @@ -0,0 +1,242 @@ +import unittest + + +class TestFixedStreamReceiver(unittest.TestCase): + def _makeOne(self, cl, buf): + from waitress.receiver import FixedStreamReceiver + + return FixedStreamReceiver(cl, buf) + + def test_received_remain_lt_1(self): + buf = DummyBuffer() + inst = self._makeOne(0, buf) + result = inst.received("a") + self.assertEqual(result, 0) + self.assertEqual(inst.completed, True) + + def test_received_remain_lte_datalen(self): + buf = DummyBuffer() + inst = self._makeOne(1, buf) + result = inst.received("aa") + self.assertEqual(result, 1) + self.assertEqual(inst.completed, True) + self.assertEqual(inst.completed, 1) + self.assertEqual(inst.remain, 0) + self.assertEqual(buf.data, ["a"]) + + def test_received_remain_gt_datalen(self): + buf = DummyBuffer() + inst = self._makeOne(10, buf) + result = inst.received("aa") + self.assertEqual(result, 2) + self.assertEqual(inst.completed, False) + self.assertEqual(inst.remain, 8) + self.assertEqual(buf.data, ["aa"]) + + def test_getfile(self): + buf = DummyBuffer() + inst = self._makeOne(10, buf) + self.assertEqual(inst.getfile(), buf) + + def test_getbuf(self): + buf = DummyBuffer() + inst = self._makeOne(10, buf) + self.assertEqual(inst.getbuf(), buf) + + def test___len__(self): + buf = DummyBuffer(["1", "2"]) + inst = self._makeOne(10, buf) + self.assertEqual(inst.__len__(), 2) + + +class TestChunkedReceiver(unittest.TestCase): + def _makeOne(self, buf): + from waitress.receiver import ChunkedReceiver + + return ChunkedReceiver(buf) + + def test_alreadycompleted(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + inst.completed = True + result = inst.received(b"a") + self.assertEqual(result, 0) + self.assertEqual(inst.completed, True) + + def test_received_remain_gt_zero(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + inst.chunk_remainder = 100 + result = inst.received(b"a") + self.assertEqual(inst.chunk_remainder, 99) + self.assertEqual(result, 1) + self.assertEqual(inst.completed, False) + + def test_received_control_line_notfinished(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + result = inst.received(b"a") + self.assertEqual(inst.control_line, b"a") + self.assertEqual(result, 1) + self.assertEqual(inst.completed, False) + + def test_received_control_line_finished_garbage_in_input(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + result = inst.received(b"garbage\r\n") + self.assertEqual(result, 9) + self.assertTrue(inst.error) + + def test_received_control_line_finished_all_chunks_not_received(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + result = inst.received(b"a;discard\r\n") + self.assertEqual(inst.control_line, b"") + self.assertEqual(inst.chunk_remainder, 10) + self.assertEqual(inst.all_chunks_received, False) + self.assertEqual(result, 11) + self.assertEqual(inst.completed, False) + + def test_received_control_line_finished_all_chunks_received(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + result = inst.received(b"0;discard\r\n") + self.assertEqual(inst.control_line, b"") + self.assertEqual(inst.all_chunks_received, True) + self.assertEqual(result, 11) + self.assertEqual(inst.completed, False) + + def test_received_trailer_startswith_crlf(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + inst.all_chunks_received = True + result = inst.received(b"\r\n") + self.assertEqual(result, 2) + self.assertEqual(inst.completed, True) + + def test_received_trailer_startswith_lf(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + inst.all_chunks_received = True + result = inst.received(b"\n") + self.assertEqual(result, 1) + self.assertEqual(inst.completed, False) + + def test_received_trailer_not_finished(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + inst.all_chunks_received = True + result = inst.received(b"a") + self.assertEqual(result, 1) + self.assertEqual(inst.completed, False) + + def test_received_trailer_finished(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + inst.all_chunks_received = True + result = inst.received(b"abc\r\n\r\n") + self.assertEqual(inst.trailer, b"abc\r\n\r\n") + self.assertEqual(result, 7) + self.assertEqual(inst.completed, True) + + def test_getfile(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + self.assertEqual(inst.getfile(), buf) + + def test_getbuf(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + self.assertEqual(inst.getbuf(), buf) + + def test___len__(self): + buf = DummyBuffer(["1", "2"]) + inst = self._makeOne(buf) + self.assertEqual(inst.__len__(), 2) + + def test_received_chunk_is_properly_terminated(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + data = b"4\r\nWiki\r\n" + result = inst.received(data) + self.assertEqual(result, len(data)) + self.assertEqual(inst.completed, False) + self.assertEqual(buf.data[0], b"Wiki") + + def test_received_chunk_not_properly_terminated(self): + from waitress.utilities import BadRequest + + buf = DummyBuffer() + inst = self._makeOne(buf) + data = b"4\r\nWikibadchunk\r\n" + result = inst.received(data) + self.assertEqual(result, len(data)) + self.assertEqual(inst.completed, False) + self.assertEqual(buf.data[0], b"Wiki") + self.assertEqual(inst.error.__class__, BadRequest) + + def test_received_multiple_chunks(self): + from waitress.utilities import BadRequest + + buf = DummyBuffer() + inst = self._makeOne(buf) + data = ( + b"4\r\n" + b"Wiki\r\n" + b"5\r\n" + b"pedia\r\n" + b"E\r\n" + b" in\r\n" + b"\r\n" + b"chunks.\r\n" + b"0\r\n" + b"\r\n" + ) + result = inst.received(data) + self.assertEqual(result, len(data)) + self.assertEqual(inst.completed, True) + self.assertEqual(b"".join(buf.data), b"Wikipedia in\r\n\r\nchunks.") + self.assertEqual(inst.error, None) + + def test_received_multiple_chunks_split(self): + from waitress.utilities import BadRequest + + buf = DummyBuffer() + inst = self._makeOne(buf) + data1 = b"4\r\nWiki\r" + result = inst.received(data1) + self.assertEqual(result, len(data1)) + + data2 = ( + b"\n5\r\n" + b"pedia\r\n" + b"E\r\n" + b" in\r\n" + b"\r\n" + b"chunks.\r\n" + b"0\r\n" + b"\r\n" + ) + + result = inst.received(data2) + self.assertEqual(result, len(data2)) + + self.assertEqual(inst.completed, True) + self.assertEqual(b"".join(buf.data), b"Wikipedia in\r\n\r\nchunks.") + self.assertEqual(inst.error, None) + + +class DummyBuffer(object): + def __init__(self, data=None): + if data is None: + data = [] + self.data = data + + def append(self, s): + self.data.append(s) + + def getfile(self): + return self + + def __len__(self): + return len(self.data) diff --git a/libs/waitress/tests/test_regression.py b/libs/waitress/tests/test_regression.py new file mode 100644 index 000000000..3c4c6c202 --- /dev/null +++ b/libs/waitress/tests/test_regression.py @@ -0,0 +1,147 @@ +############################################################################## +# +# Copyright (c) 2005 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +"""Tests for waitress.channel maintenance logic +""" +import doctest + + +class FakeSocket: # pragma: no cover + data = "" + setblocking = lambda *_: None + close = lambda *_: None + + def __init__(self, no): + self.no = no + + def fileno(self): + return self.no + + def getpeername(self): + return ("localhost", self.no) + + def send(self, data): + self.data += data + return len(data) + + def recv(self, data): + return "data" + + +def zombies_test(): + """Regression test for HTTPChannel.maintenance method + + Bug: This method checks for channels that have been "inactive" for a + configured time. The bug was that last_activity is set at creation time + but never updated during async channel activity (reads and writes), so + any channel older than the configured timeout will be closed when a new + channel is created, regardless of activity. + + >>> import time + >>> import waitress.adjustments + >>> config = waitress.adjustments.Adjustments() + + >>> from waitress.server import HTTPServer + >>> class TestServer(HTTPServer): + ... def bind(self, (ip, port)): + ... print "Listening on %s:%d" % (ip or '*', port) + >>> sb = TestServer('127.0.0.1', 80, start=False, verbose=True) + Listening on 127.0.0.1:80 + + First we confirm the correct behavior, where a channel with no activity + for the timeout duration gets closed. + + >>> from waitress.channel import HTTPChannel + >>> socket = FakeSocket(42) + >>> channel = HTTPChannel(sb, socket, ('localhost', 42)) + + >>> channel.connected + True + + >>> channel.last_activity -= int(config.channel_timeout) + 1 + + >>> channel.next_channel_cleanup[0] = channel.creation_time - int( + ... config.cleanup_interval) - 1 + + >>> socket2 = FakeSocket(7) + >>> channel2 = HTTPChannel(sb, socket2, ('localhost', 7)) + + >>> channel.connected + False + + Write Activity + -------------- + + Now we make sure that if there is activity the channel doesn't get closed + incorrectly. + + >>> channel2.connected + True + + >>> channel2.last_activity -= int(config.channel_timeout) + 1 + + >>> channel2.handle_write() + + >>> channel2.next_channel_cleanup[0] = channel2.creation_time - int( + ... config.cleanup_interval) - 1 + + >>> socket3 = FakeSocket(3) + >>> channel3 = HTTPChannel(sb, socket3, ('localhost', 3)) + + >>> channel2.connected + True + + Read Activity + -------------- + + We should test to see that read activity will update a channel as well. + + >>> channel3.connected + True + + >>> channel3.last_activity -= int(config.channel_timeout) + 1 + + >>> import waitress.parser + >>> channel3.parser_class = ( + ... waitress.parser.HTTPRequestParser) + >>> channel3.handle_read() + + >>> channel3.next_channel_cleanup[0] = channel3.creation_time - int( + ... config.cleanup_interval) - 1 + + >>> socket4 = FakeSocket(4) + >>> channel4 = HTTPChannel(sb, socket4, ('localhost', 4)) + + >>> channel3.connected + True + + Main loop window + ---------------- + + There is also a corner case we'll do a shallow test for where a + channel can be closed waiting for the main loop. + + >>> channel4.last_activity -= 1 + + >>> last_active = channel4.last_activity + + >>> channel4.set_async() + + >>> channel4.last_activity != last_active + True + +""" + + +def test_suite(): + return doctest.DocTestSuite() diff --git a/libs/waitress/tests/test_runner.py b/libs/waitress/tests/test_runner.py new file mode 100644 index 000000000..127757e15 --- /dev/null +++ b/libs/waitress/tests/test_runner.py @@ -0,0 +1,191 @@ +import contextlib +import os +import sys + +if sys.version_info[:2] == (2, 6): # pragma: no cover + import unittest2 as unittest +else: # pragma: no cover + import unittest + +from waitress import runner + + +class Test_match(unittest.TestCase): + def test_empty(self): + self.assertRaisesRegexp( + ValueError, "^Malformed application ''$", runner.match, "" + ) + + def test_module_only(self): + self.assertRaisesRegexp( + ValueError, r"^Malformed application 'foo\.bar'$", runner.match, "foo.bar" + ) + + def test_bad_module(self): + self.assertRaisesRegexp( + ValueError, + r"^Malformed application 'foo#bar:barney'$", + runner.match, + "foo#bar:barney", + ) + + def test_module_obj(self): + self.assertTupleEqual( + runner.match("foo.bar:fred.barney"), ("foo.bar", "fred.barney") + ) + + +class Test_resolve(unittest.TestCase): + def test_bad_module(self): + self.assertRaises( + ImportError, runner.resolve, "nonexistent", "nonexistent_function" + ) + + def test_nonexistent_function(self): + self.assertRaisesRegexp( + AttributeError, + r"has no attribute 'nonexistent_function'", + runner.resolve, + "os.path", + "nonexistent_function", + ) + + def test_simple_happy_path(self): + from os.path import exists + + self.assertIs(runner.resolve("os.path", "exists"), exists) + + def test_complex_happy_path(self): + # Ensure we can recursively resolve object attributes if necessary. + self.assertEquals(runner.resolve("os.path", "exists.__name__"), "exists") + + +class Test_run(unittest.TestCase): + def match_output(self, argv, code, regex): + argv = ["waitress-serve"] + argv + with capture() as captured: + self.assertEqual(runner.run(argv=argv), code) + self.assertRegexpMatches(captured.getvalue(), regex) + captured.close() + + def test_bad(self): + self.match_output(["--bad-opt"], 1, "^Error: option --bad-opt not recognized") + + def test_help(self): + self.match_output(["--help"], 0, "^Usage:\n\n waitress-serve") + + def test_no_app(self): + self.match_output([], 1, "^Error: Specify one application only") + + def test_multiple_apps_app(self): + self.match_output(["a:a", "b:b"], 1, "^Error: Specify one application only") + + def test_bad_apps_app(self): + self.match_output(["a"], 1, "^Error: Malformed application 'a'") + + def test_bad_app_module(self): + self.match_output(["nonexistent:a"], 1, "^Error: Bad module 'nonexistent'") + + self.match_output( + ["nonexistent:a"], + 1, + ( + r"There was an exception \((ImportError|ModuleNotFoundError)\) " + "importing your module.\n\nIt had these arguments: \n" + "1. No module named '?nonexistent'?" + ), + ) + + def test_cwd_added_to_path(self): + def null_serve(app, **kw): + pass + + sys_path = sys.path + current_dir = os.getcwd() + try: + os.chdir(os.path.dirname(__file__)) + argv = [ + "waitress-serve", + "fixtureapps.runner:app", + ] + self.assertEqual(runner.run(argv=argv, _serve=null_serve), 0) + finally: + sys.path = sys_path + os.chdir(current_dir) + + def test_bad_app_object(self): + self.match_output( + ["waitress.tests.fixtureapps.runner:a"], 1, "^Error: Bad object name 'a'" + ) + + def test_simple_call(self): + import waitress.tests.fixtureapps.runner as _apps + + def check_server(app, **kw): + self.assertIs(app, _apps.app) + self.assertDictEqual(kw, {"port": "80"}) + + argv = [ + "waitress-serve", + "--port=80", + "waitress.tests.fixtureapps.runner:app", + ] + self.assertEqual(runner.run(argv=argv, _serve=check_server), 0) + + def test_returned_app(self): + import waitress.tests.fixtureapps.runner as _apps + + def check_server(app, **kw): + self.assertIs(app, _apps.app) + self.assertDictEqual(kw, {"port": "80"}) + + argv = [ + "waitress-serve", + "--port=80", + "--call", + "waitress.tests.fixtureapps.runner:returns_app", + ] + self.assertEqual(runner.run(argv=argv, _serve=check_server), 0) + + +class Test_helper(unittest.TestCase): + def test_exception_logging(self): + from waitress.runner import show_exception + + regex = ( + r"There was an exception \(ImportError\) importing your module." + r"\n\nIt had these arguments: \n1. My reason" + ) + + with capture() as captured: + try: + raise ImportError("My reason") + except ImportError: + self.assertEqual(show_exception(sys.stderr), None) + self.assertRegexpMatches(captured.getvalue(), regex) + captured.close() + + regex = ( + r"There was an exception \(ImportError\) importing your module." + r"\n\nIt had no arguments." + ) + + with capture() as captured: + try: + raise ImportError + except ImportError: + self.assertEqual(show_exception(sys.stderr), None) + self.assertRegexpMatches(captured.getvalue(), regex) + captured.close() + + +def capture(): + from waitress.compat import NativeIO + + fd = NativeIO() + sys.stdout = fd + sys.stderr = fd + yield fd + sys.stdout = sys.__stdout__ + sys.stderr = sys.__stderr__ diff --git a/libs/waitress/tests/test_server.py b/libs/waitress/tests/test_server.py new file mode 100644 index 000000000..9134fb8c1 --- /dev/null +++ b/libs/waitress/tests/test_server.py @@ -0,0 +1,533 @@ +import errno +import socket +import unittest + +dummy_app = object() + + +class TestWSGIServer(unittest.TestCase): + def _makeOne( + self, + application=dummy_app, + host="127.0.0.1", + port=0, + _dispatcher=None, + adj=None, + map=None, + _start=True, + _sock=None, + _server=None, + ): + from waitress.server import create_server + + self.inst = create_server( + application, + host=host, + port=port, + map=map, + _dispatcher=_dispatcher, + _start=_start, + _sock=_sock, + ) + return self.inst + + def _makeOneWithMap( + self, adj=None, _start=True, host="127.0.0.1", port=0, app=dummy_app + ): + sock = DummySock() + task_dispatcher = DummyTaskDispatcher() + map = {} + return self._makeOne( + app, + host=host, + port=port, + map=map, + _sock=sock, + _dispatcher=task_dispatcher, + _start=_start, + ) + + def _makeOneWithMulti( + self, adj=None, _start=True, app=dummy_app, listen="127.0.0.1:0 127.0.0.1:0" + ): + sock = DummySock() + task_dispatcher = DummyTaskDispatcher() + map = {} + from waitress.server import create_server + + self.inst = create_server( + app, + listen=listen, + map=map, + _dispatcher=task_dispatcher, + _start=_start, + _sock=sock, + ) + return self.inst + + def _makeWithSockets( + self, + application=dummy_app, + _dispatcher=None, + map=None, + _start=True, + _sock=None, + _server=None, + sockets=None, + ): + from waitress.server import create_server + + _sockets = [] + if sockets is not None: + _sockets = sockets + self.inst = create_server( + application, + map=map, + _dispatcher=_dispatcher, + _start=_start, + _sock=_sock, + sockets=_sockets, + ) + return self.inst + + def tearDown(self): + if self.inst is not None: + self.inst.close() + + def test_ctor_app_is_None(self): + self.inst = None + self.assertRaises(ValueError, self._makeOneWithMap, app=None) + + def test_ctor_start_true(self): + inst = self._makeOneWithMap(_start=True) + self.assertEqual(inst.accepting, True) + self.assertEqual(inst.socket.listened, 1024) + + def test_ctor_makes_dispatcher(self): + inst = self._makeOne(_start=False, map={}) + self.assertEqual( + inst.task_dispatcher.__class__.__name__, "ThreadedTaskDispatcher" + ) + + def test_ctor_start_false(self): + inst = self._makeOneWithMap(_start=False) + self.assertEqual(inst.accepting, False) + + def test_get_server_name_empty(self): + inst = self._makeOneWithMap(_start=False) + self.assertRaises(ValueError, inst.get_server_name, "") + + def test_get_server_name_with_ip(self): + inst = self._makeOneWithMap(_start=False) + result = inst.get_server_name("127.0.0.1") + self.assertTrue(result) + + def test_get_server_name_with_hostname(self): + inst = self._makeOneWithMap(_start=False) + result = inst.get_server_name("fred.flintstone.com") + self.assertEqual(result, "fred.flintstone.com") + + def test_get_server_name_0000(self): + inst = self._makeOneWithMap(_start=False) + result = inst.get_server_name("0.0.0.0") + self.assertTrue(len(result) != 0) + + def test_get_server_name_double_colon(self): + inst = self._makeOneWithMap(_start=False) + result = inst.get_server_name("::") + self.assertTrue(len(result) != 0) + + def test_get_server_name_ipv6(self): + inst = self._makeOneWithMap(_start=False) + result = inst.get_server_name("2001:DB8::ffff") + self.assertEqual("[2001:DB8::ffff]", result) + + def test_get_server_multi(self): + inst = self._makeOneWithMulti() + self.assertEqual(inst.__class__.__name__, "MultiSocketServer") + + def test_run(self): + inst = self._makeOneWithMap(_start=False) + inst.asyncore = DummyAsyncore() + inst.task_dispatcher = DummyTaskDispatcher() + inst.run() + self.assertTrue(inst.task_dispatcher.was_shutdown) + + def test_run_base_server(self): + inst = self._makeOneWithMulti(_start=False) + inst.asyncore = DummyAsyncore() + inst.task_dispatcher = DummyTaskDispatcher() + inst.run() + self.assertTrue(inst.task_dispatcher.was_shutdown) + + def test_pull_trigger(self): + inst = self._makeOneWithMap(_start=False) + inst.trigger.close() + inst.trigger = DummyTrigger() + inst.pull_trigger() + self.assertEqual(inst.trigger.pulled, True) + + def test_add_task(self): + task = DummyTask() + inst = self._makeOneWithMap() + inst.add_task(task) + self.assertEqual(inst.task_dispatcher.tasks, [task]) + self.assertFalse(task.serviced) + + def test_readable_not_accepting(self): + inst = self._makeOneWithMap() + inst.accepting = False + self.assertFalse(inst.readable()) + + def test_readable_maplen_gt_connection_limit(self): + inst = self._makeOneWithMap() + inst.accepting = True + inst.adj = DummyAdj + inst._map = {"a": 1, "b": 2} + self.assertFalse(inst.readable()) + + def test_readable_maplen_lt_connection_limit(self): + inst = self._makeOneWithMap() + inst.accepting = True + inst.adj = DummyAdj + inst._map = {} + self.assertTrue(inst.readable()) + + def test_readable_maintenance_false(self): + import time + + inst = self._makeOneWithMap() + then = time.time() + 1000 + inst.next_channel_cleanup = then + L = [] + inst.maintenance = lambda t: L.append(t) + inst.readable() + self.assertEqual(L, []) + self.assertEqual(inst.next_channel_cleanup, then) + + def test_readable_maintenance_true(self): + inst = self._makeOneWithMap() + inst.next_channel_cleanup = 0 + L = [] + inst.maintenance = lambda t: L.append(t) + inst.readable() + self.assertEqual(len(L), 1) + self.assertNotEqual(inst.next_channel_cleanup, 0) + + def test_writable(self): + inst = self._makeOneWithMap() + self.assertFalse(inst.writable()) + + def test_handle_read(self): + inst = self._makeOneWithMap() + self.assertEqual(inst.handle_read(), None) + + def test_handle_connect(self): + inst = self._makeOneWithMap() + self.assertEqual(inst.handle_connect(), None) + + def test_handle_accept_wouldblock_socket_error(self): + inst = self._makeOneWithMap() + ewouldblock = socket.error(errno.EWOULDBLOCK) + inst.socket = DummySock(toraise=ewouldblock) + inst.handle_accept() + self.assertEqual(inst.socket.accepted, False) + + def test_handle_accept_other_socket_error(self): + inst = self._makeOneWithMap() + eaborted = socket.error(errno.ECONNABORTED) + inst.socket = DummySock(toraise=eaborted) + inst.adj = DummyAdj + + def foo(): + raise socket.error + + inst.accept = foo + inst.logger = DummyLogger() + inst.handle_accept() + self.assertEqual(inst.socket.accepted, False) + self.assertEqual(len(inst.logger.logged), 1) + + def test_handle_accept_noerror(self): + inst = self._makeOneWithMap() + innersock = DummySock() + inst.socket = DummySock(acceptresult=(innersock, None)) + inst.adj = DummyAdj + L = [] + inst.channel_class = lambda *arg, **kw: L.append(arg) + inst.handle_accept() + self.assertEqual(inst.socket.accepted, True) + self.assertEqual(innersock.opts, [("level", "optname", "value")]) + self.assertEqual(L, [(inst, innersock, None, inst.adj)]) + + def test_maintenance(self): + inst = self._makeOneWithMap() + + class DummyChannel(object): + requests = [] + + zombie = DummyChannel() + zombie.last_activity = 0 + zombie.running_tasks = False + inst.active_channels[100] = zombie + inst.maintenance(10000) + self.assertEqual(zombie.will_close, True) + + def test_backward_compatibility(self): + from waitress.server import WSGIServer, TcpWSGIServer + from waitress.adjustments import Adjustments + + self.assertTrue(WSGIServer is TcpWSGIServer) + self.inst = WSGIServer(None, _start=False, port=1234) + # Ensure the adjustment was actually applied. + self.assertNotEqual(Adjustments.port, 1234) + self.assertEqual(self.inst.adj.port, 1234) + + def test_create_with_one_tcp_socket(self): + from waitress.server import TcpWSGIServer + + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + sockets[0].bind(("127.0.0.1", 0)) + inst = self._makeWithSockets(_start=False, sockets=sockets) + self.assertTrue(isinstance(inst, TcpWSGIServer)) + + def test_create_with_multiple_tcp_sockets(self): + from waitress.server import MultiSocketServer + + sockets = [ + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + ] + sockets[0].bind(("127.0.0.1", 0)) + sockets[1].bind(("127.0.0.1", 0)) + inst = self._makeWithSockets(_start=False, sockets=sockets) + self.assertTrue(isinstance(inst, MultiSocketServer)) + self.assertEqual(len(inst.effective_listen), 2) + + def test_create_with_one_socket_should_not_bind_socket(self): + innersock = DummySock() + sockets = [DummySock(acceptresult=(innersock, None))] + sockets[0].bind(("127.0.0.1", 80)) + sockets[0].bind_called = False + inst = self._makeWithSockets(_start=False, sockets=sockets) + self.assertEqual(inst.socket.bound, ("127.0.0.1", 80)) + self.assertFalse(inst.socket.bind_called) + + def test_create_with_one_socket_handle_accept_noerror(self): + innersock = DummySock() + sockets = [DummySock(acceptresult=(innersock, None))] + sockets[0].bind(("127.0.0.1", 80)) + inst = self._makeWithSockets(sockets=sockets) + L = [] + inst.channel_class = lambda *arg, **kw: L.append(arg) + inst.adj = DummyAdj + inst.handle_accept() + self.assertEqual(sockets[0].accepted, True) + self.assertEqual(innersock.opts, [("level", "optname", "value")]) + self.assertEqual(L, [(inst, innersock, None, inst.adj)]) + + +if hasattr(socket, "AF_UNIX"): + + class TestUnixWSGIServer(unittest.TestCase): + unix_socket = "/tmp/waitress.test.sock" + + def _makeOne(self, _start=True, _sock=None): + from waitress.server import create_server + + self.inst = create_server( + dummy_app, + map={}, + _start=_start, + _sock=_sock, + _dispatcher=DummyTaskDispatcher(), + unix_socket=self.unix_socket, + unix_socket_perms="600", + ) + return self.inst + + def _makeWithSockets( + self, + application=dummy_app, + _dispatcher=None, + map=None, + _start=True, + _sock=None, + _server=None, + sockets=None, + ): + from waitress.server import create_server + + _sockets = [] + if sockets is not None: + _sockets = sockets + self.inst = create_server( + application, + map=map, + _dispatcher=_dispatcher, + _start=_start, + _sock=_sock, + sockets=_sockets, + ) + return self.inst + + def tearDown(self): + self.inst.close() + + def _makeDummy(self, *args, **kwargs): + sock = DummySock(*args, **kwargs) + sock.family = socket.AF_UNIX + return sock + + def test_unix(self): + inst = self._makeOne(_start=False) + self.assertEqual(inst.socket.family, socket.AF_UNIX) + self.assertEqual(inst.socket.getsockname(), self.unix_socket) + + def test_handle_accept(self): + # Working on the assumption that we only have to test the happy path + # for Unix domain sockets as the other paths should've been covered + # by inet sockets. + client = self._makeDummy() + listen = self._makeDummy(acceptresult=(client, None)) + inst = self._makeOne(_sock=listen) + self.assertEqual(inst.accepting, True) + self.assertEqual(inst.socket.listened, 1024) + L = [] + inst.channel_class = lambda *arg, **kw: L.append(arg) + inst.handle_accept() + self.assertEqual(inst.socket.accepted, True) + self.assertEqual(client.opts, []) + self.assertEqual(L, [(inst, client, ("localhost", None), inst.adj)]) + + def test_creates_new_sockinfo(self): + from waitress.server import UnixWSGIServer + + self.inst = UnixWSGIServer( + dummy_app, unix_socket=self.unix_socket, unix_socket_perms="600" + ) + + self.assertEqual(self.inst.sockinfo[0], socket.AF_UNIX) + + def test_create_with_unix_socket(self): + from waitress.server import ( + MultiSocketServer, + BaseWSGIServer, + TcpWSGIServer, + UnixWSGIServer, + ) + + sockets = [ + socket.socket(socket.AF_UNIX, socket.SOCK_STREAM), + socket.socket(socket.AF_UNIX, socket.SOCK_STREAM), + ] + inst = self._makeWithSockets(sockets=sockets, _start=False) + self.assertTrue(isinstance(inst, MultiSocketServer)) + server = list( + filter(lambda s: isinstance(s, BaseWSGIServer), inst.map.values()) + ) + self.assertTrue(isinstance(server[0], UnixWSGIServer)) + self.assertTrue(isinstance(server[1], UnixWSGIServer)) + + +class DummySock(socket.socket): + accepted = False + blocking = False + family = socket.AF_INET + type = socket.SOCK_STREAM + proto = 0 + + def __init__(self, toraise=None, acceptresult=(None, None)): + self.toraise = toraise + self.acceptresult = acceptresult + self.bound = None + self.opts = [] + self.bind_called = False + + def bind(self, addr): + self.bind_called = True + self.bound = addr + + def accept(self): + if self.toraise: + raise self.toraise + self.accepted = True + return self.acceptresult + + def setblocking(self, x): + self.blocking = True + + def fileno(self): + return 10 + + def getpeername(self): + return "127.0.0.1" + + def setsockopt(self, *arg): + self.opts.append(arg) + + def getsockopt(self, *arg): + return 1 + + def listen(self, num): + self.listened = num + + def getsockname(self): + return self.bound + + def close(self): + pass + + +class DummyTaskDispatcher(object): + def __init__(self): + self.tasks = [] + + def add_task(self, task): + self.tasks.append(task) + + def shutdown(self): + self.was_shutdown = True + + +class DummyTask(object): + serviced = False + start_response_called = False + wrote_header = False + status = "200 OK" + + def __init__(self): + self.response_headers = {} + self.written = "" + + def service(self): # pragma: no cover + self.serviced = True + + +class DummyAdj: + connection_limit = 1 + log_socket_errors = True + socket_options = [("level", "optname", "value")] + cleanup_interval = 900 + channel_timeout = 300 + + +class DummyAsyncore(object): + def loop(self, timeout=30.0, use_poll=False, map=None, count=None): + raise SystemExit + + +class DummyTrigger(object): + def pull_trigger(self): + self.pulled = True + + def close(self): + pass + + +class DummyLogger(object): + def __init__(self): + self.logged = [] + + def warning(self, msg, **kw): + self.logged.append(msg) diff --git a/libs/waitress/tests/test_task.py b/libs/waitress/tests/test_task.py new file mode 100644 index 000000000..1a86245ab --- /dev/null +++ b/libs/waitress/tests/test_task.py @@ -0,0 +1,1001 @@ +import unittest +import io + + +class TestThreadedTaskDispatcher(unittest.TestCase): + def _makeOne(self): + from waitress.task import ThreadedTaskDispatcher + + return ThreadedTaskDispatcher() + + def test_handler_thread_task_raises(self): + inst = self._makeOne() + inst.threads.add(0) + inst.logger = DummyLogger() + + class BadDummyTask(DummyTask): + def service(self): + super(BadDummyTask, self).service() + inst.stop_count += 1 + raise Exception + + task = BadDummyTask() + inst.logger = DummyLogger() + inst.queue.append(task) + inst.active_count += 1 + inst.handler_thread(0) + self.assertEqual(inst.stop_count, 0) + self.assertEqual(inst.active_count, 0) + self.assertEqual(inst.threads, set()) + self.assertEqual(len(inst.logger.logged), 1) + + def test_set_thread_count_increase(self): + inst = self._makeOne() + L = [] + inst.start_new_thread = lambda *x: L.append(x) + inst.set_thread_count(1) + self.assertEqual(L, [(inst.handler_thread, (0,))]) + + def test_set_thread_count_increase_with_existing(self): + inst = self._makeOne() + L = [] + inst.threads = {0} + inst.start_new_thread = lambda *x: L.append(x) + inst.set_thread_count(2) + self.assertEqual(L, [(inst.handler_thread, (1,))]) + + def test_set_thread_count_decrease(self): + inst = self._makeOne() + inst.threads = {0, 1} + inst.set_thread_count(1) + self.assertEqual(inst.stop_count, 1) + + def test_set_thread_count_same(self): + inst = self._makeOne() + L = [] + inst.start_new_thread = lambda *x: L.append(x) + inst.threads = {0} + inst.set_thread_count(1) + self.assertEqual(L, []) + + def test_add_task_with_idle_threads(self): + task = DummyTask() + inst = self._makeOne() + inst.threads.add(0) + inst.queue_logger = DummyLogger() + inst.add_task(task) + self.assertEqual(len(inst.queue), 1) + self.assertEqual(len(inst.queue_logger.logged), 0) + + def test_add_task_with_all_busy_threads(self): + task = DummyTask() + inst = self._makeOne() + inst.queue_logger = DummyLogger() + inst.add_task(task) + self.assertEqual(len(inst.queue_logger.logged), 1) + inst.add_task(task) + self.assertEqual(len(inst.queue_logger.logged), 2) + + def test_shutdown_one_thread(self): + inst = self._makeOne() + inst.threads.add(0) + inst.logger = DummyLogger() + task = DummyTask() + inst.queue.append(task) + self.assertEqual(inst.shutdown(timeout=0.01), True) + self.assertEqual( + inst.logger.logged, + ["1 thread(s) still running", "Canceling 1 pending task(s)",], + ) + self.assertEqual(task.cancelled, True) + + def test_shutdown_no_threads(self): + inst = self._makeOne() + self.assertEqual(inst.shutdown(timeout=0.01), True) + + def test_shutdown_no_cancel_pending(self): + inst = self._makeOne() + self.assertEqual(inst.shutdown(cancel_pending=False, timeout=0.01), False) + + +class TestTask(unittest.TestCase): + def _makeOne(self, channel=None, request=None): + if channel is None: + channel = DummyChannel() + if request is None: + request = DummyParser() + from waitress.task import Task + + return Task(channel, request) + + def test_ctor_version_not_in_known(self): + request = DummyParser() + request.version = "8.4" + inst = self._makeOne(request=request) + self.assertEqual(inst.version, "1.0") + + def test_build_response_header_bad_http_version(self): + inst = self._makeOne() + inst.request = DummyParser() + inst.version = "8.4" + self.assertRaises(AssertionError, inst.build_response_header) + + def test_build_response_header_v10_keepalive_no_content_length(self): + inst = self._makeOne() + inst.request = DummyParser() + inst.request.headers["CONNECTION"] = "keep-alive" + inst.version = "1.0" + result = inst.build_response_header() + lines = filter_lines(result) + self.assertEqual(len(lines), 4) + self.assertEqual(lines[0], b"HTTP/1.0 200 OK") + self.assertEqual(lines[1], b"Connection: close") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: waitress") + self.assertEqual(inst.close_on_finish, True) + self.assertTrue(("Connection", "close") in inst.response_headers) + + def test_build_response_header_v10_keepalive_with_content_length(self): + inst = self._makeOne() + inst.request = DummyParser() + inst.request.headers["CONNECTION"] = "keep-alive" + inst.response_headers = [("Content-Length", "10")] + inst.version = "1.0" + inst.content_length = 0 + result = inst.build_response_header() + lines = filter_lines(result) + self.assertEqual(len(lines), 5) + self.assertEqual(lines[0], b"HTTP/1.0 200 OK") + self.assertEqual(lines[1], b"Connection: Keep-Alive") + self.assertEqual(lines[2], b"Content-Length: 10") + self.assertTrue(lines[3].startswith(b"Date:")) + self.assertEqual(lines[4], b"Server: waitress") + self.assertEqual(inst.close_on_finish, False) + + def test_build_response_header_v11_connection_closed_by_client(self): + inst = self._makeOne() + inst.request = DummyParser() + inst.version = "1.1" + inst.request.headers["CONNECTION"] = "close" + result = inst.build_response_header() + lines = filter_lines(result) + self.assertEqual(len(lines), 5) + self.assertEqual(lines[0], b"HTTP/1.1 200 OK") + self.assertEqual(lines[1], b"Connection: close") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: waitress") + self.assertEqual(lines[4], b"Transfer-Encoding: chunked") + self.assertTrue(("Connection", "close") in inst.response_headers) + self.assertEqual(inst.close_on_finish, True) + + def test_build_response_header_v11_connection_keepalive_by_client(self): + inst = self._makeOne() + inst.request = DummyParser() + inst.request.headers["CONNECTION"] = "keep-alive" + inst.version = "1.1" + result = inst.build_response_header() + lines = filter_lines(result) + self.assertEqual(len(lines), 5) + self.assertEqual(lines[0], b"HTTP/1.1 200 OK") + self.assertEqual(lines[1], b"Connection: close") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: waitress") + self.assertEqual(lines[4], b"Transfer-Encoding: chunked") + self.assertTrue(("Connection", "close") in inst.response_headers) + self.assertEqual(inst.close_on_finish, True) + + def test_build_response_header_v11_200_no_content_length(self): + inst = self._makeOne() + inst.request = DummyParser() + inst.version = "1.1" + result = inst.build_response_header() + lines = filter_lines(result) + self.assertEqual(len(lines), 5) + self.assertEqual(lines[0], b"HTTP/1.1 200 OK") + self.assertEqual(lines[1], b"Connection: close") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: waitress") + self.assertEqual(lines[4], b"Transfer-Encoding: chunked") + self.assertEqual(inst.close_on_finish, True) + self.assertTrue(("Connection", "close") in inst.response_headers) + + def test_build_response_header_v11_204_no_content_length_or_transfer_encoding(self): + # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length + # for any response with a status code of 1xx or 204. + inst = self._makeOne() + inst.request = DummyParser() + inst.version = "1.1" + inst.status = "204 No Content" + result = inst.build_response_header() + lines = filter_lines(result) + self.assertEqual(len(lines), 4) + self.assertEqual(lines[0], b"HTTP/1.1 204 No Content") + self.assertEqual(lines[1], b"Connection: close") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: waitress") + self.assertEqual(inst.close_on_finish, True) + self.assertTrue(("Connection", "close") in inst.response_headers) + + def test_build_response_header_v11_1xx_no_content_length_or_transfer_encoding(self): + # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length + # for any response with a status code of 1xx or 204. + inst = self._makeOne() + inst.request = DummyParser() + inst.version = "1.1" + inst.status = "100 Continue" + result = inst.build_response_header() + lines = filter_lines(result) + self.assertEqual(len(lines), 4) + self.assertEqual(lines[0], b"HTTP/1.1 100 Continue") + self.assertEqual(lines[1], b"Connection: close") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: waitress") + self.assertEqual(inst.close_on_finish, True) + self.assertTrue(("Connection", "close") in inst.response_headers) + + def test_build_response_header_v11_304_no_content_length_or_transfer_encoding(self): + # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length + # for any response with a status code of 1xx, 204 or 304. + inst = self._makeOne() + inst.request = DummyParser() + inst.version = "1.1" + inst.status = "304 Not Modified" + result = inst.build_response_header() + lines = filter_lines(result) + self.assertEqual(len(lines), 4) + self.assertEqual(lines[0], b"HTTP/1.1 304 Not Modified") + self.assertEqual(lines[1], b"Connection: close") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: waitress") + self.assertEqual(inst.close_on_finish, True) + self.assertTrue(("Connection", "close") in inst.response_headers) + + def test_build_response_header_via_added(self): + inst = self._makeOne() + inst.request = DummyParser() + inst.version = "1.0" + inst.response_headers = [("Server", "abc")] + result = inst.build_response_header() + lines = filter_lines(result) + self.assertEqual(len(lines), 5) + self.assertEqual(lines[0], b"HTTP/1.0 200 OK") + self.assertEqual(lines[1], b"Connection: close") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: abc") + self.assertEqual(lines[4], b"Via: waitress") + + def test_build_response_header_date_exists(self): + inst = self._makeOne() + inst.request = DummyParser() + inst.version = "1.0" + inst.response_headers = [("Date", "date")] + result = inst.build_response_header() + lines = filter_lines(result) + self.assertEqual(len(lines), 4) + self.assertEqual(lines[0], b"HTTP/1.0 200 OK") + self.assertEqual(lines[1], b"Connection: close") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: waitress") + + def test_build_response_header_preexisting_content_length(self): + inst = self._makeOne() + inst.request = DummyParser() + inst.version = "1.1" + inst.content_length = 100 + result = inst.build_response_header() + lines = filter_lines(result) + self.assertEqual(len(lines), 4) + self.assertEqual(lines[0], b"HTTP/1.1 200 OK") + self.assertEqual(lines[1], b"Content-Length: 100") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: waitress") + + def test_remove_content_length_header(self): + inst = self._makeOne() + inst.response_headers = [("Content-Length", "70")] + inst.remove_content_length_header() + self.assertEqual(inst.response_headers, []) + + def test_remove_content_length_header_with_other(self): + inst = self._makeOne() + inst.response_headers = [ + ("Content-Length", "70"), + ("Content-Type", "text/html"), + ] + inst.remove_content_length_header() + self.assertEqual(inst.response_headers, [("Content-Type", "text/html")]) + + def test_start(self): + inst = self._makeOne() + inst.start() + self.assertTrue(inst.start_time) + + def test_finish_didnt_write_header(self): + inst = self._makeOne() + inst.wrote_header = False + inst.complete = True + inst.finish() + self.assertTrue(inst.channel.written) + + def test_finish_wrote_header(self): + inst = self._makeOne() + inst.wrote_header = True + inst.finish() + self.assertFalse(inst.channel.written) + + def test_finish_chunked_response(self): + inst = self._makeOne() + inst.wrote_header = True + inst.chunked_response = True + inst.finish() + self.assertEqual(inst.channel.written, b"0\r\n\r\n") + + def test_write_wrote_header(self): + inst = self._makeOne() + inst.wrote_header = True + inst.complete = True + inst.content_length = 3 + inst.write(b"abc") + self.assertEqual(inst.channel.written, b"abc") + + def test_write_header_not_written(self): + inst = self._makeOne() + inst.wrote_header = False + inst.complete = True + inst.write(b"abc") + self.assertTrue(inst.channel.written) + self.assertEqual(inst.wrote_header, True) + + def test_write_start_response_uncalled(self): + inst = self._makeOne() + self.assertRaises(RuntimeError, inst.write, b"") + + def test_write_chunked_response(self): + inst = self._makeOne() + inst.wrote_header = True + inst.chunked_response = True + inst.complete = True + inst.write(b"abc") + self.assertEqual(inst.channel.written, b"3\r\nabc\r\n") + + def test_write_preexisting_content_length(self): + inst = self._makeOne() + inst.wrote_header = True + inst.complete = True + inst.content_length = 1 + inst.logger = DummyLogger() + inst.write(b"abc") + self.assertTrue(inst.channel.written) + self.assertEqual(inst.logged_write_excess, True) + self.assertEqual(len(inst.logger.logged), 1) + + +class TestWSGITask(unittest.TestCase): + def _makeOne(self, channel=None, request=None): + if channel is None: + channel = DummyChannel() + if request is None: + request = DummyParser() + from waitress.task import WSGITask + + return WSGITask(channel, request) + + def test_service(self): + inst = self._makeOne() + + def execute(): + inst.executed = True + + inst.execute = execute + inst.complete = True + inst.service() + self.assertTrue(inst.start_time) + self.assertTrue(inst.close_on_finish) + self.assertTrue(inst.channel.written) + self.assertEqual(inst.executed, True) + + def test_service_server_raises_socket_error(self): + import socket + + inst = self._makeOne() + + def execute(): + raise socket.error + + inst.execute = execute + self.assertRaises(socket.error, inst.service) + self.assertTrue(inst.start_time) + self.assertTrue(inst.close_on_finish) + self.assertFalse(inst.channel.written) + + def test_execute_app_calls_start_response_twice_wo_exc_info(self): + def app(environ, start_response): + start_response("200 OK", []) + start_response("200 OK", []) + + inst = self._makeOne() + inst.channel.server.application = app + self.assertRaises(AssertionError, inst.execute) + + def test_execute_app_calls_start_response_w_exc_info_complete(self): + def app(environ, start_response): + start_response("200 OK", [], [ValueError, ValueError(), None]) + return [b"a"] + + inst = self._makeOne() + inst.complete = True + inst.channel.server.application = app + inst.execute() + self.assertTrue(inst.complete) + self.assertEqual(inst.status, "200 OK") + self.assertTrue(inst.channel.written) + + def test_execute_app_calls_start_response_w_excinf_headers_unwritten(self): + def app(environ, start_response): + start_response("200 OK", [], [ValueError, None, None]) + return [b"a"] + + inst = self._makeOne() + inst.wrote_header = False + inst.channel.server.application = app + inst.response_headers = [("a", "b")] + inst.execute() + self.assertTrue(inst.complete) + self.assertEqual(inst.status, "200 OK") + self.assertTrue(inst.channel.written) + self.assertFalse(("a", "b") in inst.response_headers) + + def test_execute_app_calls_start_response_w_excinf_headers_written(self): + def app(environ, start_response): + start_response("200 OK", [], [ValueError, ValueError(), None]) + + inst = self._makeOne() + inst.complete = True + inst.wrote_header = True + inst.channel.server.application = app + self.assertRaises(ValueError, inst.execute) + + def test_execute_bad_header_key(self): + def app(environ, start_response): + start_response("200 OK", [(None, "a")]) + + inst = self._makeOne() + inst.channel.server.application = app + self.assertRaises(AssertionError, inst.execute) + + def test_execute_bad_header_value(self): + def app(environ, start_response): + start_response("200 OK", [("a", None)]) + + inst = self._makeOne() + inst.channel.server.application = app + self.assertRaises(AssertionError, inst.execute) + + def test_execute_hopbyhop_header(self): + def app(environ, start_response): + start_response("200 OK", [("Connection", "close")]) + + inst = self._makeOne() + inst.channel.server.application = app + self.assertRaises(AssertionError, inst.execute) + + def test_execute_bad_header_value_control_characters(self): + def app(environ, start_response): + start_response("200 OK", [("a", "\n")]) + + inst = self._makeOne() + inst.channel.server.application = app + self.assertRaises(ValueError, inst.execute) + + def test_execute_bad_header_name_control_characters(self): + def app(environ, start_response): + start_response("200 OK", [("a\r", "value")]) + + inst = self._makeOne() + inst.channel.server.application = app + self.assertRaises(ValueError, inst.execute) + + def test_execute_bad_status_control_characters(self): + def app(environ, start_response): + start_response("200 OK\r", []) + + inst = self._makeOne() + inst.channel.server.application = app + self.assertRaises(ValueError, inst.execute) + + def test_preserve_header_value_order(self): + def app(environ, start_response): + write = start_response("200 OK", [("C", "b"), ("A", "b"), ("A", "a")]) + write(b"abc") + return [] + + inst = self._makeOne() + inst.channel.server.application = app + inst.execute() + self.assertTrue(b"A: b\r\nA: a\r\nC: b\r\n" in inst.channel.written) + + def test_execute_bad_status_value(self): + def app(environ, start_response): + start_response(None, []) + + inst = self._makeOne() + inst.channel.server.application = app + self.assertRaises(AssertionError, inst.execute) + + def test_execute_with_content_length_header(self): + def app(environ, start_response): + start_response("200 OK", [("Content-Length", "1")]) + return [b"a"] + + inst = self._makeOne() + inst.channel.server.application = app + inst.execute() + self.assertEqual(inst.content_length, 1) + + def test_execute_app_calls_write(self): + def app(environ, start_response): + write = start_response("200 OK", [("Content-Length", "3")]) + write(b"abc") + return [] + + inst = self._makeOne() + inst.channel.server.application = app + inst.execute() + self.assertEqual(inst.channel.written[-3:], b"abc") + + def test_execute_app_returns_len1_chunk_without_cl(self): + def app(environ, start_response): + start_response("200 OK", []) + return [b"abc"] + + inst = self._makeOne() + inst.channel.server.application = app + inst.execute() + self.assertEqual(inst.content_length, 3) + + def test_execute_app_returns_empty_chunk_as_first(self): + def app(environ, start_response): + start_response("200 OK", []) + return ["", b"abc"] + + inst = self._makeOne() + inst.channel.server.application = app + inst.execute() + self.assertEqual(inst.content_length, None) + + def test_execute_app_returns_too_many_bytes(self): + def app(environ, start_response): + start_response("200 OK", [("Content-Length", "1")]) + return [b"abc"] + + inst = self._makeOne() + inst.channel.server.application = app + inst.logger = DummyLogger() + inst.execute() + self.assertEqual(inst.close_on_finish, True) + self.assertEqual(len(inst.logger.logged), 1) + + def test_execute_app_returns_too_few_bytes(self): + def app(environ, start_response): + start_response("200 OK", [("Content-Length", "3")]) + return [b"a"] + + inst = self._makeOne() + inst.channel.server.application = app + inst.logger = DummyLogger() + inst.execute() + self.assertEqual(inst.close_on_finish, True) + self.assertEqual(len(inst.logger.logged), 1) + + def test_execute_app_do_not_warn_on_head(self): + def app(environ, start_response): + start_response("200 OK", [("Content-Length", "3")]) + return [b""] + + inst = self._makeOne() + inst.request.command = "HEAD" + inst.channel.server.application = app + inst.logger = DummyLogger() + inst.execute() + self.assertEqual(inst.close_on_finish, True) + self.assertEqual(len(inst.logger.logged), 0) + + def test_execute_app_without_body_204_logged(self): + def app(environ, start_response): + start_response("204 No Content", [("Content-Length", "3")]) + return [b"abc"] + + inst = self._makeOne() + inst.channel.server.application = app + inst.logger = DummyLogger() + inst.execute() + self.assertEqual(inst.close_on_finish, True) + self.assertNotIn(b"abc", inst.channel.written) + self.assertNotIn(b"Content-Length", inst.channel.written) + self.assertNotIn(b"Transfer-Encoding", inst.channel.written) + self.assertEqual(len(inst.logger.logged), 1) + + def test_execute_app_without_body_304_logged(self): + def app(environ, start_response): + start_response("304 Not Modified", [("Content-Length", "3")]) + return [b"abc"] + + inst = self._makeOne() + inst.channel.server.application = app + inst.logger = DummyLogger() + inst.execute() + self.assertEqual(inst.close_on_finish, True) + self.assertNotIn(b"abc", inst.channel.written) + self.assertNotIn(b"Content-Length", inst.channel.written) + self.assertNotIn(b"Transfer-Encoding", inst.channel.written) + self.assertEqual(len(inst.logger.logged), 1) + + def test_execute_app_returns_closeable(self): + class closeable(list): + def close(self): + self.closed = True + + foo = closeable([b"abc"]) + + def app(environ, start_response): + start_response("200 OK", [("Content-Length", "3")]) + return foo + + inst = self._makeOne() + inst.channel.server.application = app + inst.execute() + self.assertEqual(foo.closed, True) + + def test_execute_app_returns_filewrapper_prepare_returns_True(self): + from waitress.buffers import ReadOnlyFileBasedBuffer + + f = io.BytesIO(b"abc") + app_iter = ReadOnlyFileBasedBuffer(f, 8192) + + def app(environ, start_response): + start_response("200 OK", [("Content-Length", "3")]) + return app_iter + + inst = self._makeOne() + inst.channel.server.application = app + inst.execute() + self.assertTrue(inst.channel.written) # header + self.assertEqual(inst.channel.otherdata, [app_iter]) + + def test_execute_app_returns_filewrapper_prepare_returns_True_nocl(self): + from waitress.buffers import ReadOnlyFileBasedBuffer + + f = io.BytesIO(b"abc") + app_iter = ReadOnlyFileBasedBuffer(f, 8192) + + def app(environ, start_response): + start_response("200 OK", []) + return app_iter + + inst = self._makeOne() + inst.channel.server.application = app + inst.execute() + self.assertTrue(inst.channel.written) # header + self.assertEqual(inst.channel.otherdata, [app_iter]) + self.assertEqual(inst.content_length, 3) + + def test_execute_app_returns_filewrapper_prepare_returns_True_badcl(self): + from waitress.buffers import ReadOnlyFileBasedBuffer + + f = io.BytesIO(b"abc") + app_iter = ReadOnlyFileBasedBuffer(f, 8192) + + def app(environ, start_response): + start_response("200 OK", []) + return app_iter + + inst = self._makeOne() + inst.channel.server.application = app + inst.content_length = 10 + inst.response_headers = [("Content-Length", "10")] + inst.execute() + self.assertTrue(inst.channel.written) # header + self.assertEqual(inst.channel.otherdata, [app_iter]) + self.assertEqual(inst.content_length, 3) + self.assertEqual(dict(inst.response_headers)["Content-Length"], "3") + + def test_get_environment_already_cached(self): + inst = self._makeOne() + inst.environ = object() + self.assertEqual(inst.get_environment(), inst.environ) + + def test_get_environment_path_startswith_more_than_one_slash(self): + inst = self._makeOne() + request = DummyParser() + request.path = "///abc" + inst.request = request + environ = inst.get_environment() + self.assertEqual(environ["PATH_INFO"], "/abc") + + def test_get_environment_path_empty(self): + inst = self._makeOne() + request = DummyParser() + request.path = "" + inst.request = request + environ = inst.get_environment() + self.assertEqual(environ["PATH_INFO"], "") + + def test_get_environment_no_query(self): + inst = self._makeOne() + request = DummyParser() + inst.request = request + environ = inst.get_environment() + self.assertEqual(environ["QUERY_STRING"], "") + + def test_get_environment_with_query(self): + inst = self._makeOne() + request = DummyParser() + request.query = "abc" + inst.request = request + environ = inst.get_environment() + self.assertEqual(environ["QUERY_STRING"], "abc") + + def test_get_environ_with_url_prefix_miss(self): + inst = self._makeOne() + inst.channel.server.adj.url_prefix = "/foo" + request = DummyParser() + request.path = "/bar" + inst.request = request + environ = inst.get_environment() + self.assertEqual(environ["PATH_INFO"], "/bar") + self.assertEqual(environ["SCRIPT_NAME"], "/foo") + + def test_get_environ_with_url_prefix_hit(self): + inst = self._makeOne() + inst.channel.server.adj.url_prefix = "/foo" + request = DummyParser() + request.path = "/foo/fuz" + inst.request = request + environ = inst.get_environment() + self.assertEqual(environ["PATH_INFO"], "/fuz") + self.assertEqual(environ["SCRIPT_NAME"], "/foo") + + def test_get_environ_with_url_prefix_empty_path(self): + inst = self._makeOne() + inst.channel.server.adj.url_prefix = "/foo" + request = DummyParser() + request.path = "/foo" + inst.request = request + environ = inst.get_environment() + self.assertEqual(environ["PATH_INFO"], "") + self.assertEqual(environ["SCRIPT_NAME"], "/foo") + + def test_get_environment_values(self): + import sys + + inst = self._makeOne() + request = DummyParser() + request.headers = { + "CONTENT_TYPE": "abc", + "CONTENT_LENGTH": "10", + "X_FOO": "BAR", + "CONNECTION": "close", + } + request.query = "abc" + inst.request = request + environ = inst.get_environment() + + # nail the keys of environ + self.assertEqual( + sorted(environ.keys()), + [ + "CONTENT_LENGTH", + "CONTENT_TYPE", + "HTTP_CONNECTION", + "HTTP_X_FOO", + "PATH_INFO", + "QUERY_STRING", + "REMOTE_ADDR", + "REMOTE_HOST", + "REMOTE_PORT", + "REQUEST_METHOD", + "SCRIPT_NAME", + "SERVER_NAME", + "SERVER_PORT", + "SERVER_PROTOCOL", + "SERVER_SOFTWARE", + "wsgi.errors", + "wsgi.file_wrapper", + "wsgi.input", + "wsgi.input_terminated", + "wsgi.multiprocess", + "wsgi.multithread", + "wsgi.run_once", + "wsgi.url_scheme", + "wsgi.version", + ], + ) + + self.assertEqual(environ["REQUEST_METHOD"], "GET") + self.assertEqual(environ["SERVER_PORT"], "80") + self.assertEqual(environ["SERVER_NAME"], "localhost") + self.assertEqual(environ["SERVER_SOFTWARE"], "waitress") + self.assertEqual(environ["SERVER_PROTOCOL"], "HTTP/1.0") + self.assertEqual(environ["SCRIPT_NAME"], "") + self.assertEqual(environ["HTTP_CONNECTION"], "close") + self.assertEqual(environ["PATH_INFO"], "/") + self.assertEqual(environ["QUERY_STRING"], "abc") + self.assertEqual(environ["REMOTE_ADDR"], "127.0.0.1") + self.assertEqual(environ["REMOTE_HOST"], "127.0.0.1") + self.assertEqual(environ["REMOTE_PORT"], "39830") + self.assertEqual(environ["CONTENT_TYPE"], "abc") + self.assertEqual(environ["CONTENT_LENGTH"], "10") + self.assertEqual(environ["HTTP_X_FOO"], "BAR") + self.assertEqual(environ["wsgi.version"], (1, 0)) + self.assertEqual(environ["wsgi.url_scheme"], "http") + self.assertEqual(environ["wsgi.errors"], sys.stderr) + self.assertEqual(environ["wsgi.multithread"], True) + self.assertEqual(environ["wsgi.multiprocess"], False) + self.assertEqual(environ["wsgi.run_once"], False) + self.assertEqual(environ["wsgi.input"], "stream") + self.assertEqual(environ["wsgi.input_terminated"], True) + self.assertEqual(inst.environ, environ) + + +class TestErrorTask(unittest.TestCase): + def _makeOne(self, channel=None, request=None): + if channel is None: + channel = DummyChannel() + if request is None: + request = DummyParser() + request.error = self._makeDummyError() + from waitress.task import ErrorTask + + return ErrorTask(channel, request) + + def _makeDummyError(self): + from waitress.utilities import Error + + e = Error("body") + e.code = 432 + e.reason = "Too Ugly" + return e + + def test_execute_http_10(self): + inst = self._makeOne() + inst.execute() + lines = filter_lines(inst.channel.written) + self.assertEqual(len(lines), 9) + self.assertEqual(lines[0], b"HTTP/1.0 432 Too Ugly") + self.assertEqual(lines[1], b"Connection: close") + self.assertEqual(lines[2], b"Content-Length: 43") + self.assertEqual(lines[3], b"Content-Type: text/plain") + self.assertTrue(lines[4]) + self.assertEqual(lines[5], b"Server: waitress") + self.assertEqual(lines[6], b"Too Ugly") + self.assertEqual(lines[7], b"body") + self.assertEqual(lines[8], b"(generated by waitress)") + + def test_execute_http_11(self): + inst = self._makeOne() + inst.version = "1.1" + inst.execute() + lines = filter_lines(inst.channel.written) + self.assertEqual(len(lines), 9) + self.assertEqual(lines[0], b"HTTP/1.1 432 Too Ugly") + self.assertEqual(lines[1], b"Connection: close") + self.assertEqual(lines[2], b"Content-Length: 43") + self.assertEqual(lines[3], b"Content-Type: text/plain") + self.assertTrue(lines[4]) + self.assertEqual(lines[5], b"Server: waitress") + self.assertEqual(lines[6], b"Too Ugly") + self.assertEqual(lines[7], b"body") + self.assertEqual(lines[8], b"(generated by waitress)") + + def test_execute_http_11_close(self): + inst = self._makeOne() + inst.version = "1.1" + inst.request.headers["CONNECTION"] = "close" + inst.execute() + lines = filter_lines(inst.channel.written) + self.assertEqual(len(lines), 9) + self.assertEqual(lines[0], b"HTTP/1.1 432 Too Ugly") + self.assertEqual(lines[1], b"Connection: close") + self.assertEqual(lines[2], b"Content-Length: 43") + self.assertEqual(lines[3], b"Content-Type: text/plain") + self.assertTrue(lines[4]) + self.assertEqual(lines[5], b"Server: waitress") + self.assertEqual(lines[6], b"Too Ugly") + self.assertEqual(lines[7], b"body") + self.assertEqual(lines[8], b"(generated by waitress)") + + def test_execute_http_11_keep_forces_close(self): + inst = self._makeOne() + inst.version = "1.1" + inst.request.headers["CONNECTION"] = "keep-alive" + inst.execute() + lines = filter_lines(inst.channel.written) + self.assertEqual(len(lines), 9) + self.assertEqual(lines[0], b"HTTP/1.1 432 Too Ugly") + self.assertEqual(lines[1], b"Connection: close") + self.assertEqual(lines[2], b"Content-Length: 43") + self.assertEqual(lines[3], b"Content-Type: text/plain") + self.assertTrue(lines[4]) + self.assertEqual(lines[5], b"Server: waitress") + self.assertEqual(lines[6], b"Too Ugly") + self.assertEqual(lines[7], b"body") + self.assertEqual(lines[8], b"(generated by waitress)") + + +class DummyTask(object): + serviced = False + cancelled = False + + def service(self): + self.serviced = True + + def cancel(self): + self.cancelled = True + + +class DummyAdj(object): + log_socket_errors = True + ident = "waitress" + host = "127.0.0.1" + port = 80 + url_prefix = "" + + +class DummyServer(object): + server_name = "localhost" + effective_port = 80 + + def __init__(self): + self.adj = DummyAdj() + + +class DummyChannel(object): + closed_when_done = False + adj = DummyAdj() + creation_time = 0 + addr = ("127.0.0.1", 39830) + + def __init__(self, server=None): + if server is None: + server = DummyServer() + self.server = server + self.written = b"" + self.otherdata = [] + + def write_soon(self, data): + if isinstance(data, bytes): + self.written += data + else: + self.otherdata.append(data) + return len(data) + + +class DummyParser(object): + version = "1.0" + command = "GET" + path = "/" + query = "" + url_scheme = "http" + expect_continue = False + headers_finished = False + + def __init__(self): + self.headers = {} + + def get_body_stream(self): + return "stream" + + +def filter_lines(s): + return list(filter(None, s.split(b"\r\n"))) + + +class DummyLogger(object): + def __init__(self): + self.logged = [] + + def warning(self, msg, *args): + self.logged.append(msg % args) + + def exception(self, msg, *args): + self.logged.append(msg % args) diff --git a/libs/waitress/tests/test_trigger.py b/libs/waitress/tests/test_trigger.py new file mode 100644 index 000000000..af740f68d --- /dev/null +++ b/libs/waitress/tests/test_trigger.py @@ -0,0 +1,111 @@ +import unittest +import os +import sys + +if not sys.platform.startswith("win"): + + class Test_trigger(unittest.TestCase): + def _makeOne(self, map): + from waitress.trigger import trigger + + self.inst = trigger(map) + return self.inst + + def tearDown(self): + self.inst.close() # prevent __del__ warning from file_dispatcher + + def test__close(self): + map = {} + inst = self._makeOne(map) + fd1, fd2 = inst._fds + inst.close() + self.assertRaises(OSError, os.read, fd1, 1) + self.assertRaises(OSError, os.read, fd2, 1) + + def test__physical_pull(self): + map = {} + inst = self._makeOne(map) + inst._physical_pull() + r = os.read(inst._fds[0], 1) + self.assertEqual(r, b"x") + + def test_readable(self): + map = {} + inst = self._makeOne(map) + self.assertEqual(inst.readable(), True) + + def test_writable(self): + map = {} + inst = self._makeOne(map) + self.assertEqual(inst.writable(), False) + + def test_handle_connect(self): + map = {} + inst = self._makeOne(map) + self.assertEqual(inst.handle_connect(), None) + + def test_close(self): + map = {} + inst = self._makeOne(map) + self.assertEqual(inst.close(), None) + self.assertEqual(inst._closed, True) + + def test_handle_close(self): + map = {} + inst = self._makeOne(map) + self.assertEqual(inst.handle_close(), None) + self.assertEqual(inst._closed, True) + + def test_pull_trigger_nothunk(self): + map = {} + inst = self._makeOne(map) + self.assertEqual(inst.pull_trigger(), None) + r = os.read(inst._fds[0], 1) + self.assertEqual(r, b"x") + + def test_pull_trigger_thunk(self): + map = {} + inst = self._makeOne(map) + self.assertEqual(inst.pull_trigger(True), None) + self.assertEqual(len(inst.thunks), 1) + r = os.read(inst._fds[0], 1) + self.assertEqual(r, b"x") + + def test_handle_read_socket_error(self): + map = {} + inst = self._makeOne(map) + result = inst.handle_read() + self.assertEqual(result, None) + + def test_handle_read_no_socket_error(self): + map = {} + inst = self._makeOne(map) + inst.pull_trigger() + result = inst.handle_read() + self.assertEqual(result, None) + + def test_handle_read_thunk(self): + map = {} + inst = self._makeOne(map) + inst.pull_trigger() + L = [] + inst.thunks = [lambda: L.append(True)] + result = inst.handle_read() + self.assertEqual(result, None) + self.assertEqual(L, [True]) + self.assertEqual(inst.thunks, []) + + def test_handle_read_thunk_error(self): + map = {} + inst = self._makeOne(map) + + def errorthunk(): + raise ValueError + + inst.pull_trigger(errorthunk) + L = [] + inst.log_info = lambda *arg: L.append(arg) + result = inst.handle_read() + self.assertEqual(result, None) + self.assertEqual(len(L), 1) + self.assertEqual(inst.thunks, []) diff --git a/libs/waitress/tests/test_utilities.py b/libs/waitress/tests/test_utilities.py new file mode 100644 index 000000000..15cd24f5a --- /dev/null +++ b/libs/waitress/tests/test_utilities.py @@ -0,0 +1,140 @@ +############################################################################## +# +# Copyright (c) 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## + +import unittest + + +class Test_parse_http_date(unittest.TestCase): + def _callFUT(self, v): + from waitress.utilities import parse_http_date + + return parse_http_date(v) + + def test_rfc850(self): + val = "Tuesday, 08-Feb-94 14:15:29 GMT" + result = self._callFUT(val) + self.assertEqual(result, 760716929) + + def test_rfc822(self): + val = "Sun, 08 Feb 1994 14:15:29 GMT" + result = self._callFUT(val) + self.assertEqual(result, 760716929) + + def test_neither(self): + val = "" + result = self._callFUT(val) + self.assertEqual(result, 0) + + +class Test_build_http_date(unittest.TestCase): + def test_rountdrip(self): + from waitress.utilities import build_http_date, parse_http_date + from time import time + + t = int(time()) + self.assertEqual(t, parse_http_date(build_http_date(t))) + + +class Test_unpack_rfc850(unittest.TestCase): + def _callFUT(self, val): + from waitress.utilities import unpack_rfc850, rfc850_reg + + return unpack_rfc850(rfc850_reg.match(val.lower())) + + def test_it(self): + val = "Tuesday, 08-Feb-94 14:15:29 GMT" + result = self._callFUT(val) + self.assertEqual(result, (1994, 2, 8, 14, 15, 29, 0, 0, 0)) + + +class Test_unpack_rfc_822(unittest.TestCase): + def _callFUT(self, val): + from waitress.utilities import unpack_rfc822, rfc822_reg + + return unpack_rfc822(rfc822_reg.match(val.lower())) + + def test_it(self): + val = "Sun, 08 Feb 1994 14:15:29 GMT" + result = self._callFUT(val) + self.assertEqual(result, (1994, 2, 8, 14, 15, 29, 0, 0, 0)) + + +class Test_find_double_newline(unittest.TestCase): + def _callFUT(self, val): + from waitress.utilities import find_double_newline + + return find_double_newline(val) + + def test_empty(self): + self.assertEqual(self._callFUT(b""), -1) + + def test_one_linefeed(self): + self.assertEqual(self._callFUT(b"\n"), -1) + + def test_double_linefeed(self): + self.assertEqual(self._callFUT(b"\n\n"), -1) + + def test_one_crlf(self): + self.assertEqual(self._callFUT(b"\r\n"), -1) + + def test_double_crfl(self): + self.assertEqual(self._callFUT(b"\r\n\r\n"), 4) + + def test_mixed(self): + self.assertEqual(self._callFUT(b"\n\n00\r\n\r\n"), 8) + + +class TestBadRequest(unittest.TestCase): + def _makeOne(self): + from waitress.utilities import BadRequest + + return BadRequest(1) + + def test_it(self): + inst = self._makeOne() + self.assertEqual(inst.body, 1) + + +class Test_undquote(unittest.TestCase): + def _callFUT(self, value): + from waitress.utilities import undquote + + return undquote(value) + + def test_empty(self): + self.assertEqual(self._callFUT(""), "") + + def test_quoted(self): + self.assertEqual(self._callFUT('"test"'), "test") + + def test_unquoted(self): + self.assertEqual(self._callFUT("test"), "test") + + def test_quoted_backslash_quote(self): + self.assertEqual(self._callFUT('"\\""'), '"') + + def test_quoted_htab(self): + self.assertEqual(self._callFUT('"\t"'), "\t") + + def test_quoted_backslash_htab(self): + self.assertEqual(self._callFUT('"\\\t"'), "\t") + + def test_quoted_backslash_invalid(self): + self.assertRaises(ValueError, self._callFUT, '"\\"') + + def test_invalid_quoting(self): + self.assertRaises(ValueError, self._callFUT, '"test') + + def test_invalid_quoting_single_quote(self): + self.assertRaises(ValueError, self._callFUT, '"') diff --git a/libs/waitress/tests/test_wasyncore.py b/libs/waitress/tests/test_wasyncore.py new file mode 100644 index 000000000..9c235092f --- /dev/null +++ b/libs/waitress/tests/test_wasyncore.py @@ -0,0 +1,1761 @@ +from waitress import wasyncore as asyncore +from waitress import compat +import contextlib +import functools +import gc +import unittest +import select +import os +import socket +import sys +import time +import errno +import re +import struct +import threading +import warnings + +from io import BytesIO + +TIMEOUT = 3 +HAS_UNIX_SOCKETS = hasattr(socket, "AF_UNIX") +HOST = "localhost" +HOSTv4 = "127.0.0.1" +HOSTv6 = "::1" + +# Filename used for testing +if os.name == "java": # pragma: no cover + # Jython disallows @ in module names + TESTFN = "$test" +else: + TESTFN = "@test" + +TESTFN = "{}_{}_tmp".format(TESTFN, os.getpid()) + + +class DummyLogger(object): # pragma: no cover + def __init__(self): + self.messages = [] + + def log(self, severity, message): + self.messages.append((severity, message)) + + +class WarningsRecorder(object): # pragma: no cover + """Convenience wrapper for the warnings list returned on + entry to the warnings.catch_warnings() context manager. + """ + + def __init__(self, warnings_list): + self._warnings = warnings_list + self._last = 0 + + @property + def warnings(self): + return self._warnings[self._last :] + + def reset(self): + self._last = len(self._warnings) + + +def _filterwarnings(filters, quiet=False): # pragma: no cover + """Catch the warnings, then check if all the expected + warnings have been raised and re-raise unexpected warnings. + If 'quiet' is True, only re-raise the unexpected warnings. + """ + # Clear the warning registry of the calling module + # in order to re-raise the warnings. + frame = sys._getframe(2) + registry = frame.f_globals.get("__warningregistry__") + if registry: + registry.clear() + with warnings.catch_warnings(record=True) as w: + # Set filter "always" to record all warnings. Because + # test_warnings swap the module, we need to look up in + # the sys.modules dictionary. + sys.modules["warnings"].simplefilter("always") + yield WarningsRecorder(w) + # Filter the recorded warnings + reraise = list(w) + missing = [] + for msg, cat in filters: + seen = False + for w in reraise[:]: + warning = w.message + # Filter out the matching messages + if re.match(msg, str(warning), re.I) and issubclass(warning.__class__, cat): + seen = True + reraise.remove(w) + if not seen and not quiet: + # This filter caught nothing + missing.append((msg, cat.__name__)) + if reraise: + raise AssertionError("unhandled warning %s" % reraise[0]) + if missing: + raise AssertionError("filter (%r, %s) did not catch any warning" % missing[0]) + + +def check_warnings(*filters, **kwargs): # pragma: no cover + """Context manager to silence warnings. + + Accept 2-tuples as positional arguments: + ("message regexp", WarningCategory) + + Optional argument: + - if 'quiet' is True, it does not fail if a filter catches nothing + (default True without argument, + default False if some filters are defined) + + Without argument, it defaults to: + check_warnings(("", Warning), quiet=True) + """ + quiet = kwargs.get("quiet") + if not filters: + filters = (("", Warning),) + # Preserve backward compatibility + if quiet is None: + quiet = True + return _filterwarnings(filters, quiet) + + +def gc_collect(): # pragma: no cover + """Force as many objects as possible to be collected. + + In non-CPython implementations of Python, this is needed because timely + deallocation is not guaranteed by the garbage collector. (Even in CPython + this can be the case in case of reference cycles.) This means that __del__ + methods may be called later than expected and weakrefs may remain alive for + longer than expected. This function tries its best to force all garbage + objects to disappear. + """ + gc.collect() + if sys.platform.startswith("java"): + time.sleep(0.1) + gc.collect() + gc.collect() + + +def threading_setup(): # pragma: no cover + return (compat.thread._count(), None) + + +def threading_cleanup(*original_values): # pragma: no cover + global environment_altered + + _MAX_COUNT = 100 + + for count in range(_MAX_COUNT): + values = (compat.thread._count(), None) + if values == original_values: + break + + if not count: + # Display a warning at the first iteration + environment_altered = True + sys.stderr.write( + "Warning -- threading_cleanup() failed to cleanup " + "%s threads" % (values[0] - original_values[0]) + ) + sys.stderr.flush() + + values = None + + time.sleep(0.01) + gc_collect() + + +def reap_threads(func): # pragma: no cover + """Use this function when threads are being used. This will + ensure that the threads are cleaned up even when the test fails. + """ + + @functools.wraps(func) + def decorator(*args): + key = threading_setup() + try: + return func(*args) + finally: + threading_cleanup(*key) + + return decorator + + +def join_thread(thread, timeout=30.0): # pragma: no cover + """Join a thread. Raise an AssertionError if the thread is still alive + after timeout seconds. + """ + thread.join(timeout) + if thread.is_alive(): + msg = "failed to join the thread in %.1f seconds" % timeout + raise AssertionError(msg) + + +def bind_port(sock, host=HOST): # pragma: no cover + """Bind the socket to a free port and return the port number. Relies on + ephemeral ports in order to ensure we are using an unbound port. This is + important as many tests may be running simultaneously, especially in a + buildbot environment. This method raises an exception if the sock.family + is AF_INET and sock.type is SOCK_STREAM, *and* the socket has SO_REUSEADDR + or SO_REUSEPORT set on it. Tests should *never* set these socket options + for TCP/IP sockets. The only case for setting these options is testing + multicasting via multiple UDP sockets. + + Additionally, if the SO_EXCLUSIVEADDRUSE socket option is available (i.e. + on Windows), it will be set on the socket. This will prevent anyone else + from bind()'ing to our host/port for the duration of the test. + """ + + if sock.family == socket.AF_INET and sock.type == socket.SOCK_STREAM: + if hasattr(socket, "SO_REUSEADDR"): + if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) == 1: + raise RuntimeError( + "tests should never set the SO_REUSEADDR " + "socket option on TCP/IP sockets!" + ) + if hasattr(socket, "SO_REUSEPORT"): + try: + if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 1: + raise RuntimeError( + "tests should never set the SO_REUSEPORT " + "socket option on TCP/IP sockets!" + ) + except OSError: + # Python's socket module was compiled using modern headers + # thus defining SO_REUSEPORT but this process is running + # under an older kernel that does not support SO_REUSEPORT. + pass + if hasattr(socket, "SO_EXCLUSIVEADDRUSE"): + sock.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1) + + sock.bind((host, 0)) + port = sock.getsockname()[1] + return port + + +def closewrapper(sock): # pragma: no cover + try: + yield sock + finally: + sock.close() + + +class dummysocket: # pragma: no cover + def __init__(self): + self.closed = False + + def close(self): + self.closed = True + + def fileno(self): + return 42 + + def setblocking(self, yesno): + self.isblocking = yesno + + def getpeername(self): + return "peername" + + +class dummychannel: # pragma: no cover + def __init__(self): + self.socket = dummysocket() + + def close(self): + self.socket.close() + + +class exitingdummy: # pragma: no cover + def __init__(self): + pass + + def handle_read_event(self): + raise asyncore.ExitNow() + + handle_write_event = handle_read_event + handle_close = handle_read_event + handle_expt_event = handle_read_event + + +class crashingdummy: + def __init__(self): + self.error_handled = False + + def handle_read_event(self): + raise Exception() + + handle_write_event = handle_read_event + handle_close = handle_read_event + handle_expt_event = handle_read_event + + def handle_error(self): + self.error_handled = True + + +# used when testing senders; just collects what it gets until newline is sent +def capture_server(evt, buf, serv): # pragma no cover + try: + serv.listen(0) + conn, addr = serv.accept() + except socket.timeout: + pass + else: + n = 200 + start = time.time() + while n > 0 and time.time() - start < 3.0: + r, w, e = select.select([conn], [], [], 0.1) + if r: + n -= 1 + data = conn.recv(10) + # keep everything except for the newline terminator + buf.write(data.replace(b"\n", b"")) + if b"\n" in data: + break + time.sleep(0.01) + + conn.close() + finally: + serv.close() + evt.set() + + +def bind_unix_socket(sock, addr): # pragma: no cover + """Bind a unix socket, raising SkipTest if PermissionError is raised.""" + assert sock.family == socket.AF_UNIX + try: + sock.bind(addr) + except PermissionError: + sock.close() + raise unittest.SkipTest("cannot bind AF_UNIX sockets") + + +def bind_af_aware(sock, addr): + """Helper function to bind a socket according to its family.""" + if HAS_UNIX_SOCKETS and sock.family == socket.AF_UNIX: + # Make sure the path doesn't exist. + unlink(addr) + bind_unix_socket(sock, addr) + else: + sock.bind(addr) + + +if sys.platform.startswith("win"): # pragma: no cover + + def _waitfor(func, pathname, waitall=False): + # Perform the operation + func(pathname) + # Now setup the wait loop + if waitall: + dirname = pathname + else: + dirname, name = os.path.split(pathname) + dirname = dirname or "." + # Check for `pathname` to be removed from the filesystem. + # The exponential backoff of the timeout amounts to a total + # of ~1 second after which the deletion is probably an error + # anyway. + # Testing on an [email protected] shows that usually only 1 iteration is + # required when contention occurs. + timeout = 0.001 + while timeout < 1.0: + # Note we are only testing for the existence of the file(s) in + # the contents of the directory regardless of any security or + # access rights. If we have made it this far, we have sufficient + # permissions to do that much using Python's equivalent of the + # Windows API FindFirstFile. + # Other Windows APIs can fail or give incorrect results when + # dealing with files that are pending deletion. + L = os.listdir(dirname) + if not (L if waitall else name in L): + return + # Increase the timeout and try again + time.sleep(timeout) + timeout *= 2 + warnings.warn( + "tests may fail, delete still pending for " + pathname, + RuntimeWarning, + stacklevel=4, + ) + + def _unlink(filename): + _waitfor(os.unlink, filename) + + +else: + _unlink = os.unlink + + +def unlink(filename): + try: + _unlink(filename) + except OSError: + pass + + +def _is_ipv6_enabled(): # pragma: no cover + """Check whether IPv6 is enabled on this host.""" + if compat.HAS_IPV6: + sock = None + try: + sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + sock.bind(("::1", 0)) + return True + except socket.error: + pass + finally: + if sock: + sock.close() + return False + + +IPV6_ENABLED = _is_ipv6_enabled() + + +class HelperFunctionTests(unittest.TestCase): + def test_readwriteexc(self): + # Check exception handling behavior of read, write and _exception + + # check that ExitNow exceptions in the object handler method + # bubbles all the way up through asyncore read/write/_exception calls + tr1 = exitingdummy() + self.assertRaises(asyncore.ExitNow, asyncore.read, tr1) + self.assertRaises(asyncore.ExitNow, asyncore.write, tr1) + self.assertRaises(asyncore.ExitNow, asyncore._exception, tr1) + + # check that an exception other than ExitNow in the object handler + # method causes the handle_error method to get called + tr2 = crashingdummy() + asyncore.read(tr2) + self.assertEqual(tr2.error_handled, True) + + tr2 = crashingdummy() + asyncore.write(tr2) + self.assertEqual(tr2.error_handled, True) + + tr2 = crashingdummy() + asyncore._exception(tr2) + self.assertEqual(tr2.error_handled, True) + + # asyncore.readwrite uses constants in the select module that + # are not present in Windows systems (see this thread: + # http://mail.python.org/pipermail/python-list/2001-October/109973.html) + # These constants should be present as long as poll is available + + @unittest.skipUnless(hasattr(select, "poll"), "select.poll required") + def test_readwrite(self): + # Check that correct methods are called by readwrite() + + attributes = ("read", "expt", "write", "closed", "error_handled") + + expected = ( + (select.POLLIN, "read"), + (select.POLLPRI, "expt"), + (select.POLLOUT, "write"), + (select.POLLERR, "closed"), + (select.POLLHUP, "closed"), + (select.POLLNVAL, "closed"), + ) + + class testobj: + def __init__(self): + self.read = False + self.write = False + self.closed = False + self.expt = False + self.error_handled = False + + def handle_read_event(self): + self.read = True + + def handle_write_event(self): + self.write = True + + def handle_close(self): + self.closed = True + + def handle_expt_event(self): + self.expt = True + + # def handle_error(self): + # self.error_handled = True + + for flag, expectedattr in expected: + tobj = testobj() + self.assertEqual(getattr(tobj, expectedattr), False) + asyncore.readwrite(tobj, flag) + + # Only the attribute modified by the routine we expect to be + # called should be True. + for attr in attributes: + self.assertEqual(getattr(tobj, attr), attr == expectedattr) + + # check that ExitNow exceptions in the object handler method + # bubbles all the way up through asyncore readwrite call + tr1 = exitingdummy() + self.assertRaises(asyncore.ExitNow, asyncore.readwrite, tr1, flag) + + # check that an exception other than ExitNow in the object handler + # method causes the handle_error method to get called + tr2 = crashingdummy() + self.assertEqual(tr2.error_handled, False) + asyncore.readwrite(tr2, flag) + self.assertEqual(tr2.error_handled, True) + + def test_closeall(self): + self.closeall_check(False) + + def test_closeall_default(self): + self.closeall_check(True) + + def closeall_check(self, usedefault): + # Check that close_all() closes everything in a given map + + l = [] + testmap = {} + for i in range(10): + c = dummychannel() + l.append(c) + self.assertEqual(c.socket.closed, False) + testmap[i] = c + + if usedefault: + socketmap = asyncore.socket_map + try: + asyncore.socket_map = testmap + asyncore.close_all() + finally: + testmap, asyncore.socket_map = asyncore.socket_map, socketmap + else: + asyncore.close_all(testmap) + + self.assertEqual(len(testmap), 0) + + for c in l: + self.assertEqual(c.socket.closed, True) + + def test_compact_traceback(self): + try: + raise Exception("I don't like spam!") + except: + real_t, real_v, real_tb = sys.exc_info() + r = asyncore.compact_traceback() + + (f, function, line), t, v, info = r + self.assertEqual(os.path.split(f)[-1], "test_wasyncore.py") + self.assertEqual(function, "test_compact_traceback") + self.assertEqual(t, real_t) + self.assertEqual(v, real_v) + self.assertEqual(info, "[%s|%s|%s]" % (f, function, line)) + + +class DispatcherTests(unittest.TestCase): + def setUp(self): + pass + + def tearDown(self): + asyncore.close_all() + + def test_basic(self): + d = asyncore.dispatcher() + self.assertEqual(d.readable(), True) + self.assertEqual(d.writable(), True) + + def test_repr(self): + d = asyncore.dispatcher() + self.assertEqual(repr(d), "<waitress.wasyncore.dispatcher at %#x>" % id(d)) + + def test_log_info(self): + import logging + + inst = asyncore.dispatcher(map={}) + logger = DummyLogger() + inst.logger = logger + inst.log_info("message", "warning") + self.assertEqual(logger.messages, [(logging.WARN, "message")]) + + def test_log(self): + import logging + + inst = asyncore.dispatcher() + logger = DummyLogger() + inst.logger = logger + inst.log("message") + self.assertEqual(logger.messages, [(logging.DEBUG, "message")]) + + def test_unhandled(self): + import logging + + inst = asyncore.dispatcher() + logger = DummyLogger() + inst.logger = logger + + inst.handle_expt() + inst.handle_read() + inst.handle_write() + inst.handle_connect() + + expected = [ + (logging.WARN, "unhandled incoming priority event"), + (logging.WARN, "unhandled read event"), + (logging.WARN, "unhandled write event"), + (logging.WARN, "unhandled connect event"), + ] + self.assertEqual(logger.messages, expected) + + def test_strerror(self): + # refers to bug #8573 + err = asyncore._strerror(errno.EPERM) + if hasattr(os, "strerror"): + self.assertEqual(err, os.strerror(errno.EPERM)) + err = asyncore._strerror(-1) + self.assertTrue(err != "") + + +class dispatcherwithsend_noread(asyncore.dispatcher_with_send): # pragma: no cover + def readable(self): + return False + + def handle_connect(self): + pass + + +class DispatcherWithSendTests(unittest.TestCase): + def setUp(self): + pass + + def tearDown(self): + asyncore.close_all() + + @reap_threads + def test_send(self): + evt = threading.Event() + sock = socket.socket() + sock.settimeout(3) + port = bind_port(sock) + + cap = BytesIO() + args = (evt, cap, sock) + t = threading.Thread(target=capture_server, args=args) + t.start() + try: + # wait a little longer for the server to initialize (it sometimes + # refuses connections on slow machines without this wait) + time.sleep(0.2) + + data = b"Suppose there isn't a 16-ton weight?" + d = dispatcherwithsend_noread() + d.create_socket() + d.connect((HOST, port)) + + # give time for socket to connect + time.sleep(0.1) + + d.send(data) + d.send(data) + d.send(b"\n") + + n = 1000 + while d.out_buffer and n > 0: # pragma: no cover + asyncore.poll() + n -= 1 + + evt.wait() + + self.assertEqual(cap.getvalue(), data * 2) + finally: + join_thread(t, timeout=TIMEOUT) + + + hasattr(asyncore, "file_wrapper"), "asyncore.file_wrapper required" +) +class FileWrapperTest(unittest.TestCase): + def setUp(self): + self.d = b"It's not dead, it's sleeping!" + with open(TESTFN, "wb") as file: + file.write(self.d) + + def tearDown(self): + unlink(TESTFN) + + def test_recv(self): + fd = os.open(TESTFN, os.O_RDONLY) + w = asyncore.file_wrapper(fd) + os.close(fd) + + self.assertNotEqual(w.fd, fd) + self.assertNotEqual(w.fileno(), fd) + self.assertEqual(w.recv(13), b"It's not dead") + self.assertEqual(w.read(6), b", it's") + w.close() + self.assertRaises(OSError, w.read, 1) + + def test_send(self): + d1 = b"Come again?" + d2 = b"I want to buy some cheese." + fd = os.open(TESTFN, os.O_WRONLY | os.O_APPEND) + w = asyncore.file_wrapper(fd) + os.close(fd) + + w.write(d1) + w.send(d2) + w.close() + with open(TESTFN, "rb") as file: + self.assertEqual(file.read(), self.d + d1 + d2) + + @unittest.skipUnless( + hasattr(asyncore, "file_dispatcher"), "asyncore.file_dispatcher required" + ) + def test_dispatcher(self): + fd = os.open(TESTFN, os.O_RDONLY) + data = [] + + class FileDispatcher(asyncore.file_dispatcher): + def handle_read(self): + data.append(self.recv(29)) + + FileDispatcher(fd) + os.close(fd) + asyncore.loop(timeout=0.01, use_poll=True, count=2) + self.assertEqual(b"".join(data), self.d) + + def test_resource_warning(self): + # Issue #11453 + got_warning = False + while got_warning is False: + # we try until we get the outcome we want because this + # test is not deterministic (gc_collect() may not + fd = os.open(TESTFN, os.O_RDONLY) + f = asyncore.file_wrapper(fd) + + os.close(fd) + + try: + with check_warnings(("", compat.ResourceWarning)): + f = None + gc_collect() + except AssertionError: # pragma: no cover + pass + else: + got_warning = True + + def test_close_twice(self): + fd = os.open(TESTFN, os.O_RDONLY) + f = asyncore.file_wrapper(fd) + os.close(fd) + + os.close(f.fd) # file_wrapper dupped fd + with self.assertRaises(OSError): + f.close() + + self.assertEqual(f.fd, -1) + # calling close twice should not fail + f.close() + + +class BaseTestHandler(asyncore.dispatcher): # pragma: no cover + def __init__(self, sock=None): + asyncore.dispatcher.__init__(self, sock) + self.flag = False + + def handle_accept(self): + raise Exception("handle_accept not supposed to be called") + + def handle_accepted(self): + raise Exception("handle_accepted not supposed to be called") + + def handle_connect(self): + raise Exception("handle_connect not supposed to be called") + + def handle_expt(self): + raise Exception("handle_expt not supposed to be called") + + def handle_close(self): + raise Exception("handle_close not supposed to be called") + + def handle_error(self): + raise + + +class BaseServer(asyncore.dispatcher): + """A server which listens on an address and dispatches the + connection to a handler. + """ + + def __init__(self, family, addr, handler=BaseTestHandler): + asyncore.dispatcher.__init__(self) + self.create_socket(family) + self.set_reuse_addr() + bind_af_aware(self.socket, addr) + self.listen(5) + self.handler = handler + + @property + def address(self): + return self.socket.getsockname() + + def handle_accepted(self, sock, addr): + self.handler(sock) + + def handle_error(self): # pragma: no cover + raise + + +class BaseClient(BaseTestHandler): + def __init__(self, family, address): + BaseTestHandler.__init__(self) + self.create_socket(family) + self.connect(address) + + def handle_connect(self): + pass + + +class BaseTestAPI: + def tearDown(self): + asyncore.close_all(ignore_all=True) + + def loop_waiting_for_flag(self, instance, timeout=5): # pragma: no cover + timeout = float(timeout) / 100 + count = 100 + while asyncore.socket_map and count > 0: + asyncore.loop(timeout=0.01, count=1, use_poll=self.use_poll) + if instance.flag: + return + count -= 1 + time.sleep(timeout) + self.fail("flag not set") + + def test_handle_connect(self): + # make sure handle_connect is called on connect() + + class TestClient(BaseClient): + def handle_connect(self): + self.flag = True + + server = BaseServer(self.family, self.addr) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + + def test_handle_accept(self): + # make sure handle_accept() is called when a client connects + + class TestListener(BaseTestHandler): + def __init__(self, family, addr): + BaseTestHandler.__init__(self) + self.create_socket(family) + bind_af_aware(self.socket, addr) + self.listen(5) + self.address = self.socket.getsockname() + + def handle_accept(self): + self.flag = True + + server = TestListener(self.family, self.addr) + client = BaseClient(self.family, server.address) + self.loop_waiting_for_flag(server) + + def test_handle_accepted(self): + # make sure handle_accepted() is called when a client connects + + class TestListener(BaseTestHandler): + def __init__(self, family, addr): + BaseTestHandler.__init__(self) + self.create_socket(family) + bind_af_aware(self.socket, addr) + self.listen(5) + self.address = self.socket.getsockname() + + def handle_accept(self): + asyncore.dispatcher.handle_accept(self) + + def handle_accepted(self, sock, addr): + sock.close() + self.flag = True + + server = TestListener(self.family, self.addr) + client = BaseClient(self.family, server.address) + self.loop_waiting_for_flag(server) + + def test_handle_read(self): + # make sure handle_read is called on data received + + class TestClient(BaseClient): + def handle_read(self): + self.flag = True + + class TestHandler(BaseTestHandler): + def __init__(self, conn): + BaseTestHandler.__init__(self, conn) + self.send(b"x" * 1024) + + server = BaseServer(self.family, self.addr, TestHandler) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + + def test_handle_write(self): + # make sure handle_write is called + + class TestClient(BaseClient): + def handle_write(self): + self.flag = True + + server = BaseServer(self.family, self.addr) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + + def test_handle_close(self): + # make sure handle_close is called when the other end closes + # the connection + + class TestClient(BaseClient): + def handle_read(self): + # in order to make handle_close be called we are supposed + # to make at least one recv() call + self.recv(1024) + + def handle_close(self): + self.flag = True + self.close() + + class TestHandler(BaseTestHandler): + def __init__(self, conn): + BaseTestHandler.__init__(self, conn) + self.close() + + server = BaseServer(self.family, self.addr, TestHandler) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + + def test_handle_close_after_conn_broken(self): + # Check that ECONNRESET/EPIPE is correctly handled (issues #5661 and + # #11265). + + data = b"\0" * 128 + + class TestClient(BaseClient): + def handle_write(self): + self.send(data) + + def handle_close(self): + self.flag = True + self.close() + + def handle_expt(self): # pragma: no cover + # needs to exist for MacOS testing + self.flag = True + self.close() + + class TestHandler(BaseTestHandler): + def handle_read(self): + self.recv(len(data)) + self.close() + + def writable(self): + return False + + server = BaseServer(self.family, self.addr, TestHandler) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + + @unittest.skipIf( + sys.platform.startswith("sunos"), "OOB support is broken on Solaris" + ) + def test_handle_expt(self): + # Make sure handle_expt is called on OOB data received. + # Note: this might fail on some platforms as OOB data is + # tenuously supported and rarely used. + if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX: + self.skipTest("Not applicable to AF_UNIX sockets.") + + if sys.platform == "darwin" and self.use_poll: # pragma: no cover + self.skipTest("poll may fail on macOS; see issue #28087") + + class TestClient(BaseClient): + def handle_expt(self): + self.socket.recv(1024, socket.MSG_OOB) + self.flag = True + + class TestHandler(BaseTestHandler): + def __init__(self, conn): + BaseTestHandler.__init__(self, conn) + self.socket.send(compat.tobytes(chr(244)), socket.MSG_OOB) + + server = BaseServer(self.family, self.addr, TestHandler) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + + def test_handle_error(self): + class TestClient(BaseClient): + def handle_write(self): + 1.0 / 0 + + def handle_error(self): + self.flag = True + try: + raise + except ZeroDivisionError: + pass + else: # pragma: no cover + raise Exception("exception not raised") + + server = BaseServer(self.family, self.addr) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + + def test_connection_attributes(self): + server = BaseServer(self.family, self.addr) + client = BaseClient(self.family, server.address) + + # we start disconnected + self.assertFalse(server.connected) + self.assertTrue(server.accepting) + # this can't be taken for granted across all platforms + # self.assertFalse(client.connected) + self.assertFalse(client.accepting) + + # execute some loops so that client connects to server + asyncore.loop(timeout=0.01, use_poll=self.use_poll, count=100) + self.assertFalse(server.connected) + self.assertTrue(server.accepting) + self.assertTrue(client.connected) + self.assertFalse(client.accepting) + + # disconnect the client + client.close() + self.assertFalse(server.connected) + self.assertTrue(server.accepting) + self.assertFalse(client.connected) + self.assertFalse(client.accepting) + + # stop serving + server.close() + self.assertFalse(server.connected) + self.assertFalse(server.accepting) + + def test_create_socket(self): + s = asyncore.dispatcher() + s.create_socket(self.family) + # self.assertEqual(s.socket.type, socket.SOCK_STREAM) + self.assertEqual(s.socket.family, self.family) + self.assertEqual(s.socket.gettimeout(), 0) + # self.assertFalse(s.socket.get_inheritable()) + + def test_bind(self): + if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX: + self.skipTest("Not applicable to AF_UNIX sockets.") + s1 = asyncore.dispatcher() + s1.create_socket(self.family) + s1.bind(self.addr) + s1.listen(5) + port = s1.socket.getsockname()[1] + + s2 = asyncore.dispatcher() + s2.create_socket(self.family) + # EADDRINUSE indicates the socket was correctly bound + self.assertRaises(socket.error, s2.bind, (self.addr[0], port)) + + def test_set_reuse_addr(self): # pragma: no cover + if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX: + self.skipTest("Not applicable to AF_UNIX sockets.") + + with closewrapper(socket.socket(self.family)) as sock: + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + except OSError: + unittest.skip("SO_REUSEADDR not supported on this platform") + else: + # if SO_REUSEADDR succeeded for sock we expect asyncore + # to do the same + s = asyncore.dispatcher(socket.socket(self.family)) + self.assertFalse( + s.socket.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) + ) + s.socket.close() + s.create_socket(self.family) + s.set_reuse_addr() + self.assertTrue( + s.socket.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) + ) + + @reap_threads + def test_quick_connect(self): # pragma: no cover + # see: http://bugs.python.org/issue10340 + if self.family not in (socket.AF_INET, getattr(socket, "AF_INET6", object())): + self.skipTest("test specific to AF_INET and AF_INET6") + + server = BaseServer(self.family, self.addr) + # run the thread 500 ms: the socket should be connected in 200 ms + t = threading.Thread(target=lambda: asyncore.loop(timeout=0.1, count=5)) + t.start() + try: + sock = socket.socket(self.family, socket.SOCK_STREAM) + with closewrapper(sock) as s: + s.settimeout(0.2) + s.setsockopt( + socket.SOL_SOCKET, socket.SO_LINGER, struct.pack("ii", 1, 0) + ) + + try: + s.connect(server.address) + except OSError: + pass + finally: + join_thread(t, timeout=TIMEOUT) + + +class TestAPI_UseIPv4Sockets(BaseTestAPI): + family = socket.AF_INET + addr = (HOST, 0) + + [email protected](IPV6_ENABLED, "IPv6 support required") +class TestAPI_UseIPv6Sockets(BaseTestAPI): + family = socket.AF_INET6 + addr = (HOSTv6, 0) + + [email protected](HAS_UNIX_SOCKETS, "Unix sockets required") +class TestAPI_UseUnixSockets(BaseTestAPI): + if HAS_UNIX_SOCKETS: + family = socket.AF_UNIX + addr = TESTFN + + def tearDown(self): + unlink(self.addr) + BaseTestAPI.tearDown(self) + + +class TestAPI_UseIPv4Select(TestAPI_UseIPv4Sockets, unittest.TestCase): + use_poll = False + + [email protected](hasattr(select, "poll"), "select.poll required") +class TestAPI_UseIPv4Poll(TestAPI_UseIPv4Sockets, unittest.TestCase): + use_poll = True + + +class TestAPI_UseIPv6Select(TestAPI_UseIPv6Sockets, unittest.TestCase): + use_poll = False + + [email protected](hasattr(select, "poll"), "select.poll required") +class TestAPI_UseIPv6Poll(TestAPI_UseIPv6Sockets, unittest.TestCase): + use_poll = True + + +class TestAPI_UseUnixSocketsSelect(TestAPI_UseUnixSockets, unittest.TestCase): + use_poll = False + + [email protected](hasattr(select, "poll"), "select.poll required") +class TestAPI_UseUnixSocketsPoll(TestAPI_UseUnixSockets, unittest.TestCase): + use_poll = True + + +class Test__strerror(unittest.TestCase): + def _callFUT(self, err): + from waitress.wasyncore import _strerror + + return _strerror(err) + + def test_gardenpath(self): + self.assertEqual(self._callFUT(1), "Operation not permitted") + + def test_unknown(self): + self.assertEqual(self._callFUT("wut"), "Unknown error wut") + + +class Test_read(unittest.TestCase): + def _callFUT(self, dispatcher): + from waitress.wasyncore import read + + return read(dispatcher) + + def test_gardenpath(self): + inst = DummyDispatcher() + self._callFUT(inst) + self.assertTrue(inst.read_event_handled) + self.assertFalse(inst.error_handled) + + def test_reraised(self): + from waitress.wasyncore import ExitNow + + inst = DummyDispatcher(ExitNow) + self.assertRaises(ExitNow, self._callFUT, inst) + self.assertTrue(inst.read_event_handled) + self.assertFalse(inst.error_handled) + + def test_non_reraised(self): + inst = DummyDispatcher(OSError) + self._callFUT(inst) + self.assertTrue(inst.read_event_handled) + self.assertTrue(inst.error_handled) + + +class Test_write(unittest.TestCase): + def _callFUT(self, dispatcher): + from waitress.wasyncore import write + + return write(dispatcher) + + def test_gardenpath(self): + inst = DummyDispatcher() + self._callFUT(inst) + self.assertTrue(inst.write_event_handled) + self.assertFalse(inst.error_handled) + + def test_reraised(self): + from waitress.wasyncore import ExitNow + + inst = DummyDispatcher(ExitNow) + self.assertRaises(ExitNow, self._callFUT, inst) + self.assertTrue(inst.write_event_handled) + self.assertFalse(inst.error_handled) + + def test_non_reraised(self): + inst = DummyDispatcher(OSError) + self._callFUT(inst) + self.assertTrue(inst.write_event_handled) + self.assertTrue(inst.error_handled) + + +class Test__exception(unittest.TestCase): + def _callFUT(self, dispatcher): + from waitress.wasyncore import _exception + + return _exception(dispatcher) + + def test_gardenpath(self): + inst = DummyDispatcher() + self._callFUT(inst) + self.assertTrue(inst.expt_event_handled) + self.assertFalse(inst.error_handled) + + def test_reraised(self): + from waitress.wasyncore import ExitNow + + inst = DummyDispatcher(ExitNow) + self.assertRaises(ExitNow, self._callFUT, inst) + self.assertTrue(inst.expt_event_handled) + self.assertFalse(inst.error_handled) + + def test_non_reraised(self): + inst = DummyDispatcher(OSError) + self._callFUT(inst) + self.assertTrue(inst.expt_event_handled) + self.assertTrue(inst.error_handled) + + [email protected](hasattr(select, "poll"), "select.poll required") +class Test_readwrite(unittest.TestCase): + def _callFUT(self, obj, flags): + from waitress.wasyncore import readwrite + + return readwrite(obj, flags) + + def test_handle_read_event(self): + flags = 0 + flags |= select.POLLIN + inst = DummyDispatcher() + self._callFUT(inst, flags) + self.assertTrue(inst.read_event_handled) + + def test_handle_write_event(self): + flags = 0 + flags |= select.POLLOUT + inst = DummyDispatcher() + self._callFUT(inst, flags) + self.assertTrue(inst.write_event_handled) + + def test_handle_expt_event(self): + flags = 0 + flags |= select.POLLPRI + inst = DummyDispatcher() + self._callFUT(inst, flags) + self.assertTrue(inst.expt_event_handled) + + def test_handle_close(self): + flags = 0 + flags |= select.POLLHUP + inst = DummyDispatcher() + self._callFUT(inst, flags) + self.assertTrue(inst.close_handled) + + def test_socketerror_not_in_disconnected(self): + flags = 0 + flags |= select.POLLIN + inst = DummyDispatcher(socket.error(errno.EALREADY, "EALREADY")) + self._callFUT(inst, flags) + self.assertTrue(inst.read_event_handled) + self.assertTrue(inst.error_handled) + + def test_socketerror_in_disconnected(self): + flags = 0 + flags |= select.POLLIN + inst = DummyDispatcher(socket.error(errno.ECONNRESET, "ECONNRESET")) + self._callFUT(inst, flags) + self.assertTrue(inst.read_event_handled) + self.assertTrue(inst.close_handled) + + def test_exception_in_reraised(self): + from waitress import wasyncore + + flags = 0 + flags |= select.POLLIN + inst = DummyDispatcher(wasyncore.ExitNow) + self.assertRaises(wasyncore.ExitNow, self._callFUT, inst, flags) + self.assertTrue(inst.read_event_handled) + + def test_exception_not_in_reraised(self): + flags = 0 + flags |= select.POLLIN + inst = DummyDispatcher(ValueError) + self._callFUT(inst, flags) + self.assertTrue(inst.error_handled) + + +class Test_poll(unittest.TestCase): + def _callFUT(self, timeout=0.0, map=None): + from waitress.wasyncore import poll + + return poll(timeout, map) + + def test_nothing_writable_nothing_readable_but_map_not_empty(self): + # i read the mock.patch docs. nerp. + dummy_time = DummyTime() + map = {0: DummyDispatcher()} + try: + from waitress import wasyncore + + old_time = wasyncore.time + wasyncore.time = dummy_time + result = self._callFUT(map=map) + finally: + wasyncore.time = old_time + self.assertEqual(result, None) + self.assertEqual(dummy_time.sleepvals, [0.0]) + + def test_select_raises_EINTR(self): + # i read the mock.patch docs. nerp. + dummy_select = DummySelect(select.error(errno.EINTR)) + disp = DummyDispatcher() + disp.readable = lambda: True + map = {0: disp} + try: + from waitress import wasyncore + + old_select = wasyncore.select + wasyncore.select = dummy_select + result = self._callFUT(map=map) + finally: + wasyncore.select = old_select + self.assertEqual(result, None) + self.assertEqual(dummy_select.selected, [([0], [], [0], 0.0)]) + + def test_select_raises_non_EINTR(self): + # i read the mock.patch docs. nerp. + dummy_select = DummySelect(select.error(errno.EBADF)) + disp = DummyDispatcher() + disp.readable = lambda: True + map = {0: disp} + try: + from waitress import wasyncore + + old_select = wasyncore.select + wasyncore.select = dummy_select + self.assertRaises(select.error, self._callFUT, map=map) + finally: + wasyncore.select = old_select + self.assertEqual(dummy_select.selected, [([0], [], [0], 0.0)]) + + +class Test_poll2(unittest.TestCase): + def _callFUT(self, timeout=0.0, map=None): + from waitress.wasyncore import poll2 + + return poll2(timeout, map) + + def test_select_raises_EINTR(self): + # i read the mock.patch docs. nerp. + pollster = DummyPollster(exc=select.error(errno.EINTR)) + dummy_select = DummySelect(pollster=pollster) + disp = DummyDispatcher() + map = {0: disp} + try: + from waitress import wasyncore + + old_select = wasyncore.select + wasyncore.select = dummy_select + self._callFUT(map=map) + finally: + wasyncore.select = old_select + self.assertEqual(pollster.polled, [0.0]) + + def test_select_raises_non_EINTR(self): + # i read the mock.patch docs. nerp. + pollster = DummyPollster(exc=select.error(errno.EBADF)) + dummy_select = DummySelect(pollster=pollster) + disp = DummyDispatcher() + map = {0: disp} + try: + from waitress import wasyncore + + old_select = wasyncore.select + wasyncore.select = dummy_select + self.assertRaises(select.error, self._callFUT, map=map) + finally: + wasyncore.select = old_select + self.assertEqual(pollster.polled, [0.0]) + + +class Test_dispatcher(unittest.TestCase): + def _makeOne(self, sock=None, map=None): + from waitress.wasyncore import dispatcher + + return dispatcher(sock=sock, map=map) + + def test_unexpected_getpeername_exc(self): + sock = dummysocket() + + def getpeername(): + raise socket.error(errno.EBADF) + + map = {} + sock.getpeername = getpeername + self.assertRaises(socket.error, self._makeOne, sock=sock, map=map) + self.assertEqual(map, {}) + + def test___repr__accepting(self): + sock = dummysocket() + map = {} + inst = self._makeOne(sock=sock, map=map) + inst.accepting = True + inst.addr = ("localhost", 8080) + result = repr(inst) + expected = "<waitress.wasyncore.dispatcher listening localhost:8080 at" + self.assertEqual(result[: len(expected)], expected) + + def test___repr__connected(self): + sock = dummysocket() + map = {} + inst = self._makeOne(sock=sock, map=map) + inst.accepting = False + inst.connected = True + inst.addr = ("localhost", 8080) + result = repr(inst) + expected = "<waitress.wasyncore.dispatcher connected localhost:8080 at" + self.assertEqual(result[: len(expected)], expected) + + def test_set_reuse_addr_with_socketerror(self): + sock = dummysocket() + map = {} + + def setsockopt(*arg, **kw): + sock.errored = True + raise socket.error + + sock.setsockopt = setsockopt + sock.getsockopt = lambda *arg: 0 + inst = self._makeOne(sock=sock, map=map) + inst.set_reuse_addr() + self.assertTrue(sock.errored) + + def test_connect_raise_socket_error(self): + sock = dummysocket() + map = {} + sock.connect_ex = lambda *arg: 1 + inst = self._makeOne(sock=sock, map=map) + self.assertRaises(socket.error, inst.connect, 0) + + def test_accept_raise_TypeError(self): + sock = dummysocket() + map = {} + + def accept(*arg, **kw): + raise TypeError + + sock.accept = accept + inst = self._makeOne(sock=sock, map=map) + result = inst.accept() + self.assertEqual(result, None) + + def test_accept_raise_unexpected_socketerror(self): + sock = dummysocket() + map = {} + + def accept(*arg, **kw): + raise socket.error(122) + + sock.accept = accept + inst = self._makeOne(sock=sock, map=map) + self.assertRaises(socket.error, inst.accept) + + def test_send_raise_EWOULDBLOCK(self): + sock = dummysocket() + map = {} + + def send(*arg, **kw): + raise socket.error(errno.EWOULDBLOCK) + + sock.send = send + inst = self._makeOne(sock=sock, map=map) + result = inst.send("a") + self.assertEqual(result, 0) + + def test_send_raise_unexpected_socketerror(self): + sock = dummysocket() + map = {} + + def send(*arg, **kw): + raise socket.error(122) + + sock.send = send + inst = self._makeOne(sock=sock, map=map) + self.assertRaises(socket.error, inst.send, "a") + + def test_recv_raises_disconnect(self): + sock = dummysocket() + map = {} + + def recv(*arg, **kw): + raise socket.error(errno.ECONNRESET) + + def handle_close(): + inst.close_handled = True + + sock.recv = recv + inst = self._makeOne(sock=sock, map=map) + inst.handle_close = handle_close + result = inst.recv(1) + self.assertEqual(result, b"") + self.assertTrue(inst.close_handled) + + def test_close_raises_unknown_socket_error(self): + sock = dummysocket() + map = {} + + def close(): + raise socket.error(122) + + sock.close = close + inst = self._makeOne(sock=sock, map=map) + inst.del_channel = lambda: None + self.assertRaises(socket.error, inst.close) + + def test_handle_read_event_not_accepting_not_connected_connecting(self): + sock = dummysocket() + map = {} + inst = self._makeOne(sock=sock, map=map) + + def handle_connect_event(): + inst.connect_event_handled = True + + def handle_read(): + inst.read_handled = True + + inst.handle_connect_event = handle_connect_event + inst.handle_read = handle_read + inst.accepting = False + inst.connected = False + inst.connecting = True + inst.handle_read_event() + self.assertTrue(inst.connect_event_handled) + self.assertTrue(inst.read_handled) + + def test_handle_connect_event_getsockopt_returns_error(self): + sock = dummysocket() + sock.getsockopt = lambda *arg: 122 + map = {} + inst = self._makeOne(sock=sock, map=map) + self.assertRaises(socket.error, inst.handle_connect_event) + + def test_handle_expt_event_getsockopt_returns_error(self): + sock = dummysocket() + sock.getsockopt = lambda *arg: 122 + map = {} + inst = self._makeOne(sock=sock, map=map) + + def handle_close(): + inst.close_handled = True + + inst.handle_close = handle_close + inst.handle_expt_event() + self.assertTrue(inst.close_handled) + + def test_handle_write_event_while_accepting(self): + sock = dummysocket() + map = {} + inst = self._makeOne(sock=sock, map=map) + inst.accepting = True + result = inst.handle_write_event() + self.assertEqual(result, None) + + def test_handle_error_gardenpath(self): + sock = dummysocket() + map = {} + inst = self._makeOne(sock=sock, map=map) + + def handle_close(): + inst.close_handled = True + + def compact_traceback(*arg, **kw): + return None, None, None, None + + def log_info(self, *arg): + inst.logged_info = arg + + inst.handle_close = handle_close + inst.compact_traceback = compact_traceback + inst.log_info = log_info + inst.handle_error() + self.assertTrue(inst.close_handled) + self.assertEqual(inst.logged_info, ("error",)) + + def test_handle_close(self): + sock = dummysocket() + map = {} + inst = self._makeOne(sock=sock, map=map) + + def log_info(self, *arg): + inst.logged_info = arg + + def close(): + inst._closed = True + + inst.log_info = log_info + inst.close = close + inst.handle_close() + self.assertTrue(inst._closed) + + def test_handle_accepted(self): + sock = dummysocket() + map = {} + inst = self._makeOne(sock=sock, map=map) + inst.handle_accepted(sock, "1") + self.assertTrue(sock.closed) + + +class Test_dispatcher_with_send(unittest.TestCase): + def _makeOne(self, sock=None, map=None): + from waitress.wasyncore import dispatcher_with_send + + return dispatcher_with_send(sock=sock, map=map) + + def test_writable(self): + sock = dummysocket() + map = {} + inst = self._makeOne(sock=sock, map=map) + inst.out_buffer = b"123" + inst.connected = True + self.assertTrue(inst.writable()) + + +class Test_close_all(unittest.TestCase): + def _callFUT(self, map=None, ignore_all=False): + from waitress.wasyncore import close_all + + return close_all(map, ignore_all) + + def test_socketerror_on_close_ebadf(self): + disp = DummyDispatcher(exc=socket.error(errno.EBADF)) + map = {0: disp} + self._callFUT(map) + self.assertEqual(map, {}) + + def test_socketerror_on_close_non_ebadf(self): + disp = DummyDispatcher(exc=socket.error(errno.EAGAIN)) + map = {0: disp} + self.assertRaises(socket.error, self._callFUT, map) + + def test_reraised_exc_on_close(self): + disp = DummyDispatcher(exc=KeyboardInterrupt) + map = {0: disp} + self.assertRaises(KeyboardInterrupt, self._callFUT, map) + + def test_unknown_exc_on_close(self): + disp = DummyDispatcher(exc=RuntimeError) + map = {0: disp} + self.assertRaises(RuntimeError, self._callFUT, map) + + +class DummyDispatcher(object): + read_event_handled = False + write_event_handled = False + expt_event_handled = False + error_handled = False + close_handled = False + accepting = False + + def __init__(self, exc=None): + self.exc = exc + + def handle_read_event(self): + self.read_event_handled = True + if self.exc is not None: + raise self.exc + + def handle_write_event(self): + self.write_event_handled = True + if self.exc is not None: + raise self.exc + + def handle_expt_event(self): + self.expt_event_handled = True + if self.exc is not None: + raise self.exc + + def handle_error(self): + self.error_handled = True + + def handle_close(self): + self.close_handled = True + + def readable(self): + return False + + def writable(self): + return False + + def close(self): + if self.exc is not None: + raise self.exc + + +class DummyTime(object): + def __init__(self): + self.sleepvals = [] + + def sleep(self, val): + self.sleepvals.append(val) + + +class DummySelect(object): + error = select.error + + def __init__(self, exc=None, pollster=None): + self.selected = [] + self.pollster = pollster + self.exc = exc + + def select(self, *arg): + self.selected.append(arg) + if self.exc is not None: + raise self.exc + + def poll(self): + return self.pollster + + +class DummyPollster(object): + def __init__(self, exc=None): + self.polled = [] + self.exc = exc + + def poll(self, timeout): + self.polled.append(timeout) + if self.exc is not None: + raise self.exc + else: # pragma: no cover + return [] diff --git a/libs/waitress/trigger.py b/libs/waitress/trigger.py new file mode 100644 index 000000000..6a57c1275 --- /dev/null +++ b/libs/waitress/trigger.py @@ -0,0 +1,203 @@ +############################################################################## +# +# Copyright (c) 2001-2005 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE +# +############################################################################## + +import os +import socket +import errno +import threading + +from . import wasyncore + +# Wake up a call to select() running in the main thread. +# +# This is useful in a context where you are using Medusa's I/O +# subsystem to deliver data, but the data is generated by another +# thread. Normally, if Medusa is in the middle of a call to +# select(), new output data generated by another thread will have +# to sit until the call to select() either times out or returns. +# If the trigger is 'pulled' by another thread, it should immediately +# generate a READ event on the trigger object, which will force the +# select() invocation to return. +# +# A common use for this facility: letting Medusa manage I/O for a +# large number of connections; but routing each request through a +# thread chosen from a fixed-size thread pool. When a thread is +# acquired, a transaction is performed, but output data is +# accumulated into buffers that will be emptied more efficiently +# by Medusa. [picture a server that can process database queries +# rapidly, but doesn't want to tie up threads waiting to send data +# to low-bandwidth connections] +# +# The other major feature provided by this class is the ability to +# move work back into the main thread: if you call pull_trigger() +# with a thunk argument, when select() wakes up and receives the +# event it will call your thunk from within that thread. The main +# purpose of this is to remove the need to wrap thread locks around +# Medusa's data structures, which normally do not need them. [To see +# why this is true, imagine this scenario: A thread tries to push some +# new data onto a channel's outgoing data queue at the same time that +# the main thread is trying to remove some] + + +class _triggerbase(object): + """OS-independent base class for OS-dependent trigger class.""" + + kind = None # subclass must set to "pipe" or "loopback"; used by repr + + def __init__(self): + self._closed = False + + # `lock` protects the `thunks` list from being traversed and + # appended to simultaneously. + self.lock = threading.Lock() + + # List of no-argument callbacks to invoke when the trigger is + # pulled. These run in the thread running the wasyncore mainloop, + # regardless of which thread pulls the trigger. + self.thunks = [] + + def readable(self): + return True + + def writable(self): + return False + + def handle_connect(self): + pass + + def handle_close(self): + self.close() + + # Override the wasyncore close() method, because it doesn't know about + # (so can't close) all the gimmicks we have open. Subclass must + # supply a _close() method to do platform-specific closing work. _close() + # will be called iff we're not already closed. + def close(self): + if not self._closed: + self._closed = True + self.del_channel() + self._close() # subclass does OS-specific stuff + + def pull_trigger(self, thunk=None): + if thunk: + with self.lock: + self.thunks.append(thunk) + self._physical_pull() + + def handle_read(self): + try: + self.recv(8192) + except (OSError, socket.error): + return + with self.lock: + for thunk in self.thunks: + try: + thunk() + except: + nil, t, v, tbinfo = wasyncore.compact_traceback() + self.log_info( + "exception in trigger thunk: (%s:%s %s)" % (t, v, tbinfo) + ) + self.thunks = [] + + +if os.name == "posix": + + class trigger(_triggerbase, wasyncore.file_dispatcher): + kind = "pipe" + + def __init__(self, map): + _triggerbase.__init__(self) + r, self.trigger = self._fds = os.pipe() + wasyncore.file_dispatcher.__init__(self, r, map=map) + + def _close(self): + for fd in self._fds: + os.close(fd) + self._fds = [] + wasyncore.file_dispatcher.close(self) + + def _physical_pull(self): + os.write(self.trigger, b"x") + + +else: # pragma: no cover + # Windows version; uses just sockets, because a pipe isn't select'able + # on Windows. + + class trigger(_triggerbase, wasyncore.dispatcher): + kind = "loopback" + + def __init__(self, map): + _triggerbase.__init__(self) + + # Get a pair of connected sockets. The trigger is the 'w' + # end of the pair, which is connected to 'r'. 'r' is put + # in the wasyncore socket map. "pulling the trigger" then + # means writing something on w, which will wake up r. + + w = socket.socket() + # Disable buffering -- pulling the trigger sends 1 byte, + # and we want that sent immediately, to wake up wasyncore's + # select() ASAP. + w.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + count = 0 + while True: + count += 1 + # Bind to a local port; for efficiency, let the OS pick + # a free port for us. + # Unfortunately, stress tests showed that we may not + # be able to connect to that port ("Address already in + # use") despite that the OS picked it. This appears + # to be a race bug in the Windows socket implementation. + # So we loop until a connect() succeeds (almost always + # on the first try). See the long thread at + # http://mail.zope.org/pipermail/zope/2005-July/160433.html + # for hideous details. + a = socket.socket() + a.bind(("127.0.0.1", 0)) + connect_address = a.getsockname() # assigned (host, port) pair + a.listen(1) + try: + w.connect(connect_address) + break # success + except socket.error as detail: + if detail[0] != errno.WSAEADDRINUSE: + # "Address already in use" is the only error + # I've seen on two WinXP Pro SP2 boxes, under + # Pythons 2.3.5 and 2.4.1. + raise + # (10048, 'Address already in use') + # assert count <= 2 # never triggered in Tim's tests + if count >= 10: # I've never seen it go above 2 + a.close() + w.close() + raise RuntimeError("Cannot bind trigger!") + # Close `a` and try again. Note: I originally put a short + # sleep() here, but it didn't appear to help or hurt. + a.close() + + r, addr = a.accept() # r becomes wasyncore's (self.)socket + a.close() + self.trigger = w + wasyncore.dispatcher.__init__(self, r, map=map) + + def _close(self): + # self.socket is r, and self.trigger is w, from __init__ + self.socket.close() + self.trigger.close() + + def _physical_pull(self): + self.trigger.send(b"x") diff --git a/libs/waitress/utilities.py b/libs/waitress/utilities.py new file mode 100644 index 000000000..556bed20a --- /dev/null +++ b/libs/waitress/utilities.py @@ -0,0 +1,320 @@ +############################################################################## +# +# Copyright (c) 2004 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +"""Utility functions +""" + +import calendar +import errno +import logging +import os +import re +import stat +import time + +from .rfc7230 import OBS_TEXT, VCHAR + +logger = logging.getLogger("waitress") +queue_logger = logging.getLogger("waitress.queue") + + +def find_double_newline(s): + """Returns the position just after a double newline in the given string.""" + pos = s.find(b"\r\n\r\n") + + if pos >= 0: + pos += 4 + + return pos + + +def concat(*args): + return "".join(args) + + +def join(seq, field=" "): + return field.join(seq) + + +def group(s): + return "(" + s + ")" + + +short_days = ["sun", "mon", "tue", "wed", "thu", "fri", "sat"] +long_days = [ + "sunday", + "monday", + "tuesday", + "wednesday", + "thursday", + "friday", + "saturday", +] + +short_day_reg = group(join(short_days, "|")) +long_day_reg = group(join(long_days, "|")) + +daymap = {} + +for i in range(7): + daymap[short_days[i]] = i + daymap[long_days[i]] = i + +hms_reg = join(3 * [group("[0-9][0-9]")], ":") + +months = [ + "jan", + "feb", + "mar", + "apr", + "may", + "jun", + "jul", + "aug", + "sep", + "oct", + "nov", + "dec", +] + +monmap = {} + +for i in range(12): + monmap[months[i]] = i + 1 + +months_reg = group(join(months, "|")) + +# From draft-ietf-http-v11-spec-07.txt/3.3.1 +# Sun, 06 Nov 1994 08:49:37 GMT ; RFC 822, updated by RFC 1123 +# Sunday, 06-Nov-94 08:49:37 GMT ; RFC 850, obsoleted by RFC 1036 +# Sun Nov 6 08:49:37 1994 ; ANSI C's asctime() format + +# rfc822 format +rfc822_date = join( + [ + concat(short_day_reg, ","), # day + group("[0-9][0-9]?"), # date + months_reg, # month + group("[0-9]+"), # year + hms_reg, # hour minute second + "gmt", + ], + " ", +) + +rfc822_reg = re.compile(rfc822_date) + + +def unpack_rfc822(m): + g = m.group + + return ( + int(g(4)), # year + monmap[g(3)], # month + int(g(2)), # day + int(g(5)), # hour + int(g(6)), # minute + int(g(7)), # second + 0, + 0, + 0, + ) + + +# rfc850 format +rfc850_date = join( + [ + concat(long_day_reg, ","), + join([group("[0-9][0-9]?"), months_reg, group("[0-9]+")], "-"), + hms_reg, + "gmt", + ], + " ", +) + +rfc850_reg = re.compile(rfc850_date) +# they actually unpack the same way +def unpack_rfc850(m): + g = m.group + yr = g(4) + + if len(yr) == 2: + yr = "19" + yr + + return ( + int(yr), # year + monmap[g(3)], # month + int(g(2)), # day + int(g(5)), # hour + int(g(6)), # minute + int(g(7)), # second + 0, + 0, + 0, + ) + + +# parsdate.parsedate - ~700/sec. +# parse_http_date - ~1333/sec. + +weekdayname = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"] +monthname = [ + None, + "Jan", + "Feb", + "Mar", + "Apr", + "May", + "Jun", + "Jul", + "Aug", + "Sep", + "Oct", + "Nov", + "Dec", +] + + +def build_http_date(when): + year, month, day, hh, mm, ss, wd, y, z = time.gmtime(when) + + return "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( + weekdayname[wd], + day, + monthname[month], + year, + hh, + mm, + ss, + ) + + +def parse_http_date(d): + d = d.lower() + m = rfc850_reg.match(d) + + if m and m.end() == len(d): + retval = int(calendar.timegm(unpack_rfc850(m))) + else: + m = rfc822_reg.match(d) + + if m and m.end() == len(d): + retval = int(calendar.timegm(unpack_rfc822(m))) + else: + return 0 + + return retval + + +# RFC 5234 Appendix B.1 "Core Rules": +# VCHAR = %x21-7E +# ; visible (printing) characters +vchar_re = VCHAR + +# RFC 7230 Section 3.2.6 "Field Value Components": +# quoted-string = DQUOTE *( qdtext / quoted-pair ) DQUOTE +# qdtext = HTAB / SP /%x21 / %x23-5B / %x5D-7E / obs-text +# obs-text = %x80-FF +# quoted-pair = "\" ( HTAB / SP / VCHAR / obs-text ) +obs_text_re = OBS_TEXT + +# The '\\' between \x5b and \x5d is needed to escape \x5d (']') +qdtext_re = "[\t \x21\x23-\x5b\\\x5d-\x7e" + obs_text_re + "]" + +quoted_pair_re = r"\\" + "([\t " + vchar_re + obs_text_re + "])" +quoted_string_re = '"(?:(?:' + qdtext_re + ")|(?:" + quoted_pair_re + '))*"' + +quoted_string = re.compile(quoted_string_re) +quoted_pair = re.compile(quoted_pair_re) + + +def undquote(value): + if value.startswith('"') and value.endswith('"'): + # So it claims to be DQUOTE'ed, let's validate that + matches = quoted_string.match(value) + + if matches and matches.end() == len(value): + # Remove the DQUOTE's from the value + value = value[1:-1] + + # Remove all backslashes that are followed by a valid vchar or + # obs-text + value = quoted_pair.sub(r"\1", value) + + return value + elif not value.startswith('"') and not value.endswith('"'): + return value + + raise ValueError("Invalid quoting in value") + + +def cleanup_unix_socket(path): + try: + st = os.stat(path) + except OSError as exc: + if exc.errno != errno.ENOENT: + raise # pragma: no cover + else: + if stat.S_ISSOCK(st.st_mode): + try: + os.remove(path) + except OSError: # pragma: no cover + # avoid race condition error during tests + pass + + +class Error(object): + code = 500 + reason = "Internal Server Error" + + def __init__(self, body): + self.body = body + + def to_response(self): + status = "%s %s" % (self.code, self.reason) + body = "%s\r\n\r\n%s" % (self.reason, self.body) + tag = "\r\n\r\n(generated by waitress)" + body = body + tag + headers = [("Content-Type", "text/plain")] + + return status, headers, body + + def wsgi_response(self, environ, start_response): + status, headers, body = self.to_response() + start_response(status, headers) + yield body + + +class BadRequest(Error): + code = 400 + reason = "Bad Request" + + +class RequestHeaderFieldsTooLarge(BadRequest): + code = 431 + reason = "Request Header Fields Too Large" + + +class RequestEntityTooLarge(BadRequest): + code = 413 + reason = "Request Entity Too Large" + + +class InternalServerError(Error): + code = 500 + reason = "Internal Server Error" + + +class ServerNotImplemented(Error): + code = 501 + reason = "Not Implemented" diff --git a/libs/waitress/wasyncore.py b/libs/waitress/wasyncore.py new file mode 100644 index 000000000..09bcafaa0 --- /dev/null +++ b/libs/waitress/wasyncore.py @@ -0,0 +1,693 @@ +# -*- Mode: Python -*- +# Id: asyncore.py,v 2.51 2000/09/07 22:29:26 rushing Exp +# Author: Sam Rushing <[email protected]> + +# ====================================================================== +# Copyright 1996 by Sam Rushing +# +# All Rights Reserved +# +# Permission to use, copy, modify, and distribute this software and +# its documentation for any purpose and without fee is hereby +# granted, provided that the above copyright notice appear in all +# copies and that both that copyright notice and this permission +# notice appear in supporting documentation, and that the name of Sam +# Rushing not be used in advertising or publicity pertaining to +# distribution of the software without specific, written prior +# permission. +# +# SAM RUSHING DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, +# INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN +# NO EVENT SHALL SAM RUSHING BE LIABLE FOR ANY SPECIAL, INDIRECT OR +# CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS +# OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, +# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN +# CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +# ====================================================================== + +"""Basic infrastructure for asynchronous socket service clients and servers. + +There are only two ways to have a program on a single processor do "more +than one thing at a time". Multi-threaded programming is the simplest and +most popular way to do it, but there is another very different technique, +that lets you have nearly all the advantages of multi-threading, without +actually using multiple threads. it's really only practical if your program +is largely I/O bound. If your program is CPU bound, then pre-emptive +scheduled threads are probably what you really need. Network servers are +rarely CPU-bound, however. + +If your operating system supports the select() system call in its I/O +library (and nearly all do), then you can use it to juggle multiple +communication channels at once; doing other work while your I/O is taking +place in the "background." Although this strategy can seem strange and +complex, especially at first, it is in many ways easier to understand and +control than multi-threaded programming. The module documented here solves +many of the difficult problems for you, making the task of building +sophisticated high-performance network servers and clients a snap. + +NB: this is a fork of asyncore from the stdlib that we've (the waitress +developers) named 'wasyncore' to ensure forward compatibility, as asyncore +in the stdlib will be dropped soon. It is neither a copy of the 2.7 asyncore +nor the 3.X asyncore; it is a version compatible with either 2.7 or 3.X. +""" + +from . import compat +from . import utilities + +import logging +import select +import socket +import sys +import time +import warnings + +import os +from errno import ( + EALREADY, + EINPROGRESS, + EWOULDBLOCK, + ECONNRESET, + EINVAL, + ENOTCONN, + ESHUTDOWN, + EISCONN, + EBADF, + ECONNABORTED, + EPIPE, + EAGAIN, + EINTR, + errorcode, +) + +_DISCONNECTED = frozenset({ECONNRESET, ENOTCONN, ESHUTDOWN, ECONNABORTED, EPIPE, EBADF}) + +try: + socket_map +except NameError: + socket_map = {} + + +def _strerror(err): + try: + return os.strerror(err) + except (TypeError, ValueError, OverflowError, NameError): + return "Unknown error %s" % err + + +class ExitNow(Exception): + pass + + +_reraised_exceptions = (ExitNow, KeyboardInterrupt, SystemExit) + + +def read(obj): + try: + obj.handle_read_event() + except _reraised_exceptions: + raise + except: + obj.handle_error() + + +def write(obj): + try: + obj.handle_write_event() + except _reraised_exceptions: + raise + except: + obj.handle_error() + + +def _exception(obj): + try: + obj.handle_expt_event() + except _reraised_exceptions: + raise + except: + obj.handle_error() + + +def readwrite(obj, flags): + try: + if flags & select.POLLIN: + obj.handle_read_event() + if flags & select.POLLOUT: + obj.handle_write_event() + if flags & select.POLLPRI: + obj.handle_expt_event() + if flags & (select.POLLHUP | select.POLLERR | select.POLLNVAL): + obj.handle_close() + except socket.error as e: + if e.args[0] not in _DISCONNECTED: + obj.handle_error() + else: + obj.handle_close() + except _reraised_exceptions: + raise + except: + obj.handle_error() + + +def poll(timeout=0.0, map=None): + if map is None: # pragma: no cover + map = socket_map + if map: + r = [] + w = [] + e = [] + for fd, obj in list(map.items()): # list() call FBO py3 + is_r = obj.readable() + is_w = obj.writable() + if is_r: + r.append(fd) + # accepting sockets should not be writable + if is_w and not obj.accepting: + w.append(fd) + if is_r or is_w: + e.append(fd) + if [] == r == w == e: + time.sleep(timeout) + return + + try: + r, w, e = select.select(r, w, e, timeout) + except select.error as err: + if err.args[0] != EINTR: + raise + else: + return + + for fd in r: + obj = map.get(fd) + if obj is None: # pragma: no cover + continue + read(obj) + + for fd in w: + obj = map.get(fd) + if obj is None: # pragma: no cover + continue + write(obj) + + for fd in e: + obj = map.get(fd) + if obj is None: # pragma: no cover + continue + _exception(obj) + + +def poll2(timeout=0.0, map=None): + # Use the poll() support added to the select module in Python 2.0 + if map is None: # pragma: no cover + map = socket_map + if timeout is not None: + # timeout is in milliseconds + timeout = int(timeout * 1000) + pollster = select.poll() + if map: + for fd, obj in list(map.items()): + flags = 0 + if obj.readable(): + flags |= select.POLLIN | select.POLLPRI + # accepting sockets should not be writable + if obj.writable() and not obj.accepting: + flags |= select.POLLOUT + if flags: + pollster.register(fd, flags) + + try: + r = pollster.poll(timeout) + except select.error as err: + if err.args[0] != EINTR: + raise + r = [] + + for fd, flags in r: + obj = map.get(fd) + if obj is None: # pragma: no cover + continue + readwrite(obj, flags) + + +poll3 = poll2 # Alias for backward compatibility + + +def loop(timeout=30.0, use_poll=False, map=None, count=None): + if map is None: # pragma: no cover + map = socket_map + + if use_poll and hasattr(select, "poll"): + poll_fun = poll2 + else: + poll_fun = poll + + if count is None: # pragma: no cover + while map: + poll_fun(timeout, map) + + else: + while map and count > 0: + poll_fun(timeout, map) + count = count - 1 + + +def compact_traceback(): + t, v, tb = sys.exc_info() + tbinfo = [] + if not tb: # pragma: no cover + raise AssertionError("traceback does not exist") + while tb: + tbinfo.append( + ( + tb.tb_frame.f_code.co_filename, + tb.tb_frame.f_code.co_name, + str(tb.tb_lineno), + ) + ) + tb = tb.tb_next + + # just to be safe + del tb + + file, function, line = tbinfo[-1] + info = " ".join(["[%s|%s|%s]" % x for x in tbinfo]) + return (file, function, line), t, v, info + + +class dispatcher: + + debug = False + connected = False + accepting = False + connecting = False + closing = False + addr = None + ignore_log_types = frozenset({"warning"}) + logger = utilities.logger + compact_traceback = staticmethod(compact_traceback) # for testing + + def __init__(self, sock=None, map=None): + if map is None: # pragma: no cover + self._map = socket_map + else: + self._map = map + + self._fileno = None + + if sock: + # Set to nonblocking just to make sure for cases where we + # get a socket from a blocking source. + sock.setblocking(0) + self.set_socket(sock, map) + self.connected = True + # The constructor no longer requires that the socket + # passed be connected. + try: + self.addr = sock.getpeername() + except socket.error as err: + if err.args[0] in (ENOTCONN, EINVAL): + # To handle the case where we got an unconnected + # socket. + self.connected = False + else: + # The socket is broken in some unknown way, alert + # the user and remove it from the map (to prevent + # polling of broken sockets). + self.del_channel(map) + raise + else: + self.socket = None + + def __repr__(self): + status = [self.__class__.__module__ + "." + compat.qualname(self.__class__)] + if self.accepting and self.addr: + status.append("listening") + elif self.connected: + status.append("connected") + if self.addr is not None: + try: + status.append("%s:%d" % self.addr) + except TypeError: # pragma: no cover + status.append(repr(self.addr)) + return "<%s at %#x>" % (" ".join(status), id(self)) + + __str__ = __repr__ + + def add_channel(self, map=None): + # self.log_info('adding channel %s' % self) + if map is None: + map = self._map + map[self._fileno] = self + + def del_channel(self, map=None): + fd = self._fileno + if map is None: + map = self._map + if fd in map: + # self.log_info('closing channel %d:%s' % (fd, self)) + del map[fd] + self._fileno = None + + def create_socket(self, family=socket.AF_INET, type=socket.SOCK_STREAM): + self.family_and_type = family, type + sock = socket.socket(family, type) + sock.setblocking(0) + self.set_socket(sock) + + def set_socket(self, sock, map=None): + self.socket = sock + self._fileno = sock.fileno() + self.add_channel(map) + + def set_reuse_addr(self): + # try to re-use a server port if possible + try: + self.socket.setsockopt( + socket.SOL_SOCKET, + socket.SO_REUSEADDR, + self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) | 1, + ) + except socket.error: + pass + + # ================================================== + # predicates for select() + # these are used as filters for the lists of sockets + # to pass to select(). + # ================================================== + + def readable(self): + return True + + def writable(self): + return True + + # ================================================== + # socket object methods. + # ================================================== + + def listen(self, num): + self.accepting = True + if os.name == "nt" and num > 5: # pragma: no cover + num = 5 + return self.socket.listen(num) + + def bind(self, addr): + self.addr = addr + return self.socket.bind(addr) + + def connect(self, address): + self.connected = False + self.connecting = True + err = self.socket.connect_ex(address) + if ( + err in (EINPROGRESS, EALREADY, EWOULDBLOCK) + or err == EINVAL + and os.name == "nt" + ): # pragma: no cover + self.addr = address + return + if err in (0, EISCONN): + self.addr = address + self.handle_connect_event() + else: + raise socket.error(err, errorcode[err]) + + def accept(self): + # XXX can return either an address pair or None + try: + conn, addr = self.socket.accept() + except TypeError: + return None + except socket.error as why: + if why.args[0] in (EWOULDBLOCK, ECONNABORTED, EAGAIN): + return None + else: + raise + else: + return conn, addr + + def send(self, data): + try: + result = self.socket.send(data) + return result + except socket.error as why: + if why.args[0] == EWOULDBLOCK: + return 0 + elif why.args[0] in _DISCONNECTED: + self.handle_close() + return 0 + else: + raise + + def recv(self, buffer_size): + try: + data = self.socket.recv(buffer_size) + if not data: + # a closed connection is indicated by signaling + # a read condition, and having recv() return 0. + self.handle_close() + return b"" + else: + return data + except socket.error as why: + # winsock sometimes raises ENOTCONN + if why.args[0] in _DISCONNECTED: + self.handle_close() + return b"" + else: + raise + + def close(self): + self.connected = False + self.accepting = False + self.connecting = False + self.del_channel() + if self.socket is not None: + try: + self.socket.close() + except socket.error as why: + if why.args[0] not in (ENOTCONN, EBADF): + raise + + # log and log_info may be overridden to provide more sophisticated + # logging and warning methods. In general, log is for 'hit' logging + # and 'log_info' is for informational, warning and error logging. + + def log(self, message): + self.logger.log(logging.DEBUG, message) + + def log_info(self, message, type="info"): + severity = { + "info": logging.INFO, + "warning": logging.WARN, + "error": logging.ERROR, + } + self.logger.log(severity.get(type, logging.INFO), message) + + def handle_read_event(self): + if self.accepting: + # accepting sockets are never connected, they "spawn" new + # sockets that are connected + self.handle_accept() + elif not self.connected: + if self.connecting: + self.handle_connect_event() + self.handle_read() + else: + self.handle_read() + + def handle_connect_event(self): + err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + raise socket.error(err, _strerror(err)) + self.handle_connect() + self.connected = True + self.connecting = False + + def handle_write_event(self): + if self.accepting: + # Accepting sockets shouldn't get a write event. + # We will pretend it didn't happen. + return + + if not self.connected: + if self.connecting: + self.handle_connect_event() + self.handle_write() + + def handle_expt_event(self): + # handle_expt_event() is called if there might be an error on the + # socket, or if there is OOB data + # check for the error condition first + err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # we can get here when select.select() says that there is an + # exceptional condition on the socket + # since there is an error, we'll go ahead and close the socket + # like we would in a subclassed handle_read() that received no + # data + self.handle_close() + else: + self.handle_expt() + + def handle_error(self): + nil, t, v, tbinfo = self.compact_traceback() + + # sometimes a user repr method will crash. + try: + self_repr = repr(self) + except: # pragma: no cover + self_repr = "<__repr__(self) failed for object at %0x>" % id(self) + + self.log_info( + "uncaptured python exception, closing channel %s (%s:%s %s)" + % (self_repr, t, v, tbinfo), + "error", + ) + self.handle_close() + + def handle_expt(self): + self.log_info("unhandled incoming priority event", "warning") + + def handle_read(self): + self.log_info("unhandled read event", "warning") + + def handle_write(self): + self.log_info("unhandled write event", "warning") + + def handle_connect(self): + self.log_info("unhandled connect event", "warning") + + def handle_accept(self): + pair = self.accept() + if pair is not None: + self.handle_accepted(*pair) + + def handle_accepted(self, sock, addr): + sock.close() + self.log_info("unhandled accepted event", "warning") + + def handle_close(self): + self.log_info("unhandled close event", "warning") + self.close() + + +# --------------------------------------------------------------------------- +# adds simple buffered output capability, useful for simple clients. +# [for more sophisticated usage use asynchat.async_chat] +# --------------------------------------------------------------------------- + + +class dispatcher_with_send(dispatcher): + def __init__(self, sock=None, map=None): + dispatcher.__init__(self, sock, map) + self.out_buffer = b"" + + def initiate_send(self): + num_sent = 0 + num_sent = dispatcher.send(self, self.out_buffer[:65536]) + self.out_buffer = self.out_buffer[num_sent:] + + handle_write = initiate_send + + def writable(self): + return (not self.connected) or len(self.out_buffer) + + def send(self, data): + if self.debug: # pragma: no cover + self.log_info("sending %s" % repr(data)) + self.out_buffer = self.out_buffer + data + self.initiate_send() + + +def close_all(map=None, ignore_all=False): + if map is None: # pragma: no cover + map = socket_map + for x in list(map.values()): # list() FBO py3 + try: + x.close() + except socket.error as x: + if x.args[0] == EBADF: + pass + elif not ignore_all: + raise + except _reraised_exceptions: + raise + except: + if not ignore_all: + raise + map.clear() + + +# Asynchronous File I/O: +# +# After a little research (reading man pages on various unixen, and +# digging through the linux kernel), I've determined that select() +# isn't meant for doing asynchronous file i/o. +# Heartening, though - reading linux/mm/filemap.c shows that linux +# supports asynchronous read-ahead. So _MOST_ of the time, the data +# will be sitting in memory for us already when we go to read it. +# +# What other OS's (besides NT) support async file i/o? [VMS?] +# +# Regardless, this is useful for pipes, and stdin/stdout... + +if os.name == "posix": + + class file_wrapper: + # Here we override just enough to make a file + # look like a socket for the purposes of asyncore. + # The passed fd is automatically os.dup()'d + + def __init__(self, fd): + self.fd = os.dup(fd) + + def __del__(self): + if self.fd >= 0: + warnings.warn("unclosed file %r" % self, compat.ResourceWarning) + self.close() + + def recv(self, *args): + return os.read(self.fd, *args) + + def send(self, *args): + return os.write(self.fd, *args) + + def getsockopt(self, level, optname, buflen=None): # pragma: no cover + if level == socket.SOL_SOCKET and optname == socket.SO_ERROR and not buflen: + return 0 + raise NotImplementedError( + "Only asyncore specific behaviour " "implemented." + ) + + read = recv + write = send + + def close(self): + if self.fd < 0: + return + fd = self.fd + self.fd = -1 + os.close(fd) + + def fileno(self): + return self.fd + + class file_dispatcher(dispatcher): + def __init__(self, fd, map=None): + dispatcher.__init__(self, None, map) + self.connected = True + try: + fd = fd.fileno() + except AttributeError: + pass + self.set_file(fd) + # set it to non-blocking mode + compat.set_nonblocking(fd) + + def set_file(self, fd): + self.socket = file_wrapper(fd) + self._fileno = self.socket.fileno() + self.add_channel() |