summaryrefslogtreecommitdiffhomepage
path: root/libs/waitress
diff options
context:
space:
mode:
authorLouis Vézina <[email protected]>2020-04-15 00:02:44 -0400
committerLouis Vézina <[email protected]>2020-04-15 00:02:44 -0400
commit1b0e721a9d4b88bfbfea823798de92713d50826b (patch)
tree9139c5ca46fe1391f540a0190170edf08d30a588 /libs/waitress
parent02551f2486531cfdb83576ced380b72507fb2da0 (diff)
downloadbazarr-1b0e721a9d4b88bfbfea823798de92713d50826b.tar.gz
bazarr-1b0e721a9d4b88bfbfea823798de92713d50826b.zip
WIP
Diffstat (limited to 'libs/waitress')
-rw-r--r--libs/waitress/__init__.py45
-rw-r--r--libs/waitress/__main__.py3
-rw-r--r--libs/waitress/adjustments.py515
-rw-r--r--libs/waitress/buffers.py308
-rw-r--r--libs/waitress/channel.py414
-rw-r--r--libs/waitress/compat.py179
-rw-r--r--libs/waitress/parser.py413
-rw-r--r--libs/waitress/proxy_headers.py333
-rw-r--r--libs/waitress/receiver.py186
-rw-r--r--libs/waitress/rfc7230.py52
-rw-r--r--libs/waitress/runner.py286
-rw-r--r--libs/waitress/server.py436
-rw-r--r--libs/waitress/task.py570
-rw-r--r--libs/waitress/tests/__init__.py2
-rw-r--r--libs/waitress/tests/fixtureapps/__init__.py1
-rw-r--r--libs/waitress/tests/fixtureapps/badcl.py11
-rw-r--r--libs/waitress/tests/fixtureapps/echo.py56
-rw-r--r--libs/waitress/tests/fixtureapps/error.py21
-rw-r--r--libs/waitress/tests/fixtureapps/filewrapper.py93
-rw-r--r--libs/waitress/tests/fixtureapps/getline.py17
-rw-r--r--libs/waitress/tests/fixtureapps/groundhog1.jpgbin0 -> 45448 bytes
-rw-r--r--libs/waitress/tests/fixtureapps/nocl.py23
-rw-r--r--libs/waitress/tests/fixtureapps/runner.py6
-rw-r--r--libs/waitress/tests/fixtureapps/sleepy.py12
-rw-r--r--libs/waitress/tests/fixtureapps/toolarge.py7
-rw-r--r--libs/waitress/tests/fixtureapps/writecb.py14
-rw-r--r--libs/waitress/tests/test_adjustments.py481
-rw-r--r--libs/waitress/tests/test_buffers.py523
-rw-r--r--libs/waitress/tests/test_channel.py882
-rw-r--r--libs/waitress/tests/test_compat.py22
-rw-r--r--libs/waitress/tests/test_functional.py1667
-rw-r--r--libs/waitress/tests/test_init.py51
-rw-r--r--libs/waitress/tests/test_parser.py732
-rw-r--r--libs/waitress/tests/test_proxy_headers.py724
-rw-r--r--libs/waitress/tests/test_receiver.py242
-rw-r--r--libs/waitress/tests/test_regression.py147
-rw-r--r--libs/waitress/tests/test_runner.py191
-rw-r--r--libs/waitress/tests/test_server.py533
-rw-r--r--libs/waitress/tests/test_task.py1001
-rw-r--r--libs/waitress/tests/test_trigger.py111
-rw-r--r--libs/waitress/tests/test_utilities.py140
-rw-r--r--libs/waitress/tests/test_wasyncore.py1761
-rw-r--r--libs/waitress/trigger.py203
-rw-r--r--libs/waitress/utilities.py320
-rw-r--r--libs/waitress/wasyncore.py693
45 files changed, 14427 insertions, 0 deletions
diff --git a/libs/waitress/__init__.py b/libs/waitress/__init__.py
new file mode 100644
index 000000000..e6e5911a5
--- /dev/null
+++ b/libs/waitress/__init__.py
@@ -0,0 +1,45 @@
+from waitress.server import create_server
+import logging
+
+
+def serve(app, **kw):
+ _server = kw.pop("_server", create_server) # test shim
+ _quiet = kw.pop("_quiet", False) # test shim
+ _profile = kw.pop("_profile", False) # test shim
+ if not _quiet: # pragma: no cover
+ # idempotent if logging has already been set up
+ logging.basicConfig()
+ server = _server(app, **kw)
+ if not _quiet: # pragma: no cover
+ server.print_listen("Serving on http://{}:{}")
+ if _profile: # pragma: no cover
+ profile("server.run()", globals(), locals(), (), False)
+ else:
+ server.run()
+
+
+def serve_paste(app, global_conf, **kw):
+ serve(app, **kw)
+ return 0
+
+
+def profile(cmd, globals, locals, sort_order, callers): # pragma: no cover
+ # runs a command under the profiler and print profiling output at shutdown
+ import os
+ import profile
+ import pstats
+ import tempfile
+
+ fd, fn = tempfile.mkstemp()
+ try:
+ profile.runctx(cmd, globals, locals, fn)
+ stats = pstats.Stats(fn)
+ stats.strip_dirs()
+ # calls,time,cumulative and cumulative,calls,time are useful
+ stats.sort_stats(*(sort_order or ("cumulative", "calls", "time")))
+ if callers:
+ stats.print_callers(0.3)
+ else:
+ stats.print_stats(0.3)
+ finally:
+ os.remove(fn)
diff --git a/libs/waitress/__main__.py b/libs/waitress/__main__.py
new file mode 100644
index 000000000..9bcd07e59
--- /dev/null
+++ b/libs/waitress/__main__.py
@@ -0,0 +1,3 @@
+from waitress.runner import run # pragma nocover
+
+run() # pragma nocover
diff --git a/libs/waitress/adjustments.py b/libs/waitress/adjustments.py
new file mode 100644
index 000000000..93439eab8
--- /dev/null
+++ b/libs/waitress/adjustments.py
@@ -0,0 +1,515 @@
+##############################################################################
+#
+# Copyright (c) 2002 Zope Foundation and Contributors.
+# All Rights Reserved.
+#
+# This software is subject to the provisions of the Zope Public License,
+# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
+# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
+# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
+# FOR A PARTICULAR PURPOSE.
+#
+##############################################################################
+"""Adjustments are tunable parameters.
+"""
+import getopt
+import socket
+import warnings
+
+from .proxy_headers import PROXY_HEADERS
+from .compat import (
+ PY2,
+ WIN,
+ string_types,
+ HAS_IPV6,
+)
+
+truthy = frozenset(("t", "true", "y", "yes", "on", "1"))
+
+KNOWN_PROXY_HEADERS = frozenset(
+ header.lower().replace("_", "-") for header in PROXY_HEADERS
+)
+
+
+def asbool(s):
+ """ Return the boolean value ``True`` if the case-lowered value of string
+ input ``s`` is any of ``t``, ``true``, ``y``, ``on``, or ``1``, otherwise
+ return the boolean value ``False``. If ``s`` is the value ``None``,
+ return ``False``. If ``s`` is already one of the boolean values ``True``
+ or ``False``, return it."""
+ if s is None:
+ return False
+ if isinstance(s, bool):
+ return s
+ s = str(s).strip()
+ return s.lower() in truthy
+
+
+def asoctal(s):
+ """Convert the given octal string to an actual number."""
+ return int(s, 8)
+
+
+def aslist_cronly(value):
+ if isinstance(value, string_types):
+ value = filter(None, [x.strip() for x in value.splitlines()])
+ return list(value)
+
+
+def aslist(value):
+ """ Return a list of strings, separating the input based on newlines
+ and, if flatten=True (the default), also split on spaces within
+ each line."""
+ values = aslist_cronly(value)
+ result = []
+ for value in values:
+ subvalues = value.split()
+ result.extend(subvalues)
+ return result
+
+
+def asset(value):
+ return set(aslist(value))
+
+
+def slash_fixed_str(s):
+ s = s.strip()
+ if s:
+ # always have a leading slash, replace any number of leading slashes
+ # with a single slash, and strip any trailing slashes
+ s = "/" + s.lstrip("/").rstrip("/")
+ return s
+
+
+def str_iftruthy(s):
+ return str(s) if s else None
+
+
+def as_socket_list(sockets):
+ """Checks if the elements in the list are of type socket and
+ removes them if not."""
+ return [sock for sock in sockets if isinstance(sock, socket.socket)]
+
+
+class _str_marker(str):
+ pass
+
+
+class _int_marker(int):
+ pass
+
+
+class _bool_marker(object):
+ pass
+
+
+class Adjustments(object):
+ """This class contains tunable parameters.
+ """
+
+ _params = (
+ ("host", str),
+ ("port", int),
+ ("ipv4", asbool),
+ ("ipv6", asbool),
+ ("listen", aslist),
+ ("threads", int),
+ ("trusted_proxy", str_iftruthy),
+ ("trusted_proxy_count", int),
+ ("trusted_proxy_headers", asset),
+ ("log_untrusted_proxy_headers", asbool),
+ ("clear_untrusted_proxy_headers", asbool),
+ ("url_scheme", str),
+ ("url_prefix", slash_fixed_str),
+ ("backlog", int),
+ ("recv_bytes", int),
+ ("send_bytes", int),
+ ("outbuf_overflow", int),
+ ("outbuf_high_watermark", int),
+ ("inbuf_overflow", int),
+ ("connection_limit", int),
+ ("cleanup_interval", int),
+ ("channel_timeout", int),
+ ("log_socket_errors", asbool),
+ ("max_request_header_size", int),
+ ("max_request_body_size", int),
+ ("expose_tracebacks", asbool),
+ ("ident", str_iftruthy),
+ ("asyncore_loop_timeout", int),
+ ("asyncore_use_poll", asbool),
+ ("unix_socket", str),
+ ("unix_socket_perms", asoctal),
+ ("sockets", as_socket_list),
+ )
+
+ _param_map = dict(_params)
+
+ # hostname or IP address to listen on
+ host = _str_marker("0.0.0.0")
+
+ # TCP port to listen on
+ port = _int_marker(8080)
+
+ listen = ["{}:{}".format(host, port)]
+
+ # number of threads available for tasks
+ threads = 4
+
+ # Host allowed to overrid ``wsgi.url_scheme`` via header
+ trusted_proxy = None
+
+ # How many proxies we trust when chained
+ #
+ # X-Forwarded-For: 192.0.2.1, "[2001:db8::1]"
+ #
+ # or
+ #
+ # Forwarded: for=192.0.2.1, For="[2001:db8::1]"
+ #
+ # means there were (potentially), two proxies involved. If we know there is
+ # only 1 valid proxy, then that initial IP address "192.0.2.1" is not
+ # trusted and we completely ignore it. If there are two trusted proxies in
+ # the path, this value should be set to a higher number.
+ trusted_proxy_count = None
+
+ # Which of the proxy headers should we trust, this is a set where you
+ # either specify forwarded or one or more of forwarded-host, forwarded-for,
+ # forwarded-proto, forwarded-port.
+ trusted_proxy_headers = set()
+
+ # Would you like waitress to log warnings about untrusted proxy headers
+ # that were encountered while processing the proxy headers? This only makes
+ # sense to set when you have a trusted_proxy, and you expect the upstream
+ # proxy server to filter invalid headers
+ log_untrusted_proxy_headers = False
+
+ # Should waitress clear any proxy headers that are not deemed trusted from
+ # the environ? Change to True by default in 2.x
+ clear_untrusted_proxy_headers = _bool_marker
+
+ # default ``wsgi.url_scheme`` value
+ url_scheme = "http"
+
+ # default ``SCRIPT_NAME`` value, also helps reset ``PATH_INFO``
+ # when nonempty
+ url_prefix = ""
+
+ # server identity (sent in Server: header)
+ ident = "waitress"
+
+ # backlog is the value waitress passes to pass to socket.listen() This is
+ # the maximum number of incoming TCP connections that will wait in an OS
+ # queue for an available channel. From listen(1): "If a connection
+ # request arrives when the queue is full, the client may receive an error
+ # with an indication of ECONNREFUSED or, if the underlying protocol
+ # supports retransmission, the request may be ignored so that a later
+ # reattempt at connection succeeds."
+ backlog = 1024
+
+ # recv_bytes is the argument to pass to socket.recv().
+ recv_bytes = 8192
+
+ # deprecated setting controls how many bytes will be buffered before
+ # being flushed to the socket
+ send_bytes = 1
+
+ # A tempfile should be created if the pending output is larger than
+ # outbuf_overflow, which is measured in bytes. The default is 1MB. This
+ # is conservative.
+ outbuf_overflow = 1048576
+
+ # The app_iter will pause when pending output is larger than this value
+ # in bytes.
+ outbuf_high_watermark = 16777216
+
+ # A tempfile should be created if the pending input is larger than
+ # inbuf_overflow, which is measured in bytes. The default is 512K. This
+ # is conservative.
+ inbuf_overflow = 524288
+
+ # Stop creating new channels if too many are already active (integer).
+ # Each channel consumes at least one file descriptor, and, depending on
+ # the input and output body sizes, potentially up to three. The default
+ # is conservative, but you may need to increase the number of file
+ # descriptors available to the Waitress process on most platforms in
+ # order to safely change it (see ``ulimit -a`` "open files" setting).
+ # Note that this doesn't control the maximum number of TCP connections
+ # that can be waiting for processing; the ``backlog`` argument controls
+ # that.
+ connection_limit = 100
+
+ # Minimum seconds between cleaning up inactive channels.
+ cleanup_interval = 30
+
+ # Maximum seconds to leave an inactive connection open.
+ channel_timeout = 120
+
+ # Boolean: turn off to not log premature client disconnects.
+ log_socket_errors = True
+
+ # maximum number of bytes of all request headers combined (256K default)
+ max_request_header_size = 262144
+
+ # maximum number of bytes in request body (1GB default)
+ max_request_body_size = 1073741824
+
+ # expose tracebacks of uncaught exceptions
+ expose_tracebacks = False
+
+ # Path to a Unix domain socket to use.
+ unix_socket = None
+
+ # Path to a Unix domain socket to use.
+ unix_socket_perms = 0o600
+
+ # The socket options to set on receiving a connection. It is a list of
+ # (level, optname, value) tuples. TCP_NODELAY disables the Nagle
+ # algorithm for writes (Waitress already buffers its writes).
+ socket_options = [
+ (socket.SOL_TCP, socket.TCP_NODELAY, 1),
+ ]
+
+ # The asyncore.loop timeout value
+ asyncore_loop_timeout = 1
+
+ # The asyncore.loop flag to use poll() instead of the default select().
+ asyncore_use_poll = False
+
+ # Enable IPv4 by default
+ ipv4 = True
+
+ # Enable IPv6 by default
+ ipv6 = True
+
+ # A list of sockets that waitress will use to accept connections. They can
+ # be used for e.g. socket activation
+ sockets = []
+
+ def __init__(self, **kw):
+
+ if "listen" in kw and ("host" in kw or "port" in kw):
+ raise ValueError("host or port may not be set if listen is set.")
+
+ if "listen" in kw and "sockets" in kw:
+ raise ValueError("socket may not be set if listen is set.")
+
+ if "sockets" in kw and ("host" in kw or "port" in kw):
+ raise ValueError("host or port may not be set if sockets is set.")
+
+ if "sockets" in kw and "unix_socket" in kw:
+ raise ValueError("unix_socket may not be set if sockets is set")
+
+ if "unix_socket" in kw and ("host" in kw or "port" in kw):
+ raise ValueError("unix_socket may not be set if host or port is set")
+
+ if "unix_socket" in kw and "listen" in kw:
+ raise ValueError("unix_socket may not be set if listen is set")
+
+ if "send_bytes" in kw:
+ warnings.warn(
+ "send_bytes will be removed in a future release", DeprecationWarning,
+ )
+
+ for k, v in kw.items():
+ if k not in self._param_map:
+ raise ValueError("Unknown adjustment %r" % k)
+ setattr(self, k, self._param_map[k](v))
+
+ if not isinstance(self.host, _str_marker) or not isinstance(
+ self.port, _int_marker
+ ):
+ self.listen = ["{}:{}".format(self.host, self.port)]
+
+ enabled_families = socket.AF_UNSPEC
+
+ if not self.ipv4 and not HAS_IPV6: # pragma: no cover
+ raise ValueError(
+ "IPv4 is disabled but IPv6 is not available. Cowardly refusing to start."
+ )
+
+ if self.ipv4 and not self.ipv6:
+ enabled_families = socket.AF_INET
+
+ if not self.ipv4 and self.ipv6 and HAS_IPV6:
+ enabled_families = socket.AF_INET6
+
+ wanted_sockets = []
+ hp_pairs = []
+ for i in self.listen:
+ if ":" in i:
+ (host, port) = i.rsplit(":", 1)
+
+ # IPv6 we need to make sure that we didn't split on the address
+ if "]" in port: # pragma: nocover
+ (host, port) = (i, str(self.port))
+ else:
+ (host, port) = (i, str(self.port))
+
+ if WIN and PY2: # pragma: no cover
+ try:
+ # Try turning the port into an integer
+ port = int(port)
+
+ except Exception:
+ raise ValueError(
+ "Windows does not support service names instead of port numbers"
+ )
+
+ try:
+ if "[" in host and "]" in host: # pragma: nocover
+ host = host.strip("[").rstrip("]")
+
+ if host == "*":
+ host = None
+
+ for s in socket.getaddrinfo(
+ host,
+ port,
+ enabled_families,
+ socket.SOCK_STREAM,
+ socket.IPPROTO_TCP,
+ socket.AI_PASSIVE,
+ ):
+ (family, socktype, proto, _, sockaddr) = s
+
+ # It seems that getaddrinfo() may sometimes happily return
+ # the same result multiple times, this of course makes
+ # bind() very unhappy...
+ #
+ # Split on %, and drop the zone-index from the host in the
+ # sockaddr. Works around a bug in OS X whereby
+ # getaddrinfo() returns the same link-local interface with
+ # two different zone-indices (which makes no sense what so
+ # ever...) yet treats them equally when we attempt to bind().
+ if (
+ sockaddr[1] == 0
+ or (sockaddr[0].split("%", 1)[0], sockaddr[1]) not in hp_pairs
+ ):
+ wanted_sockets.append((family, socktype, proto, sockaddr))
+ hp_pairs.append((sockaddr[0].split("%", 1)[0], sockaddr[1]))
+
+ except Exception:
+ raise ValueError("Invalid host/port specified.")
+
+ if self.trusted_proxy_count is not None and self.trusted_proxy is None:
+ raise ValueError(
+ "trusted_proxy_count has no meaning without setting " "trusted_proxy"
+ )
+
+ elif self.trusted_proxy_count is None:
+ self.trusted_proxy_count = 1
+
+ if self.trusted_proxy_headers and self.trusted_proxy is None:
+ raise ValueError(
+ "trusted_proxy_headers has no meaning without setting " "trusted_proxy"
+ )
+
+ if self.trusted_proxy_headers:
+ self.trusted_proxy_headers = {
+ header.lower() for header in self.trusted_proxy_headers
+ }
+
+ unknown_values = self.trusted_proxy_headers - KNOWN_PROXY_HEADERS
+ if unknown_values:
+ raise ValueError(
+ "Received unknown trusted_proxy_headers value (%s) expected one "
+ "of %s"
+ % (", ".join(unknown_values), ", ".join(KNOWN_PROXY_HEADERS))
+ )
+
+ if (
+ "forwarded" in self.trusted_proxy_headers
+ and self.trusted_proxy_headers - {"forwarded"}
+ ):
+ raise ValueError(
+ "The Forwarded proxy header and the "
+ "X-Forwarded-{By,Host,Proto,Port,For} headers are mutually "
+ "exclusive. Can't trust both!"
+ )
+
+ elif self.trusted_proxy is not None:
+ warnings.warn(
+ "No proxy headers were marked as trusted, but trusted_proxy was set. "
+ "Implicitly trusting X-Forwarded-Proto for backwards compatibility. "
+ "This will be removed in future versions of waitress.",
+ DeprecationWarning,
+ )
+ self.trusted_proxy_headers = {"x-forwarded-proto"}
+
+ if self.clear_untrusted_proxy_headers is _bool_marker:
+ warnings.warn(
+ "In future versions of Waitress clear_untrusted_proxy_headers will be "
+ "set to True by default. You may opt-out by setting this value to "
+ "False, or opt-in explicitly by setting this to True.",
+ DeprecationWarning,
+ )
+ self.clear_untrusted_proxy_headers = False
+
+ self.listen = wanted_sockets
+
+ self.check_sockets(self.sockets)
+
+ @classmethod
+ def parse_args(cls, argv):
+ """Pre-parse command line arguments for input into __init__. Note that
+ this does not cast values into adjustment types, it just creates a
+ dictionary suitable for passing into __init__, where __init__ does the
+ casting.
+ """
+ long_opts = ["help", "call"]
+ for opt, cast in cls._params:
+ opt = opt.replace("_", "-")
+ if cast is asbool:
+ long_opts.append(opt)
+ long_opts.append("no-" + opt)
+ else:
+ long_opts.append(opt + "=")
+
+ kw = {
+ "help": False,
+ "call": False,
+ }
+
+ opts, args = getopt.getopt(argv, "", long_opts)
+ for opt, value in opts:
+ param = opt.lstrip("-").replace("-", "_")
+
+ if param == "listen":
+ kw["listen"] = "{} {}".format(kw.get("listen", ""), value)
+ continue
+
+ if param.startswith("no_"):
+ param = param[3:]
+ kw[param] = "false"
+ elif param in ("help", "call"):
+ kw[param] = True
+ elif cls._param_map[param] is asbool:
+ kw[param] = "true"
+ else:
+ kw[param] = value
+
+ return kw, args
+
+ @classmethod
+ def check_sockets(cls, sockets):
+ has_unix_socket = False
+ has_inet_socket = False
+ has_unsupported_socket = False
+ for sock in sockets:
+ if (
+ sock.family == socket.AF_INET or sock.family == socket.AF_INET6
+ ) and sock.type == socket.SOCK_STREAM:
+ has_inet_socket = True
+ elif (
+ hasattr(socket, "AF_UNIX")
+ and sock.family == socket.AF_UNIX
+ and sock.type == socket.SOCK_STREAM
+ ):
+ has_unix_socket = True
+ else:
+ has_unsupported_socket = True
+ if has_unix_socket and has_inet_socket:
+ raise ValueError("Internet and UNIX sockets may not be mixed.")
+ if has_unsupported_socket:
+ raise ValueError("Only Internet or UNIX stream sockets may be used.")
diff --git a/libs/waitress/buffers.py b/libs/waitress/buffers.py
new file mode 100644
index 000000000..04f6b4274
--- /dev/null
+++ b/libs/waitress/buffers.py
@@ -0,0 +1,308 @@
+##############################################################################
+#
+# Copyright (c) 2001-2004 Zope Foundation and Contributors.
+# All Rights Reserved.
+#
+# This software is subject to the provisions of the Zope Public License,
+# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
+# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
+# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
+# FOR A PARTICULAR PURPOSE.
+#
+##############################################################################
+"""Buffers
+"""
+from io import BytesIO
+
+# copy_bytes controls the size of temp. strings for shuffling data around.
+COPY_BYTES = 1 << 18 # 256K
+
+# The maximum number of bytes to buffer in a simple string.
+STRBUF_LIMIT = 8192
+
+
+class FileBasedBuffer(object):
+
+ remain = 0
+
+ def __init__(self, file, from_buffer=None):
+ self.file = file
+ if from_buffer is not None:
+ from_file = from_buffer.getfile()
+ read_pos = from_file.tell()
+ from_file.seek(0)
+ while True:
+ data = from_file.read(COPY_BYTES)
+ if not data:
+ break
+ file.write(data)
+ self.remain = int(file.tell() - read_pos)
+ from_file.seek(read_pos)
+ file.seek(read_pos)
+
+ def __len__(self):
+ return self.remain
+
+ def __nonzero__(self):
+ return True
+
+ __bool__ = __nonzero__ # py3
+
+ def append(self, s):
+ file = self.file
+ read_pos = file.tell()
+ file.seek(0, 2)
+ file.write(s)
+ file.seek(read_pos)
+ self.remain = self.remain + len(s)
+
+ def get(self, numbytes=-1, skip=False):
+ file = self.file
+ if not skip:
+ read_pos = file.tell()
+ if numbytes < 0:
+ # Read all
+ res = file.read()
+ else:
+ res = file.read(numbytes)
+ if skip:
+ self.remain -= len(res)
+ else:
+ file.seek(read_pos)
+ return res
+
+ def skip(self, numbytes, allow_prune=0):
+ if self.remain < numbytes:
+ raise ValueError(
+ "Can't skip %d bytes in buffer of %d bytes" % (numbytes, self.remain)
+ )
+ self.file.seek(numbytes, 1)
+ self.remain = self.remain - numbytes
+
+ def newfile(self):
+ raise NotImplementedError()
+
+ def prune(self):
+ file = self.file
+ if self.remain == 0:
+ read_pos = file.tell()
+ file.seek(0, 2)
+ sz = file.tell()
+ file.seek(read_pos)
+ if sz == 0:
+ # Nothing to prune.
+ return
+ nf = self.newfile()
+ while True:
+ data = file.read(COPY_BYTES)
+ if not data:
+ break
+ nf.write(data)
+ self.file = nf
+
+ def getfile(self):
+ return self.file
+
+ def close(self):
+ if hasattr(self.file, "close"):
+ self.file.close()
+ self.remain = 0
+
+
+class TempfileBasedBuffer(FileBasedBuffer):
+ def __init__(self, from_buffer=None):
+ FileBasedBuffer.__init__(self, self.newfile(), from_buffer)
+
+ def newfile(self):
+ from tempfile import TemporaryFile
+
+ return TemporaryFile("w+b")
+
+
+class BytesIOBasedBuffer(FileBasedBuffer):
+ def __init__(self, from_buffer=None):
+ if from_buffer is not None:
+ FileBasedBuffer.__init__(self, BytesIO(), from_buffer)
+ else:
+ # Shortcut. :-)
+ self.file = BytesIO()
+
+ def newfile(self):
+ return BytesIO()
+
+
+def _is_seekable(fp):
+ if hasattr(fp, "seekable"):
+ return fp.seekable()
+ return hasattr(fp, "seek") and hasattr(fp, "tell")
+
+
+class ReadOnlyFileBasedBuffer(FileBasedBuffer):
+ # used as wsgi.file_wrapper
+
+ def __init__(self, file, block_size=32768):
+ self.file = file
+ self.block_size = block_size # for __iter__
+
+ def prepare(self, size=None):
+ if _is_seekable(self.file):
+ start_pos = self.file.tell()
+ self.file.seek(0, 2)
+ end_pos = self.file.tell()
+ self.file.seek(start_pos)
+ fsize = end_pos - start_pos
+ if size is None:
+ self.remain = fsize
+ else:
+ self.remain = min(fsize, size)
+ return self.remain
+
+ def get(self, numbytes=-1, skip=False):
+ # never read more than self.remain (it can be user-specified)
+ if numbytes == -1 or numbytes > self.remain:
+ numbytes = self.remain
+ file = self.file
+ if not skip:
+ read_pos = file.tell()
+ res = file.read(numbytes)
+ if skip:
+ self.remain -= len(res)
+ else:
+ file.seek(read_pos)
+ return res
+
+ def __iter__(self): # called by task if self.filelike has no seek/tell
+ return self
+
+ def next(self):
+ val = self.file.read(self.block_size)
+ if not val:
+ raise StopIteration
+ return val
+
+ __next__ = next # py3
+
+ def append(self, s):
+ raise NotImplementedError
+
+
+class OverflowableBuffer(object):
+ """
+ This buffer implementation has four stages:
+ - No data
+ - Bytes-based buffer
+ - BytesIO-based buffer
+ - Temporary file storage
+ The first two stages are fastest for simple transfers.
+ """
+
+ overflowed = False
+ buf = None
+ strbuf = b"" # Bytes-based buffer.
+
+ def __init__(self, overflow):
+ # overflow is the maximum to be stored in a StringIO buffer.
+ self.overflow = overflow
+
+ def __len__(self):
+ buf = self.buf
+ if buf is not None:
+ # use buf.__len__ rather than len(buf) FBO of not getting
+ # OverflowError on Python 2
+ return buf.__len__()
+ else:
+ return self.strbuf.__len__()
+
+ def __nonzero__(self):
+ # use self.__len__ rather than len(self) FBO of not getting
+ # OverflowError on Python 2
+ return self.__len__() > 0
+
+ __bool__ = __nonzero__ # py3
+
+ def _create_buffer(self):
+ strbuf = self.strbuf
+ if len(strbuf) >= self.overflow:
+ self._set_large_buffer()
+ else:
+ self._set_small_buffer()
+ buf = self.buf
+ if strbuf:
+ buf.append(self.strbuf)
+ self.strbuf = b""
+ return buf
+
+ def _set_small_buffer(self):
+ self.buf = BytesIOBasedBuffer(self.buf)
+ self.overflowed = False
+
+ def _set_large_buffer(self):
+ self.buf = TempfileBasedBuffer(self.buf)
+ self.overflowed = True
+
+ def append(self, s):
+ buf = self.buf
+ if buf is None:
+ strbuf = self.strbuf
+ if len(strbuf) + len(s) < STRBUF_LIMIT:
+ self.strbuf = strbuf + s
+ return
+ buf = self._create_buffer()
+ buf.append(s)
+ # use buf.__len__ rather than len(buf) FBO of not getting
+ # OverflowError on Python 2
+ sz = buf.__len__()
+ if not self.overflowed:
+ if sz >= self.overflow:
+ self._set_large_buffer()
+
+ def get(self, numbytes=-1, skip=False):
+ buf = self.buf
+ if buf is None:
+ strbuf = self.strbuf
+ if not skip:
+ return strbuf
+ buf = self._create_buffer()
+ return buf.get(numbytes, skip)
+
+ def skip(self, numbytes, allow_prune=False):
+ buf = self.buf
+ if buf is None:
+ if allow_prune and numbytes == len(self.strbuf):
+ # We could slice instead of converting to
+ # a buffer, but that would eat up memory in
+ # large transfers.
+ self.strbuf = b""
+ return
+ buf = self._create_buffer()
+ buf.skip(numbytes, allow_prune)
+
+ def prune(self):
+ """
+ A potentially expensive operation that removes all data
+ already retrieved from the buffer.
+ """
+ buf = self.buf
+ if buf is None:
+ self.strbuf = b""
+ return
+ buf.prune()
+ if self.overflowed:
+ # use buf.__len__ rather than len(buf) FBO of not getting
+ # OverflowError on Python 2
+ sz = buf.__len__()
+ if sz < self.overflow:
+ # Revert to a faster buffer.
+ self._set_small_buffer()
+
+ def getfile(self):
+ buf = self.buf
+ if buf is None:
+ buf = self._create_buffer()
+ return buf.getfile()
+
+ def close(self):
+ buf = self.buf
+ if buf is not None:
+ buf.close()
diff --git a/libs/waitress/channel.py b/libs/waitress/channel.py
new file mode 100644
index 000000000..a8bc76f74
--- /dev/null
+++ b/libs/waitress/channel.py
@@ -0,0 +1,414 @@
+##############################################################################
+#
+# Copyright (c) 2001, 2002 Zope Foundation and Contributors.
+# All Rights Reserved.
+#
+# This software is subject to the provisions of the Zope Public License,
+# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
+# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
+# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
+# FOR A PARTICULAR PURPOSE.
+#
+##############################################################################
+import socket
+import threading
+import time
+import traceback
+
+from waitress.buffers import (
+ OverflowableBuffer,
+ ReadOnlyFileBasedBuffer,
+)
+
+from waitress.parser import HTTPRequestParser
+
+from waitress.task import (
+ ErrorTask,
+ WSGITask,
+)
+
+from waitress.utilities import InternalServerError
+
+from . import wasyncore
+
+
+class ClientDisconnected(Exception):
+ """ Raised when attempting to write to a closed socket."""
+
+
+class HTTPChannel(wasyncore.dispatcher, object):
+ """
+ Setting self.requests = [somerequest] prevents more requests from being
+ received until the out buffers have been flushed.
+
+ Setting self.requests = [] allows more requests to be received.
+ """
+
+ task_class = WSGITask
+ error_task_class = ErrorTask
+ parser_class = HTTPRequestParser
+
+ request = None # A request parser instance
+ last_activity = 0 # Time of last activity
+ will_close = False # set to True to close the socket.
+ close_when_flushed = False # set to True to close the socket when flushed
+ requests = () # currently pending requests
+ sent_continue = False # used as a latch after sending 100 continue
+ total_outbufs_len = 0 # total bytes ready to send
+ current_outbuf_count = 0 # total bytes written to current outbuf
+
+ #
+ # ASYNCHRONOUS METHODS (including __init__)
+ #
+
+ def __init__(
+ self, server, sock, addr, adj, map=None,
+ ):
+ self.server = server
+ self.adj = adj
+ self.outbufs = [OverflowableBuffer(adj.outbuf_overflow)]
+ self.creation_time = self.last_activity = time.time()
+ self.sendbuf_len = sock.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF)
+
+ # task_lock used to push/pop requests
+ self.task_lock = threading.Lock()
+ # outbuf_lock used to access any outbuf (expected to use an RLock)
+ self.outbuf_lock = threading.Condition()
+
+ wasyncore.dispatcher.__init__(self, sock, map=map)
+
+ # Don't let wasyncore.dispatcher throttle self.addr on us.
+ self.addr = addr
+
+ def writable(self):
+ # if there's data in the out buffer or we've been instructed to close
+ # the channel (possibly by our server maintenance logic), run
+ # handle_write
+ return self.total_outbufs_len or self.will_close or self.close_when_flushed
+
+ def handle_write(self):
+ # Precondition: there's data in the out buffer to be sent, or
+ # there's a pending will_close request
+ if not self.connected:
+ # we dont want to close the channel twice
+ return
+
+ # try to flush any pending output
+ if not self.requests:
+ # 1. There are no running tasks, so we don't need to try to lock
+ # the outbuf before sending
+ # 2. The data in the out buffer should be sent as soon as possible
+ # because it's either data left over from task output
+ # or a 100 Continue line sent within "received".
+ flush = self._flush_some
+ elif self.total_outbufs_len >= self.adj.send_bytes:
+ # 1. There's a running task, so we need to try to lock
+ # the outbuf before sending
+ # 2. Only try to send if the data in the out buffer is larger
+ # than self.adj_bytes to avoid TCP fragmentation
+ flush = self._flush_some_if_lockable
+ else:
+ # 1. There's not enough data in the out buffer to bother to send
+ # right now.
+ flush = None
+
+ if flush:
+ try:
+ flush()
+ except socket.error:
+ if self.adj.log_socket_errors:
+ self.logger.exception("Socket error")
+ self.will_close = True
+ except Exception:
+ self.logger.exception("Unexpected exception when flushing")
+ self.will_close = True
+
+ if self.close_when_flushed and not self.total_outbufs_len:
+ self.close_when_flushed = False
+ self.will_close = True
+
+ if self.will_close:
+ self.handle_close()
+
+ def readable(self):
+ # We might want to create a new task. We can only do this if:
+ # 1. We're not already about to close the connection.
+ # 2. There's no already currently running task(s).
+ # 3. There's no data in the output buffer that needs to be sent
+ # before we potentially create a new task.
+ return not (self.will_close or self.requests or self.total_outbufs_len)
+
+ def handle_read(self):
+ try:
+ data = self.recv(self.adj.recv_bytes)
+ except socket.error:
+ if self.adj.log_socket_errors:
+ self.logger.exception("Socket error")
+ self.handle_close()
+ return
+ if data:
+ self.last_activity = time.time()
+ self.received(data)
+
+ def received(self, data):
+ """
+ Receives input asynchronously and assigns one or more requests to the
+ channel.
+ """
+ # Preconditions: there's no task(s) already running
+ request = self.request
+ requests = []
+
+ if not data:
+ return False
+
+ while data:
+ if request is None:
+ request = self.parser_class(self.adj)
+ n = request.received(data)
+ if request.expect_continue and request.headers_finished:
+ # guaranteed by parser to be a 1.1 request
+ request.expect_continue = False
+ if not self.sent_continue:
+ # there's no current task, so we don't need to try to
+ # lock the outbuf to append to it.
+ outbuf_payload = b"HTTP/1.1 100 Continue\r\n\r\n"
+ self.outbufs[-1].append(outbuf_payload)
+ self.current_outbuf_count += len(outbuf_payload)
+ self.total_outbufs_len += len(outbuf_payload)
+ self.sent_continue = True
+ self._flush_some()
+ request.completed = False
+ if request.completed:
+ # The request (with the body) is ready to use.
+ self.request = None
+ if not request.empty:
+ requests.append(request)
+ request = None
+ else:
+ self.request = request
+ if n >= len(data):
+ break
+ data = data[n:]
+
+ if requests:
+ self.requests = requests
+ self.server.add_task(self)
+
+ return True
+
+ def _flush_some_if_lockable(self):
+ # Since our task may be appending to the outbuf, we try to acquire
+ # the lock, but we don't block if we can't.
+ if self.outbuf_lock.acquire(False):
+ try:
+ self._flush_some()
+
+ if self.total_outbufs_len < self.adj.outbuf_high_watermark:
+ self.outbuf_lock.notify()
+ finally:
+ self.outbuf_lock.release()
+
+ def _flush_some(self):
+ # Send as much data as possible to our client
+
+ sent = 0
+ dobreak = False
+
+ while True:
+ outbuf = self.outbufs[0]
+ # use outbuf.__len__ rather than len(outbuf) FBO of not getting
+ # OverflowError on 32-bit Python
+ outbuflen = outbuf.__len__()
+ while outbuflen > 0:
+ chunk = outbuf.get(self.sendbuf_len)
+ num_sent = self.send(chunk)
+ if num_sent:
+ outbuf.skip(num_sent, True)
+ outbuflen -= num_sent
+ sent += num_sent
+ self.total_outbufs_len -= num_sent
+ else:
+ # failed to write anything, break out entirely
+ dobreak = True
+ break
+ else:
+ # self.outbufs[-1] must always be a writable outbuf
+ if len(self.outbufs) > 1:
+ toclose = self.outbufs.pop(0)
+ try:
+ toclose.close()
+ except Exception:
+ self.logger.exception("Unexpected error when closing an outbuf")
+ else:
+ # caught up, done flushing for now
+ dobreak = True
+
+ if dobreak:
+ break
+
+ if sent:
+ self.last_activity = time.time()
+ return True
+
+ return False
+
+ def handle_close(self):
+ with self.outbuf_lock:
+ for outbuf in self.outbufs:
+ try:
+ outbuf.close()
+ except Exception:
+ self.logger.exception(
+ "Unknown exception while trying to close outbuf"
+ )
+ self.total_outbufs_len = 0
+ self.connected = False
+ self.outbuf_lock.notify()
+ wasyncore.dispatcher.close(self)
+
+ def add_channel(self, map=None):
+ """See wasyncore.dispatcher
+
+ This hook keeps track of opened channels.
+ """
+ wasyncore.dispatcher.add_channel(self, map)
+ self.server.active_channels[self._fileno] = self
+
+ def del_channel(self, map=None):
+ """See wasyncore.dispatcher
+
+ This hook keeps track of closed channels.
+ """
+ fd = self._fileno # next line sets this to None
+ wasyncore.dispatcher.del_channel(self, map)
+ ac = self.server.active_channels
+ if fd in ac:
+ del ac[fd]
+
+ #
+ # SYNCHRONOUS METHODS
+ #
+
+ def write_soon(self, data):
+ if not self.connected:
+ # if the socket is closed then interrupt the task so that it
+ # can cleanup possibly before the app_iter is exhausted
+ raise ClientDisconnected
+ if data:
+ # the async mainloop might be popping data off outbuf; we can
+ # block here waiting for it because we're in a task thread
+ with self.outbuf_lock:
+ self._flush_outbufs_below_high_watermark()
+ if not self.connected:
+ raise ClientDisconnected
+ num_bytes = len(data)
+ if data.__class__ is ReadOnlyFileBasedBuffer:
+ # they used wsgi.file_wrapper
+ self.outbufs.append(data)
+ nextbuf = OverflowableBuffer(self.adj.outbuf_overflow)
+ self.outbufs.append(nextbuf)
+ self.current_outbuf_count = 0
+ else:
+ if self.current_outbuf_count > self.adj.outbuf_high_watermark:
+ # rotate to a new buffer if the current buffer has hit
+ # the watermark to avoid it growing unbounded
+ nextbuf = OverflowableBuffer(self.adj.outbuf_overflow)
+ self.outbufs.append(nextbuf)
+ self.current_outbuf_count = 0
+ self.outbufs[-1].append(data)
+ self.current_outbuf_count += num_bytes
+ self.total_outbufs_len += num_bytes
+ if self.total_outbufs_len >= self.adj.send_bytes:
+ self.server.pull_trigger()
+ return num_bytes
+ return 0
+
+ def _flush_outbufs_below_high_watermark(self):
+ # check first to avoid locking if possible
+ if self.total_outbufs_len > self.adj.outbuf_high_watermark:
+ with self.outbuf_lock:
+ while (
+ self.connected
+ and self.total_outbufs_len > self.adj.outbuf_high_watermark
+ ):
+ self.server.pull_trigger()
+ self.outbuf_lock.wait()
+
+ def service(self):
+ """Execute all pending requests """
+ with self.task_lock:
+ while self.requests:
+ request = self.requests[0]
+ if request.error:
+ task = self.error_task_class(self, request)
+ else:
+ task = self.task_class(self, request)
+ try:
+ task.service()
+ except ClientDisconnected:
+ self.logger.info(
+ "Client disconnected while serving %s" % task.request.path
+ )
+ task.close_on_finish = True
+ except Exception:
+ self.logger.exception(
+ "Exception while serving %s" % task.request.path
+ )
+ if not task.wrote_header:
+ if self.adj.expose_tracebacks:
+ body = traceback.format_exc()
+ else:
+ body = (
+ "The server encountered an unexpected "
+ "internal server error"
+ )
+ req_version = request.version
+ req_headers = request.headers
+ request = self.parser_class(self.adj)
+ request.error = InternalServerError(body)
+ # copy some original request attributes to fulfill
+ # HTTP 1.1 requirements
+ request.version = req_version
+ try:
+ request.headers["CONNECTION"] = req_headers["CONNECTION"]
+ except KeyError:
+ pass
+ task = self.error_task_class(self, request)
+ try:
+ task.service() # must not fail
+ except ClientDisconnected:
+ task.close_on_finish = True
+ else:
+ task.close_on_finish = True
+ # we cannot allow self.requests to drop to empty til
+ # here; otherwise the mainloop gets confused
+ if task.close_on_finish:
+ self.close_when_flushed = True
+ for request in self.requests:
+ request.close()
+ self.requests = []
+ else:
+ # before processing a new request, ensure there is not too
+ # much data in the outbufs waiting to be flushed
+ # NB: currently readable() returns False while we are
+ # flushing data so we know no new requests will come in
+ # that we need to account for, otherwise it'd be better
+ # to do this check at the start of the request instead of
+ # at the end to account for consecutive service() calls
+ if len(self.requests) > 1:
+ self._flush_outbufs_below_high_watermark()
+ request = self.requests.pop(0)
+ request.close()
+
+ if self.connected:
+ self.server.pull_trigger()
+ self.last_activity = time.time()
+
+ def cancel(self):
+ """ Cancels all pending / active requests """
+ self.will_close = True
+ self.connected = False
+ self.last_activity = time.time()
+ self.requests = []
diff --git a/libs/waitress/compat.py b/libs/waitress/compat.py
new file mode 100644
index 000000000..fe72a7610
--- /dev/null
+++ b/libs/waitress/compat.py
@@ -0,0 +1,179 @@
+import os
+import sys
+import types
+import platform
+import warnings
+
+try:
+ import urlparse
+except ImportError: # pragma: no cover
+ from urllib import parse as urlparse
+
+try:
+ import fcntl
+except ImportError: # pragma: no cover
+ fcntl = None # windows
+
+# True if we are running on Python 3.
+PY2 = sys.version_info[0] == 2
+PY3 = sys.version_info[0] == 3
+
+# True if we are running on Windows
+WIN = platform.system() == "Windows"
+
+if PY3: # pragma: no cover
+ string_types = (str,)
+ integer_types = (int,)
+ class_types = (type,)
+ text_type = str
+ binary_type = bytes
+ long = int
+else:
+ string_types = (basestring,)
+ integer_types = (int, long)
+ class_types = (type, types.ClassType)
+ text_type = unicode
+ binary_type = str
+ long = long
+
+if PY3: # pragma: no cover
+ from urllib.parse import unquote_to_bytes
+
+ def unquote_bytes_to_wsgi(bytestring):
+ return unquote_to_bytes(bytestring).decode("latin-1")
+
+
+else:
+ from urlparse import unquote as unquote_to_bytes
+
+ def unquote_bytes_to_wsgi(bytestring):
+ return unquote_to_bytes(bytestring)
+
+
+def text_(s, encoding="latin-1", errors="strict"):
+ """ If ``s`` is an instance of ``binary_type``, return
+ ``s.decode(encoding, errors)``, otherwise return ``s``"""
+ if isinstance(s, binary_type):
+ return s.decode(encoding, errors)
+ return s # pragma: no cover
+
+
+if PY3: # pragma: no cover
+
+ def tostr(s):
+ if isinstance(s, text_type):
+ s = s.encode("latin-1")
+ return str(s, "latin-1", "strict")
+
+ def tobytes(s):
+ return bytes(s, "latin-1")
+
+
+else:
+ tostr = str
+
+ def tobytes(s):
+ return s
+
+
+if PY3: # pragma: no cover
+ import builtins
+
+ exec_ = getattr(builtins, "exec")
+
+ def reraise(tp, value, tb=None):
+ if value is None:
+ value = tp
+ if value.__traceback__ is not tb:
+ raise value.with_traceback(tb)
+ raise value
+
+ del builtins
+
+else: # pragma: no cover
+
+ def exec_(code, globs=None, locs=None):
+ """Execute code in a namespace."""
+ if globs is None:
+ frame = sys._getframe(1)
+ globs = frame.f_globals
+ if locs is None:
+ locs = frame.f_locals
+ del frame
+ elif locs is None:
+ locs = globs
+ exec("""exec code in globs, locs""")
+
+ exec_(
+ """def reraise(tp, value, tb=None):
+ raise tp, value, tb
+"""
+ )
+
+try:
+ from StringIO import StringIO as NativeIO
+except ImportError: # pragma: no cover
+ from io import StringIO as NativeIO
+
+try:
+ import httplib
+except ImportError: # pragma: no cover
+ from http import client as httplib
+
+try:
+ MAXINT = sys.maxint
+except AttributeError: # pragma: no cover
+ MAXINT = sys.maxsize
+
+
+# Fix for issue reported in https://github.com/Pylons/waitress/issues/138,
+# Python on Windows may not define IPPROTO_IPV6 in socket.
+import socket
+
+HAS_IPV6 = socket.has_ipv6
+
+if hasattr(socket, "IPPROTO_IPV6") and hasattr(socket, "IPV6_V6ONLY"):
+ IPPROTO_IPV6 = socket.IPPROTO_IPV6
+ IPV6_V6ONLY = socket.IPV6_V6ONLY
+else: # pragma: no cover
+ if WIN:
+ IPPROTO_IPV6 = 41
+ IPV6_V6ONLY = 27
+ else:
+ warnings.warn(
+ "OS does not support required IPv6 socket flags. This is requirement "
+ "for Waitress. Please open an issue at https://github.com/Pylons/waitress. "
+ "IPv6 support has been disabled.",
+ RuntimeWarning,
+ )
+ HAS_IPV6 = False
+
+
+def set_nonblocking(fd): # pragma: no cover
+ if PY3 and sys.version_info[1] >= 5:
+ os.set_blocking(fd, False)
+ elif fcntl is None:
+ raise RuntimeError("no fcntl module present")
+ else:
+ flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0)
+ flags = flags | os.O_NONBLOCK
+ fcntl.fcntl(fd, fcntl.F_SETFL, flags)
+
+
+if PY3:
+ ResourceWarning = ResourceWarning
+else:
+ ResourceWarning = UserWarning
+
+
+def qualname(cls):
+ if PY3:
+ return cls.__qualname__
+ return cls.__name__
+
+
+try:
+ import thread
+except ImportError:
+ # py3
+ import _thread as thread
diff --git a/libs/waitress/parser.py b/libs/waitress/parser.py
new file mode 100644
index 000000000..fef8a3da6
--- /dev/null
+++ b/libs/waitress/parser.py
@@ -0,0 +1,413 @@
+##############################################################################
+#
+# Copyright (c) 2001, 2002 Zope Foundation and Contributors.
+# All Rights Reserved.
+#
+# This software is subject to the provisions of the Zope Public License,
+# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
+# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
+# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
+# FOR A PARTICULAR PURPOSE.
+#
+##############################################################################
+"""HTTP Request Parser
+
+This server uses asyncore to accept connections and do initial
+processing but threads to do work.
+"""
+import re
+from io import BytesIO
+
+from waitress.buffers import OverflowableBuffer
+from waitress.compat import tostr, unquote_bytes_to_wsgi, urlparse
+from waitress.receiver import ChunkedReceiver, FixedStreamReceiver
+from waitress.utilities import (
+ BadRequest,
+ RequestEntityTooLarge,
+ RequestHeaderFieldsTooLarge,
+ ServerNotImplemented,
+ find_double_newline,
+)
+from .rfc7230 import HEADER_FIELD
+
+
+class ParsingError(Exception):
+ pass
+
+
+class TransferEncodingNotImplemented(Exception):
+ pass
+
+class HTTPRequestParser(object):
+ """A structure that collects the HTTP request.
+
+ Once the stream is completed, the instance is passed to
+ a server task constructor.
+ """
+
+ completed = False # Set once request is completed.
+ empty = False # Set if no request was made.
+ expect_continue = False # client sent "Expect: 100-continue" header
+ headers_finished = False # True when headers have been read
+ header_plus = b""
+ chunked = False
+ content_length = 0
+ header_bytes_received = 0
+ body_bytes_received = 0
+ body_rcv = None
+ version = "1.0"
+ error = None
+ connection_close = False
+
+ # Other attributes: first_line, header, headers, command, uri, version,
+ # path, query, fragment
+
+ def __init__(self, adj):
+ """
+ adj is an Adjustments object.
+ """
+ # headers is a mapping containing keys translated to uppercase
+ # with dashes turned into underscores.
+ self.headers = {}
+ self.adj = adj
+
+ def received(self, data):
+ """
+ Receives the HTTP stream for one request. Returns the number of
+ bytes consumed. Sets the completed flag once both the header and the
+ body have been received.
+ """
+ if self.completed:
+ return 0 # Can't consume any more.
+
+ datalen = len(data)
+ br = self.body_rcv
+ if br is None:
+ # In header.
+ max_header = self.adj.max_request_header_size
+
+ s = self.header_plus + data
+ index = find_double_newline(s)
+ consumed = 0
+
+ if index >= 0:
+ # If the headers have ended, and we also have part of the body
+ # message in data we still want to validate we aren't going
+ # over our limit for received headers.
+ self.header_bytes_received += index
+ consumed = datalen - (len(s) - index)
+ else:
+ self.header_bytes_received += datalen
+ consumed = datalen
+
+ # If the first line + headers is over the max length, we return a
+ # RequestHeaderFieldsTooLarge error rather than continuing to
+ # attempt to parse the headers.
+ if self.header_bytes_received >= max_header:
+ self.parse_header(b"GET / HTTP/1.0\r\n")
+ self.error = RequestHeaderFieldsTooLarge(
+ "exceeds max_header of %s" % max_header
+ )
+ self.completed = True
+ return consumed
+
+ if index >= 0:
+ # Header finished.
+ header_plus = s[:index]
+
+ # Remove preceeding blank lines. This is suggested by
+ # https://tools.ietf.org/html/rfc7230#section-3.5 to support
+ # clients sending an extra CR LF after another request when
+ # using HTTP pipelining
+ header_plus = header_plus.lstrip()
+
+ if not header_plus:
+ self.empty = True
+ self.completed = True
+ else:
+ try:
+ self.parse_header(header_plus)
+ except ParsingError as e:
+ self.error = BadRequest(e.args[0])
+ self.completed = True
+ except TransferEncodingNotImplemented as e:
+ self.error = ServerNotImplemented(e.args[0])
+ self.completed = True
+ else:
+ if self.body_rcv is None:
+ # no content-length header and not a t-e: chunked
+ # request
+ self.completed = True
+
+ if self.content_length > 0:
+ max_body = self.adj.max_request_body_size
+ # we won't accept this request if the content-length
+ # is too large
+
+ if self.content_length >= max_body:
+ self.error = RequestEntityTooLarge(
+ "exceeds max_body of %s" % max_body
+ )
+ self.completed = True
+ self.headers_finished = True
+
+ return consumed
+
+ # Header not finished yet.
+ self.header_plus = s
+
+ return datalen
+ else:
+ # In body.
+ consumed = br.received(data)
+ self.body_bytes_received += consumed
+ max_body = self.adj.max_request_body_size
+
+ if self.body_bytes_received >= max_body:
+ # this will only be raised during t-e: chunked requests
+ self.error = RequestEntityTooLarge("exceeds max_body of %s" % max_body)
+ self.completed = True
+ elif br.error:
+ # garbage in chunked encoding input probably
+ self.error = br.error
+ self.completed = True
+ elif br.completed:
+ # The request (with the body) is ready to use.
+ self.completed = True
+
+ if self.chunked:
+ # We've converted the chunked transfer encoding request
+ # body into a normal request body, so we know its content
+ # length; set the header here. We already popped the
+ # TRANSFER_ENCODING header in parse_header, so this will
+ # appear to the client to be an entirely non-chunked HTTP
+ # request with a valid content-length.
+ self.headers["CONTENT_LENGTH"] = str(br.__len__())
+
+ return consumed
+
+ def parse_header(self, header_plus):
+ """
+ Parses the header_plus block of text (the headers plus the
+ first line of the request).
+ """
+ index = header_plus.find(b"\r\n")
+ if index >= 0:
+ first_line = header_plus[:index].rstrip()
+ header = header_plus[index + 2 :]
+ else:
+ raise ParsingError("HTTP message header invalid")
+
+ if b"\r" in first_line or b"\n" in first_line:
+ raise ParsingError("Bare CR or LF found in HTTP message")
+
+ self.first_line = first_line # for testing
+
+ lines = get_header_lines(header)
+
+ headers = self.headers
+ for line in lines:
+ header = HEADER_FIELD.match(line)
+
+ if not header:
+ raise ParsingError("Invalid header")
+
+ key, value = header.group("name", "value")
+
+ if b"_" in key:
+ # TODO(xistence): Should we drop this request instead?
+ continue
+
+ # Only strip off whitespace that is considered valid whitespace by
+ # RFC7230, don't strip the rest
+ value = value.strip(b" \t")
+ key1 = tostr(key.upper().replace(b"-", b"_"))
+ # If a header already exists, we append subsequent values
+ # seperated by a comma. Applications already need to handle
+ # the comma seperated values, as HTTP front ends might do
+ # the concatenation for you (behavior specified in RFC2616).
+ try:
+ headers[key1] += tostr(b", " + value)
+ except KeyError:
+ headers[key1] = tostr(value)
+
+ # command, uri, version will be bytes
+ command, uri, version = crack_first_line(first_line)
+ version = tostr(version)
+ command = tostr(command)
+ self.command = command
+ self.version = version
+ (
+ self.proxy_scheme,
+ self.proxy_netloc,
+ self.path,
+ self.query,
+ self.fragment,
+ ) = split_uri(uri)
+ self.url_scheme = self.adj.url_scheme
+ connection = headers.get("CONNECTION", "")
+
+ if version == "1.0":
+ if connection.lower() != "keep-alive":
+ self.connection_close = True
+
+ if version == "1.1":
+ # since the server buffers data from chunked transfers and clients
+ # never need to deal with chunked requests, downstream clients
+ # should not see the HTTP_TRANSFER_ENCODING header; we pop it
+ # here
+ te = headers.pop("TRANSFER_ENCODING", "")
+
+ # NB: We can not just call bare strip() here because it will also
+ # remove other non-printable characters that we explicitly do not
+ # want removed so that if someone attempts to smuggle a request
+ # with these characters we don't fall prey to it.
+ #
+ # For example \x85 is stripped by default, but it is not considered
+ # valid whitespace to be stripped by RFC7230.
+ encodings = [
+ encoding.strip(" \t").lower() for encoding in te.split(",") if encoding
+ ]
+
+ for encoding in encodings:
+ # Out of the transfer-codings listed in
+ # https://tools.ietf.org/html/rfc7230#section-4 we only support
+ # chunked at this time.
+
+ # Note: the identity transfer-coding was removed in RFC7230:
+ # https://tools.ietf.org/html/rfc7230#appendix-A.2 and is thus
+ # not supported
+ if encoding not in {"chunked"}:
+ raise TransferEncodingNotImplemented(
+ "Transfer-Encoding requested is not supported."
+ )
+
+ if encodings and encodings[-1] == "chunked":
+ self.chunked = True
+ buf = OverflowableBuffer(self.adj.inbuf_overflow)
+ self.body_rcv = ChunkedReceiver(buf)
+ elif encodings: # pragma: nocover
+ raise TransferEncodingNotImplemented(
+ "Transfer-Encoding requested is not supported."
+ )
+
+ expect = headers.get("EXPECT", "").lower()
+ self.expect_continue = expect == "100-continue"
+ if connection.lower() == "close":
+ self.connection_close = True
+
+ if not self.chunked:
+ try:
+ cl = int(headers.get("CONTENT_LENGTH", 0))
+ except ValueError:
+ raise ParsingError("Content-Length is invalid")
+
+ self.content_length = cl
+ if cl > 0:
+ buf = OverflowableBuffer(self.adj.inbuf_overflow)
+ self.body_rcv = FixedStreamReceiver(cl, buf)
+
+ def get_body_stream(self):
+ body_rcv = self.body_rcv
+ if body_rcv is not None:
+ return body_rcv.getfile()
+ else:
+ return BytesIO()
+
+ def close(self):
+ body_rcv = self.body_rcv
+ if body_rcv is not None:
+ body_rcv.getbuf().close()
+
+
+def split_uri(uri):
+ # urlsplit handles byte input by returning bytes on py3, so
+ # scheme, netloc, path, query, and fragment are bytes
+
+ scheme = netloc = path = query = fragment = b""
+
+ # urlsplit below will treat this as a scheme-less netloc, thereby losing
+ # the original intent of the request. Here we shamelessly stole 4 lines of
+ # code from the CPython stdlib to parse out the fragment and query but
+ # leave the path alone. See
+ # https://github.com/python/cpython/blob/8c9e9b0cd5b24dfbf1424d1f253d02de80e8f5ef/Lib/urllib/parse.py#L465-L468
+ # and https://github.com/Pylons/waitress/issues/260
+
+ if uri[:2] == b"//":
+ path = uri
+
+ if b"#" in path:
+ path, fragment = path.split(b"#", 1)
+
+ if b"?" in path:
+ path, query = path.split(b"?", 1)
+ else:
+ try:
+ scheme, netloc, path, query, fragment = urlparse.urlsplit(uri)
+ except UnicodeError:
+ raise ParsingError("Bad URI")
+
+ return (
+ tostr(scheme),
+ tostr(netloc),
+ unquote_bytes_to_wsgi(path),
+ tostr(query),
+ tostr(fragment),
+ )
+
+
+def get_header_lines(header):
+ """
+ Splits the header into lines, putting multi-line headers together.
+ """
+ r = []
+ lines = header.split(b"\r\n")
+ for line in lines:
+ if not line:
+ continue
+
+ if b"\r" in line or b"\n" in line:
+ raise ParsingError('Bare CR or LF found in header line "%s"' % tostr(line))
+
+ if line.startswith((b" ", b"\t")):
+ if not r:
+ # https://corte.si/posts/code/pathod/pythonservers/index.html
+ raise ParsingError('Malformed header line "%s"' % tostr(line))
+ r[-1] += line
+ else:
+ r.append(line)
+ return r
+
+
+first_line_re = re.compile(
+ b"([^ ]+) "
+ b"((?:[^ :?#]+://[^ ?#/]*(?:[0-9]{1,5})?)?[^ ]+)"
+ b"(( HTTP/([0-9.]+))$|$)"
+)
+
+
+def crack_first_line(line):
+ m = first_line_re.match(line)
+ if m is not None and m.end() == len(line):
+ if m.group(3):
+ version = m.group(5)
+ else:
+ version = b""
+ method = m.group(1)
+
+ # the request methods that are currently defined are all uppercase:
+ # https://www.iana.org/assignments/http-methods/http-methods.xhtml and
+ # the request method is case sensitive according to
+ # https://tools.ietf.org/html/rfc7231#section-4.1
+
+ # By disallowing anything but uppercase methods we save poor
+ # unsuspecting souls from sending lowercase HTTP methods to waitress
+ # and having the request complete, while servers like nginx drop the
+ # request onto the floor.
+ if method != method.upper():
+ raise ParsingError('Malformed HTTP method "%s"' % tostr(method))
+ uri = m.group(2)
+ return method, uri, version
+ else:
+ return b"", b"", b""
diff --git a/libs/waitress/proxy_headers.py b/libs/waitress/proxy_headers.py
new file mode 100644
index 000000000..1df8b8eba
--- /dev/null
+++ b/libs/waitress/proxy_headers.py
@@ -0,0 +1,333 @@
+from collections import namedtuple
+
+from .utilities import logger, undquote, BadRequest
+
+
+PROXY_HEADERS = frozenset(
+ {
+ "X_FORWARDED_FOR",
+ "X_FORWARDED_HOST",
+ "X_FORWARDED_PROTO",
+ "X_FORWARDED_PORT",
+ "X_FORWARDED_BY",
+ "FORWARDED",
+ }
+)
+
+Forwarded = namedtuple("Forwarded", ["by", "for_", "host", "proto"])
+
+
+class MalformedProxyHeader(Exception):
+ def __init__(self, header, reason, value):
+ self.header = header
+ self.reason = reason
+ self.value = value
+ super(MalformedProxyHeader, self).__init__(header, reason, value)
+
+
+def proxy_headers_middleware(
+ app,
+ trusted_proxy=None,
+ trusted_proxy_count=1,
+ trusted_proxy_headers=None,
+ clear_untrusted=True,
+ log_untrusted=False,
+ logger=logger,
+):
+ def translate_proxy_headers(environ, start_response):
+ untrusted_headers = PROXY_HEADERS
+ remote_peer = environ["REMOTE_ADDR"]
+ if trusted_proxy == "*" or remote_peer == trusted_proxy:
+ try:
+ untrusted_headers = parse_proxy_headers(
+ environ,
+ trusted_proxy_count=trusted_proxy_count,
+ trusted_proxy_headers=trusted_proxy_headers,
+ logger=logger,
+ )
+ except MalformedProxyHeader as ex:
+ logger.warning(
+ 'Malformed proxy header "%s" from "%s": %s value: %s',
+ ex.header,
+ remote_peer,
+ ex.reason,
+ ex.value,
+ )
+ error = BadRequest('Header "{0}" malformed.'.format(ex.header))
+ return error.wsgi_response(environ, start_response)
+
+ # Clear out the untrusted proxy headers
+ if clear_untrusted:
+ clear_untrusted_headers(
+ environ, untrusted_headers, log_warning=log_untrusted, logger=logger,
+ )
+
+ return app(environ, start_response)
+
+ return translate_proxy_headers
+
+
+def parse_proxy_headers(
+ environ, trusted_proxy_count, trusted_proxy_headers, logger=logger,
+):
+ if trusted_proxy_headers is None:
+ trusted_proxy_headers = set()
+
+ forwarded_for = []
+ forwarded_host = forwarded_proto = forwarded_port = forwarded = ""
+ client_addr = None
+ untrusted_headers = set(PROXY_HEADERS)
+
+ def raise_for_multiple_values():
+ raise ValueError("Unspecified behavior for multiple values found in header",)
+
+ if "x-forwarded-for" in trusted_proxy_headers and "HTTP_X_FORWARDED_FOR" in environ:
+ try:
+ forwarded_for = []
+
+ for forward_hop in environ["HTTP_X_FORWARDED_FOR"].split(","):
+ forward_hop = forward_hop.strip()
+ forward_hop = undquote(forward_hop)
+
+ # Make sure that all IPv6 addresses are surrounded by brackets,
+ # this is assuming that the IPv6 representation here does not
+ # include a port number.
+
+ if "." not in forward_hop and (
+ ":" in forward_hop and forward_hop[-1] != "]"
+ ):
+ forwarded_for.append("[{}]".format(forward_hop))
+ else:
+ forwarded_for.append(forward_hop)
+
+ forwarded_for = forwarded_for[-trusted_proxy_count:]
+ client_addr = forwarded_for[0]
+
+ untrusted_headers.remove("X_FORWARDED_FOR")
+ except Exception as ex:
+ raise MalformedProxyHeader(
+ "X-Forwarded-For", str(ex), environ["HTTP_X_FORWARDED_FOR"],
+ )
+
+ if (
+ "x-forwarded-host" in trusted_proxy_headers
+ and "HTTP_X_FORWARDED_HOST" in environ
+ ):
+ try:
+ forwarded_host_multiple = []
+
+ for forward_host in environ["HTTP_X_FORWARDED_HOST"].split(","):
+ forward_host = forward_host.strip()
+ forward_host = undquote(forward_host)
+ forwarded_host_multiple.append(forward_host)
+
+ forwarded_host_multiple = forwarded_host_multiple[-trusted_proxy_count:]
+ forwarded_host = forwarded_host_multiple[0]
+
+ untrusted_headers.remove("X_FORWARDED_HOST")
+ except Exception as ex:
+ raise MalformedProxyHeader(
+ "X-Forwarded-Host", str(ex), environ["HTTP_X_FORWARDED_HOST"],
+ )
+
+ if "x-forwarded-proto" in trusted_proxy_headers:
+ try:
+ forwarded_proto = undquote(environ.get("HTTP_X_FORWARDED_PROTO", ""))
+ if "," in forwarded_proto:
+ raise_for_multiple_values()
+ untrusted_headers.remove("X_FORWARDED_PROTO")
+ except Exception as ex:
+ raise MalformedProxyHeader(
+ "X-Forwarded-Proto", str(ex), environ["HTTP_X_FORWARDED_PROTO"],
+ )
+
+ if "x-forwarded-port" in trusted_proxy_headers:
+ try:
+ forwarded_port = undquote(environ.get("HTTP_X_FORWARDED_PORT", ""))
+ if "," in forwarded_port:
+ raise_for_multiple_values()
+ untrusted_headers.remove("X_FORWARDED_PORT")
+ except Exception as ex:
+ raise MalformedProxyHeader(
+ "X-Forwarded-Port", str(ex), environ["HTTP_X_FORWARDED_PORT"],
+ )
+
+ if "x-forwarded-by" in trusted_proxy_headers:
+ # Waitress itself does not use X-Forwarded-By, but we can not
+ # remove it so it can get set in the environ
+ untrusted_headers.remove("X_FORWARDED_BY")
+
+ if "forwarded" in trusted_proxy_headers:
+ forwarded = environ.get("HTTP_FORWARDED", None)
+ untrusted_headers = PROXY_HEADERS - {"FORWARDED"}
+
+ # If the Forwarded header exists, it gets priority
+ if forwarded:
+ proxies = []
+ try:
+ for forwarded_element in forwarded.split(","):
+ # Remove whitespace that may have been introduced when
+ # appending a new entry
+ forwarded_element = forwarded_element.strip()
+
+ forwarded_for = forwarded_host = forwarded_proto = ""
+ forwarded_port = forwarded_by = ""
+
+ for pair in forwarded_element.split(";"):
+ pair = pair.lower()
+
+ if not pair:
+ continue
+
+ token, equals, value = pair.partition("=")
+
+ if equals != "=":
+ raise ValueError('Invalid forwarded-pair missing "="')
+
+ if token.strip() != token:
+ raise ValueError("Token may not be surrounded by whitespace")
+
+ if value.strip() != value:
+ raise ValueError("Value may not be surrounded by whitespace")
+
+ if token == "by":
+ forwarded_by = undquote(value)
+
+ elif token == "for":
+ forwarded_for = undquote(value)
+
+ elif token == "host":
+ forwarded_host = undquote(value)
+
+ elif token == "proto":
+ forwarded_proto = undquote(value)
+
+ else:
+ logger.warning("Unknown Forwarded token: %s" % token)
+
+ proxies.append(
+ Forwarded(
+ forwarded_by, forwarded_for, forwarded_host, forwarded_proto
+ )
+ )
+ except Exception as ex:
+ raise MalformedProxyHeader(
+ "Forwarded", str(ex), environ["HTTP_FORWARDED"],
+ )
+
+ proxies = proxies[-trusted_proxy_count:]
+
+ # Iterate backwards and fill in some values, the oldest entry that
+ # contains the information we expect is the one we use. We expect
+ # that intermediate proxies may re-write the host header or proto,
+ # but the oldest entry is the one that contains the information the
+ # client expects when generating URL's
+ #
+ # Forwarded: for="[2001:db8::1]";host="example.com:8443";proto="https"
+ # Forwarded: for=192.0.2.1;host="example.internal:8080"
+ #
+ # (After HTTPS header folding) should mean that we use as values:
+ #
+ # Host: example.com
+ # Protocol: https
+ # Port: 8443
+
+ for proxy in proxies[::-1]:
+ client_addr = proxy.for_ or client_addr
+ forwarded_host = proxy.host or forwarded_host
+ forwarded_proto = proxy.proto or forwarded_proto
+
+ if forwarded_proto:
+ forwarded_proto = forwarded_proto.lower()
+
+ if forwarded_proto not in {"http", "https"}:
+ raise MalformedProxyHeader(
+ "Forwarded Proto=" if forwarded else "X-Forwarded-Proto",
+ "unsupported proto value",
+ forwarded_proto,
+ )
+
+ # Set the URL scheme to the proxy provided proto
+ environ["wsgi.url_scheme"] = forwarded_proto
+
+ if not forwarded_port:
+ if forwarded_proto == "http":
+ forwarded_port = "80"
+
+ if forwarded_proto == "https":
+ forwarded_port = "443"
+
+ if forwarded_host:
+ if ":" in forwarded_host and forwarded_host[-1] != "]":
+ host, port = forwarded_host.rsplit(":", 1)
+ host, port = host.strip(), str(port)
+
+ # We trust the port in the Forwarded Host/X-Forwarded-Host over
+ # X-Forwarded-Port, or whatever we got from Forwarded
+ # Proto/X-Forwarded-Proto.
+
+ if forwarded_port != port:
+ forwarded_port = port
+
+ # We trust the proxy server's forwarded Host
+ environ["SERVER_NAME"] = host
+ environ["HTTP_HOST"] = forwarded_host
+ else:
+ # We trust the proxy server's forwarded Host
+ environ["SERVER_NAME"] = forwarded_host
+ environ["HTTP_HOST"] = forwarded_host
+
+ if forwarded_port:
+ if forwarded_port not in {"443", "80"}:
+ environ["HTTP_HOST"] = "{}:{}".format(
+ forwarded_host, forwarded_port
+ )
+ elif forwarded_port == "80" and environ["wsgi.url_scheme"] != "http":
+ environ["HTTP_HOST"] = "{}:{}".format(
+ forwarded_host, forwarded_port
+ )
+ elif forwarded_port == "443" and environ["wsgi.url_scheme"] != "https":
+ environ["HTTP_HOST"] = "{}:{}".format(
+ forwarded_host, forwarded_port
+ )
+
+ if forwarded_port:
+ environ["SERVER_PORT"] = str(forwarded_port)
+
+ if client_addr:
+ if ":" in client_addr and client_addr[-1] != "]":
+ addr, port = client_addr.rsplit(":", 1)
+ environ["REMOTE_ADDR"] = strip_brackets(addr.strip())
+ environ["REMOTE_PORT"] = port.strip()
+ else:
+ environ["REMOTE_ADDR"] = strip_brackets(client_addr.strip())
+ environ["REMOTE_HOST"] = environ["REMOTE_ADDR"]
+
+ return untrusted_headers
+
+
+def strip_brackets(addr):
+ if addr[0] == "[" and addr[-1] == "]":
+ return addr[1:-1]
+ return addr
+
+
+def clear_untrusted_headers(
+ environ, untrusted_headers, log_warning=False, logger=logger
+):
+ untrusted_headers_removed = [
+ header
+ for header in untrusted_headers
+ if environ.pop("HTTP_" + header, False) is not False
+ ]
+
+ if log_warning and untrusted_headers_removed:
+ untrusted_headers_removed = [
+ "-".join(x.capitalize() for x in header.split("_"))
+ for header in untrusted_headers_removed
+ ]
+ logger.warning(
+ "Removed untrusted headers (%s). Waitress recommends these be "
+ "removed upstream.",
+ ", ".join(untrusted_headers_removed),
+ )
diff --git a/libs/waitress/receiver.py b/libs/waitress/receiver.py
new file mode 100644
index 000000000..5d1568d51
--- /dev/null
+++ b/libs/waitress/receiver.py
@@ -0,0 +1,186 @@
+##############################################################################
+#
+# Copyright (c) 2001, 2002 Zope Foundation and Contributors.
+# All Rights Reserved.
+#
+# This software is subject to the provisions of the Zope Public License,
+# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
+# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
+# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
+# FOR A PARTICULAR PURPOSE.
+#
+##############################################################################
+"""Data Chunk Receiver
+"""
+
+from waitress.utilities import BadRequest, find_double_newline
+
+
+class FixedStreamReceiver(object):
+
+ # See IStreamConsumer
+ completed = False
+ error = None
+
+ def __init__(self, cl, buf):
+ self.remain = cl
+ self.buf = buf
+
+ def __len__(self):
+ return self.buf.__len__()
+
+ def received(self, data):
+ "See IStreamConsumer"
+ rm = self.remain
+
+ if rm < 1:
+ self.completed = True # Avoid any chance of spinning
+
+ return 0
+ datalen = len(data)
+
+ if rm <= datalen:
+ self.buf.append(data[:rm])
+ self.remain = 0
+ self.completed = True
+
+ return rm
+ else:
+ self.buf.append(data)
+ self.remain -= datalen
+
+ return datalen
+
+ def getfile(self):
+ return self.buf.getfile()
+
+ def getbuf(self):
+ return self.buf
+
+
+class ChunkedReceiver(object):
+
+ chunk_remainder = 0
+ validate_chunk_end = False
+ control_line = b""
+ chunk_end = b""
+ all_chunks_received = False
+ trailer = b""
+ completed = False
+ error = None
+
+ # max_control_line = 1024
+ # max_trailer = 65536
+
+ def __init__(self, buf):
+ self.buf = buf
+
+ def __len__(self):
+ return self.buf.__len__()
+
+ def received(self, s):
+ # Returns the number of bytes consumed.
+
+ if self.completed:
+ return 0
+ orig_size = len(s)
+
+ while s:
+ rm = self.chunk_remainder
+
+ if rm > 0:
+ # Receive the remainder of a chunk.
+ to_write = s[:rm]
+ self.buf.append(to_write)
+ written = len(to_write)
+ s = s[written:]
+
+ self.chunk_remainder -= written
+
+ if self.chunk_remainder == 0:
+ self.validate_chunk_end = True
+ elif self.validate_chunk_end:
+ s = self.chunk_end + s
+
+ pos = s.find(b"\r\n")
+
+ if pos < 0 and len(s) < 2:
+ self.chunk_end = s
+ s = b""
+ else:
+ self.chunk_end = b""
+ if pos == 0:
+ # Chop off the terminating CR LF from the chunk
+ s = s[2:]
+ else:
+ self.error = BadRequest("Chunk not properly terminated")
+ self.all_chunks_received = True
+
+ # Always exit this loop
+ self.validate_chunk_end = False
+ elif not self.all_chunks_received:
+ # Receive a control line.
+ s = self.control_line + s
+ pos = s.find(b"\r\n")
+
+ if pos < 0:
+ # Control line not finished.
+ self.control_line = s
+ s = b""
+ else:
+ # Control line finished.
+ line = s[:pos]
+ s = s[pos + 2 :]
+ self.control_line = b""
+ line = line.strip()
+
+ if line:
+ # Begin a new chunk.
+ semi = line.find(b";")
+
+ if semi >= 0:
+ # discard extension info.
+ line = line[:semi]
+ try:
+ sz = int(line.strip(), 16) # hexadecimal
+ except ValueError: # garbage in input
+ self.error = BadRequest("garbage in chunked encoding input")
+ sz = 0
+
+ if sz > 0:
+ # Start a new chunk.
+ self.chunk_remainder = sz
+ else:
+ # Finished chunks.
+ self.all_chunks_received = True
+ # else expect a control line.
+ else:
+ # Receive the trailer.
+ trailer = self.trailer + s
+
+ if trailer.startswith(b"\r\n"):
+ # No trailer.
+ self.completed = True
+
+ return orig_size - (len(trailer) - 2)
+ pos = find_double_newline(trailer)
+
+ if pos < 0:
+ # Trailer not finished.
+ self.trailer = trailer
+ s = b""
+ else:
+ # Finished the trailer.
+ self.completed = True
+ self.trailer = trailer[:pos]
+
+ return orig_size - (len(trailer) - pos)
+
+ return orig_size
+
+ def getfile(self):
+ return self.buf.getfile()
+
+ def getbuf(self):
+ return self.buf
diff --git a/libs/waitress/rfc7230.py b/libs/waitress/rfc7230.py
new file mode 100644
index 000000000..cd33c9064
--- /dev/null
+++ b/libs/waitress/rfc7230.py
@@ -0,0 +1,52 @@
+"""
+This contains a bunch of RFC7230 definitions and regular expressions that are
+needed to properly parse HTTP messages.
+"""
+
+import re
+
+from .compat import tobytes
+
+WS = "[ \t]"
+OWS = WS + "{0,}?"
+RWS = WS + "{1,}?"
+BWS = OWS
+
+# RFC 7230 Section 3.2.6 "Field Value Components":
+# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*"
+# / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~"
+# / DIGIT / ALPHA
+# obs-text = %x80-FF
+TCHAR = r"[!#$%&'*+\-.^_`|~0-9A-Za-z]"
+OBS_TEXT = r"\x80-\xff"
+
+TOKEN = TCHAR + "{1,}"
+
+# RFC 5234 Appendix B.1 "Core Rules":
+# VCHAR = %x21-7E
+# ; visible (printing) characters
+VCHAR = r"\x21-\x7e"
+
+# header-field = field-name ":" OWS field-value OWS
+# field-name = token
+# field-value = *( field-content / obs-fold )
+# field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ]
+# field-vchar = VCHAR / obs-text
+
+# Errata from: https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189
+# changes field-content to:
+#
+# field-content = field-vchar [ 1*( SP / HTAB / field-vchar )
+# field-vchar ]
+
+FIELD_VCHAR = "[" + VCHAR + OBS_TEXT + "]"
+# Field content is more greedy than the ABNF, in that it will match the whole value
+FIELD_CONTENT = FIELD_VCHAR + "+(?:[ \t]+" + FIELD_VCHAR + "+)*"
+# Which allows the field value here to just see if there is even a value in the first place
+FIELD_VALUE = "(?:" + FIELD_CONTENT + ")?"
+
+HEADER_FIELD = re.compile(
+ tobytes(
+ "^(?P<name>" + TOKEN + "):" + OWS + "(?P<value>" + FIELD_VALUE + ")" + OWS + "$"
+ )
+)
diff --git a/libs/waitress/runner.py b/libs/waitress/runner.py
new file mode 100644
index 000000000..2495084f0
--- /dev/null
+++ b/libs/waitress/runner.py
@@ -0,0 +1,286 @@
+##############################################################################
+#
+# Copyright (c) 2013 Zope Foundation and Contributors.
+# All Rights Reserved.
+#
+# This software is subject to the provisions of the Zope Public License,
+# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
+# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
+# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
+# FOR A PARTICULAR PURPOSE.
+#
+##############################################################################
+"""Command line runner.
+"""
+
+from __future__ import print_function, unicode_literals
+
+import getopt
+import os
+import os.path
+import re
+import sys
+
+from waitress import serve
+from waitress.adjustments import Adjustments
+
+HELP = """\
+Usage:
+
+ {0} [OPTS] MODULE:OBJECT
+
+Standard options:
+
+ --help
+ Show this information.
+
+ --call
+ Call the given object to get the WSGI application.
+
+ --host=ADDR
+ Hostname or IP address on which to listen, default is '0.0.0.0',
+ which means "all IP addresses on this host".
+
+ Note: May not be used together with --listen
+
+ --port=PORT
+ TCP port on which to listen, default is '8080'
+
+ Note: May not be used together with --listen
+
+ --listen=ip:port
+ Tell waitress to listen on an ip port combination.
+
+ Example:
+
+ --listen=127.0.0.1:8080
+ --listen=[::1]:8080
+ --listen=*:8080
+
+ This option may be used multiple times to listen on multiple sockets.
+ A wildcard for the hostname is also supported and will bind to both
+ IPv4/IPv6 depending on whether they are enabled or disabled.
+
+ --[no-]ipv4
+ Toggle on/off IPv4 support.
+
+ Example:
+
+ --no-ipv4
+
+ This will disable IPv4 socket support. This affects wildcard matching
+ when generating the list of sockets.
+
+ --[no-]ipv6
+ Toggle on/off IPv6 support.
+
+ Example:
+
+ --no-ipv6
+
+ This will turn on IPv6 socket support. This affects wildcard matching
+ when generating a list of sockets.
+
+ --unix-socket=PATH
+ Path of Unix socket. If a socket path is specified, a Unix domain
+ socket is made instead of the usual inet domain socket.
+
+ Not available on Windows.
+
+ --unix-socket-perms=PERMS
+ Octal permissions to use for the Unix domain socket, default is
+ '600'.
+
+ --url-scheme=STR
+ Default wsgi.url_scheme value, default is 'http'.
+
+ --url-prefix=STR
+ The ``SCRIPT_NAME`` WSGI environment value. Setting this to anything
+ except the empty string will cause the WSGI ``SCRIPT_NAME`` value to be
+ the value passed minus any trailing slashes you add, and it will cause
+ the ``PATH_INFO`` of any request which is prefixed with this value to
+ be stripped of the prefix. Default is the empty string.
+
+ --ident=STR
+ Server identity used in the 'Server' header in responses. Default
+ is 'waitress'.
+
+Tuning options:
+
+ --threads=INT
+ Number of threads used to process application logic, default is 4.
+
+ --backlog=INT
+ Connection backlog for the server. Default is 1024.
+
+ --recv-bytes=INT
+ Number of bytes to request when calling socket.recv(). Default is
+ 8192.
+
+ --send-bytes=INT
+ Number of bytes to send to socket.send(). Default is 18000.
+ Multiples of 9000 should avoid partly-filled TCP packets.
+
+ --outbuf-overflow=INT
+ A temporary file should be created if the pending output is larger
+ than this. Default is 1048576 (1MB).
+
+ --outbuf-high-watermark=INT
+ The app_iter will pause when pending output is larger than this value
+ and will resume once enough data is written to the socket to fall below
+ this threshold. Default is 16777216 (16MB).
+
+ --inbuf-overflow=INT
+ A temporary file should be created if the pending input is larger
+ than this. Default is 524288 (512KB).
+
+ --connection-limit=INT
+ Stop creating new channels if too many are already active.
+ Default is 100.
+
+ --cleanup-interval=INT
+ Minimum seconds between cleaning up inactive channels. Default
+ is 30. See '--channel-timeout'.
+
+ --channel-timeout=INT
+ Maximum number of seconds to leave inactive connections open.
+ Default is 120. 'Inactive' is defined as 'has received no data
+ from the client and has sent no data to the client'.
+
+ --[no-]log-socket-errors
+ Toggle whether premature client disconnect tracebacks ought to be
+ logged. On by default.
+
+ --max-request-header-size=INT
+ Maximum size of all request headers combined. Default is 262144
+ (256KB).
+
+ --max-request-body-size=INT
+ Maximum size of request body. Default is 1073741824 (1GB).
+
+ --[no-]expose-tracebacks
+ Toggle whether to expose tracebacks of unhandled exceptions to the
+ client. Off by default.
+
+ --asyncore-loop-timeout=INT
+ The timeout value in seconds passed to asyncore.loop(). Default is 1.
+
+ --asyncore-use-poll
+ The use_poll argument passed to ``asyncore.loop()``. Helps overcome
+ open file descriptors limit. Default is False.
+
+"""
+
+RUNNER_PATTERN = re.compile(
+ r"""
+ ^
+ (?P<module>
+ [a-z_][a-z0-9_]*(?:\.[a-z_][a-z0-9_]*)*
+ )
+ :
+ (?P<object>
+ [a-z_][a-z0-9_]*(?:\.[a-z_][a-z0-9_]*)*
+ )
+ $
+ """,
+ re.I | re.X,
+)
+
+
+def match(obj_name):
+ matches = RUNNER_PATTERN.match(obj_name)
+ if not matches:
+ raise ValueError("Malformed application '{0}'".format(obj_name))
+ return matches.group("module"), matches.group("object")
+
+
+def resolve(module_name, object_name):
+ """Resolve a named object in a module."""
+ # We cast each segments due to an issue that has been found to manifest
+ # in Python 2.6.6, but not 2.6.8, and may affect other revisions of Python
+ # 2.6 and 2.7, whereby ``__import__`` chokes if the list passed in the
+ # ``fromlist`` argument are unicode strings rather than 8-bit strings.
+ # The error triggered is "TypeError: Item in ``fromlist '' not a string".
+ # My guess is that this was fixed by checking against ``basestring``
+ # rather than ``str`` sometime between the release of 2.6.6 and 2.6.8,
+ # but I've yet to go over the commits. I know, however, that the NEWS
+ # file makes no mention of such a change to the behaviour of
+ # ``__import__``.
+ segments = [str(segment) for segment in object_name.split(".")]
+ obj = __import__(module_name, fromlist=segments[:1])
+ for segment in segments:
+ obj = getattr(obj, segment)
+ return obj
+
+
+def show_help(stream, name, error=None): # pragma: no cover
+ if error is not None:
+ print("Error: {0}\n".format(error), file=stream)
+ print(HELP.format(name), file=stream)
+
+
+def show_exception(stream):
+ exc_type, exc_value = sys.exc_info()[:2]
+ args = getattr(exc_value, "args", None)
+ print(
+ ("There was an exception ({0}) importing your module.\n").format(
+ exc_type.__name__,
+ ),
+ file=stream,
+ )
+ if args:
+ print("It had these arguments: ", file=stream)
+ for idx, arg in enumerate(args, start=1):
+ print("{0}. {1}\n".format(idx, arg), file=stream)
+ else:
+ print("It had no arguments.", file=stream)
+
+
+def run(argv=sys.argv, _serve=serve):
+ """Command line runner."""
+ name = os.path.basename(argv[0])
+
+ try:
+ kw, args = Adjustments.parse_args(argv[1:])
+ except getopt.GetoptError as exc:
+ show_help(sys.stderr, name, str(exc))
+ return 1
+
+ if kw["help"]:
+ show_help(sys.stdout, name)
+ return 0
+
+ if len(args) != 1:
+ show_help(sys.stderr, name, "Specify one application only")
+ return 1
+
+ try:
+ module, obj_name = match(args[0])
+ except ValueError as exc:
+ show_help(sys.stderr, name, str(exc))
+ show_exception(sys.stderr)
+ return 1
+
+ # Add the current directory onto sys.path
+ sys.path.append(os.getcwd())
+
+ # Get the WSGI function.
+ try:
+ app = resolve(module, obj_name)
+ except ImportError:
+ show_help(sys.stderr, name, "Bad module '{0}'".format(module))
+ show_exception(sys.stderr)
+ return 1
+ except AttributeError:
+ show_help(sys.stderr, name, "Bad object name '{0}'".format(obj_name))
+ show_exception(sys.stderr)
+ return 1
+ if kw["call"]:
+ app = app()
+
+ # These arguments are specific to the runner, not waitress itself.
+ del kw["call"], kw["help"]
+
+ _serve(app, **kw)
+ return 0
diff --git a/libs/waitress/server.py b/libs/waitress/server.py
new file mode 100644
index 000000000..ae566994f
--- /dev/null
+++ b/libs/waitress/server.py
@@ -0,0 +1,436 @@
+##############################################################################
+#
+# Copyright (c) 2001, 2002 Zope Foundation and Contributors.
+# All Rights Reserved.
+#
+# This software is subject to the provisions of the Zope Public License,
+# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
+# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
+# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
+# FOR A PARTICULAR PURPOSE.
+#
+##############################################################################
+
+import os
+import os.path
+import socket
+import time
+
+from waitress import trigger
+from waitress.adjustments import Adjustments
+from waitress.channel import HTTPChannel
+from waitress.task import ThreadedTaskDispatcher
+from waitress.utilities import cleanup_unix_socket
+
+from waitress.compat import (
+ IPPROTO_IPV6,
+ IPV6_V6ONLY,
+)
+from . import wasyncore
+from .proxy_headers import proxy_headers_middleware
+
+
+def create_server(
+ application,
+ map=None,
+ _start=True, # test shim
+ _sock=None, # test shim
+ _dispatcher=None, # test shim
+ **kw # adjustments
+):
+ """
+ if __name__ == '__main__':
+ server = create_server(app)
+ server.run()
+ """
+ if application is None:
+ raise ValueError(
+ 'The "app" passed to ``create_server`` was ``None``. You forgot '
+ "to return a WSGI app within your application."
+ )
+ adj = Adjustments(**kw)
+
+ if map is None: # pragma: nocover
+ map = {}
+
+ dispatcher = _dispatcher
+ if dispatcher is None:
+ dispatcher = ThreadedTaskDispatcher()
+ dispatcher.set_thread_count(adj.threads)
+
+ if adj.unix_socket and hasattr(socket, "AF_UNIX"):
+ sockinfo = (socket.AF_UNIX, socket.SOCK_STREAM, None, None)
+ return UnixWSGIServer(
+ application,
+ map,
+ _start,
+ _sock,
+ dispatcher=dispatcher,
+ adj=adj,
+ sockinfo=sockinfo,
+ )
+
+ effective_listen = []
+ last_serv = None
+ if not adj.sockets:
+ for sockinfo in adj.listen:
+ # When TcpWSGIServer is called, it registers itself in the map. This
+ # side-effect is all we need it for, so we don't store a reference to
+ # or return it to the user.
+ last_serv = TcpWSGIServer(
+ application,
+ map,
+ _start,
+ _sock,
+ dispatcher=dispatcher,
+ adj=adj,
+ sockinfo=sockinfo,
+ )
+ effective_listen.append(
+ (last_serv.effective_host, last_serv.effective_port)
+ )
+
+ for sock in adj.sockets:
+ sockinfo = (sock.family, sock.type, sock.proto, sock.getsockname())
+ if sock.family == socket.AF_INET or sock.family == socket.AF_INET6:
+ last_serv = TcpWSGIServer(
+ application,
+ map,
+ _start,
+ sock,
+ dispatcher=dispatcher,
+ adj=adj,
+ bind_socket=False,
+ sockinfo=sockinfo,
+ )
+ effective_listen.append(
+ (last_serv.effective_host, last_serv.effective_port)
+ )
+ elif hasattr(socket, "AF_UNIX") and sock.family == socket.AF_UNIX:
+ last_serv = UnixWSGIServer(
+ application,
+ map,
+ _start,
+ sock,
+ dispatcher=dispatcher,
+ adj=adj,
+ bind_socket=False,
+ sockinfo=sockinfo,
+ )
+ effective_listen.append(
+ (last_serv.effective_host, last_serv.effective_port)
+ )
+
+ # We are running a single server, so we can just return the last server,
+ # saves us from having to create one more object
+ if len(effective_listen) == 1:
+ # In this case we have no need to use a MultiSocketServer
+ return last_serv
+
+ # Return a class that has a utility function to print out the sockets it's
+ # listening on, and has a .run() function. All of the TcpWSGIServers
+ # registered themselves in the map above.
+ return MultiSocketServer(map, adj, effective_listen, dispatcher)
+
+
+# This class is only ever used if we have multiple listen sockets. It allows
+# the serve() API to call .run() which starts the wasyncore loop, and catches
+# SystemExit/KeyboardInterrupt so that it can atempt to cleanly shut down.
+class MultiSocketServer(object):
+ asyncore = wasyncore # test shim
+
+ def __init__(
+ self, map=None, adj=None, effective_listen=None, dispatcher=None,
+ ):
+ self.adj = adj
+ self.map = map
+ self.effective_listen = effective_listen
+ self.task_dispatcher = dispatcher
+
+ def print_listen(self, format_str): # pragma: nocover
+ for l in self.effective_listen:
+ l = list(l)
+
+ if ":" in l[0]:
+ l[0] = "[{}]".format(l[0])
+
+ print(format_str.format(*l))
+
+ def run(self):
+ try:
+ self.asyncore.loop(
+ timeout=self.adj.asyncore_loop_timeout,
+ map=self.map,
+ use_poll=self.adj.asyncore_use_poll,
+ )
+ except (SystemExit, KeyboardInterrupt):
+ self.close()
+
+ def close(self):
+ self.task_dispatcher.shutdown()
+ wasyncore.close_all(self.map)
+
+
+class BaseWSGIServer(wasyncore.dispatcher, object):
+
+ channel_class = HTTPChannel
+ next_channel_cleanup = 0
+ socketmod = socket # test shim
+ asyncore = wasyncore # test shim
+
+ def __init__(
+ self,
+ application,
+ map=None,
+ _start=True, # test shim
+ _sock=None, # test shim
+ dispatcher=None, # dispatcher
+ adj=None, # adjustments
+ sockinfo=None, # opaque object
+ bind_socket=True,
+ **kw
+ ):
+ if adj is None:
+ adj = Adjustments(**kw)
+
+ if adj.trusted_proxy or adj.clear_untrusted_proxy_headers:
+ # wrap the application to deal with proxy headers
+ # we wrap it here because webtest subclasses the TcpWSGIServer
+ # directly and thus doesn't run any code that's in create_server
+ application = proxy_headers_middleware(
+ application,
+ trusted_proxy=adj.trusted_proxy,
+ trusted_proxy_count=adj.trusted_proxy_count,
+ trusted_proxy_headers=adj.trusted_proxy_headers,
+ clear_untrusted=adj.clear_untrusted_proxy_headers,
+ log_untrusted=adj.log_untrusted_proxy_headers,
+ logger=self.logger,
+ )
+
+ if map is None:
+ # use a nonglobal socket map by default to hopefully prevent
+ # conflicts with apps and libs that use the wasyncore global socket
+ # map ala https://github.com/Pylons/waitress/issues/63
+ map = {}
+ if sockinfo is None:
+ sockinfo = adj.listen[0]
+
+ self.sockinfo = sockinfo
+ self.family = sockinfo[0]
+ self.socktype = sockinfo[1]
+ self.application = application
+ self.adj = adj
+ self.trigger = trigger.trigger(map)
+ if dispatcher is None:
+ dispatcher = ThreadedTaskDispatcher()
+ dispatcher.set_thread_count(self.adj.threads)
+
+ self.task_dispatcher = dispatcher
+ self.asyncore.dispatcher.__init__(self, _sock, map=map)
+ if _sock is None:
+ self.create_socket(self.family, self.socktype)
+ if self.family == socket.AF_INET6: # pragma: nocover
+ self.socket.setsockopt(IPPROTO_IPV6, IPV6_V6ONLY, 1)
+
+ self.set_reuse_addr()
+
+ if bind_socket:
+ self.bind_server_socket()
+
+ self.effective_host, self.effective_port = self.getsockname()
+ self.server_name = self.get_server_name(self.effective_host)
+ self.active_channels = {}
+ if _start:
+ self.accept_connections()
+
+ def bind_server_socket(self):
+ raise NotImplementedError # pragma: no cover
+
+ def get_server_name(self, ip):
+ """Given an IP or hostname, try to determine the server name."""
+
+ if not ip:
+ raise ValueError("Requires an IP to get the server name")
+
+ server_name = str(ip)
+
+ # If we are bound to all IP's, just return the current hostname, only
+ # fall-back to "localhost" if we fail to get the hostname
+ if server_name == "0.0.0.0" or server_name == "::":
+ try:
+ return str(self.socketmod.gethostname())
+ except (socket.error, UnicodeDecodeError): # pragma: no cover
+ # We also deal with UnicodeDecodeError in case of Windows with
+ # non-ascii hostname
+ return "localhost"
+
+ # Now let's try and convert the IP address to a proper hostname
+ try:
+ server_name = self.socketmod.gethostbyaddr(server_name)[0]
+ except (socket.error, UnicodeDecodeError): # pragma: no cover
+ # We also deal with UnicodeDecodeError in case of Windows with
+ # non-ascii hostname
+ pass
+
+ # If it contains an IPv6 literal, make sure to surround it with
+ # brackets
+ if ":" in server_name and "[" not in server_name:
+ server_name = "[{}]".format(server_name)
+
+ return server_name
+
+ def getsockname(self):
+ raise NotImplementedError # pragma: no cover
+
+ def accept_connections(self):
+ self.accepting = True
+ self.socket.listen(self.adj.backlog) # Get around asyncore NT limit
+
+ def add_task(self, task):
+ self.task_dispatcher.add_task(task)
+
+ def readable(self):
+ now = time.time()
+ if now >= self.next_channel_cleanup:
+ self.next_channel_cleanup = now + self.adj.cleanup_interval
+ self.maintenance(now)
+ return self.accepting and len(self._map) < self.adj.connection_limit
+
+ def writable(self):
+ return False
+
+ def handle_read(self):
+ pass
+
+ def handle_connect(self):
+ pass
+
+ def handle_accept(self):
+ try:
+ v = self.accept()
+ if v is None:
+ return
+ conn, addr = v
+ except socket.error:
+ # Linux: On rare occasions we get a bogus socket back from
+ # accept. socketmodule.c:makesockaddr complains that the
+ # address family is unknown. We don't want the whole server
+ # to shut down because of this.
+ if self.adj.log_socket_errors:
+ self.logger.warning("server accept() threw an exception", exc_info=True)
+ return
+ self.set_socket_options(conn)
+ addr = self.fix_addr(addr)
+ self.channel_class(self, conn, addr, self.adj, map=self._map)
+
+ def run(self):
+ try:
+ self.asyncore.loop(
+ timeout=self.adj.asyncore_loop_timeout,
+ map=self._map,
+ use_poll=self.adj.asyncore_use_poll,
+ )
+ except (SystemExit, KeyboardInterrupt):
+ self.task_dispatcher.shutdown()
+
+ def pull_trigger(self):
+ self.trigger.pull_trigger()
+
+ def set_socket_options(self, conn):
+ pass
+
+ def fix_addr(self, addr):
+ return addr
+
+ def maintenance(self, now):
+ """
+ Closes channels that have not had any activity in a while.
+
+ The timeout is configured through adj.channel_timeout (seconds).
+ """
+ cutoff = now - self.adj.channel_timeout
+ for channel in self.active_channels.values():
+ if (not channel.requests) and channel.last_activity < cutoff:
+ channel.will_close = True
+
+ def print_listen(self, format_str): # pragma: nocover
+ print(format_str.format(self.effective_host, self.effective_port))
+
+ def close(self):
+ self.trigger.close()
+ return wasyncore.dispatcher.close(self)
+
+
+class TcpWSGIServer(BaseWSGIServer):
+ def bind_server_socket(self):
+ (_, _, _, sockaddr) = self.sockinfo
+ self.bind(sockaddr)
+
+ def getsockname(self):
+ try:
+ return self.socketmod.getnameinfo(
+ self.socket.getsockname(), self.socketmod.NI_NUMERICSERV
+ )
+ except: # pragma: no cover
+ # This only happens on Linux because a DNS issue is considered a
+ # temporary failure that will raise (even when NI_NAMEREQD is not
+ # set). Instead we try again, but this time we just ask for the
+ # numerichost and the numericserv (port) and return those. It is
+ # better than nothing.
+ return self.socketmod.getnameinfo(
+ self.socket.getsockname(),
+ self.socketmod.NI_NUMERICHOST | self.socketmod.NI_NUMERICSERV,
+ )
+
+ def set_socket_options(self, conn):
+ for (level, optname, value) in self.adj.socket_options:
+ conn.setsockopt(level, optname, value)
+
+
+if hasattr(socket, "AF_UNIX"):
+
+ class UnixWSGIServer(BaseWSGIServer):
+ def __init__(
+ self,
+ application,
+ map=None,
+ _start=True, # test shim
+ _sock=None, # test shim
+ dispatcher=None, # dispatcher
+ adj=None, # adjustments
+ sockinfo=None, # opaque object
+ **kw
+ ):
+ if sockinfo is None:
+ sockinfo = (socket.AF_UNIX, socket.SOCK_STREAM, None, None)
+
+ super(UnixWSGIServer, self).__init__(
+ application,
+ map=map,
+ _start=_start,
+ _sock=_sock,
+ dispatcher=dispatcher,
+ adj=adj,
+ sockinfo=sockinfo,
+ **kw
+ )
+
+ def bind_server_socket(self):
+ cleanup_unix_socket(self.adj.unix_socket)
+ self.bind(self.adj.unix_socket)
+ if os.path.exists(self.adj.unix_socket):
+ os.chmod(self.adj.unix_socket, self.adj.unix_socket_perms)
+
+ def getsockname(self):
+ return ("unix", self.socket.getsockname())
+
+ def fix_addr(self, addr):
+ return ("localhost", None)
+
+ def get_server_name(self, ip):
+ return "localhost"
+
+
+# Compatibility alias.
+WSGIServer = TcpWSGIServer
diff --git a/libs/waitress/task.py b/libs/waitress/task.py
new file mode 100644
index 000000000..8e7ab1888
--- /dev/null
+++ b/libs/waitress/task.py
@@ -0,0 +1,570 @@
+##############################################################################
+#
+# Copyright (c) 2001, 2002 Zope Foundation and Contributors.
+# All Rights Reserved.
+#
+# This software is subject to the provisions of the Zope Public License,
+# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
+# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
+# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
+# FOR A PARTICULAR PURPOSE.
+#
+##############################################################################
+
+import socket
+import sys
+import threading
+import time
+from collections import deque
+
+from .buffers import ReadOnlyFileBasedBuffer
+from .compat import reraise, tobytes
+from .utilities import build_http_date, logger, queue_logger
+
+rename_headers = { # or keep them without the HTTP_ prefix added
+ "CONTENT_LENGTH": "CONTENT_LENGTH",
+ "CONTENT_TYPE": "CONTENT_TYPE",
+}
+
+hop_by_hop = frozenset(
+ (
+ "connection",
+ "keep-alive",
+ "proxy-authenticate",
+ "proxy-authorization",
+ "te",
+ "trailers",
+ "transfer-encoding",
+ "upgrade",
+ )
+)
+
+
+class ThreadedTaskDispatcher(object):
+ """A Task Dispatcher that creates a thread for each task.
+ """
+
+ stop_count = 0 # Number of threads that will stop soon.
+ active_count = 0 # Number of currently active threads
+ logger = logger
+ queue_logger = queue_logger
+
+ def __init__(self):
+ self.threads = set()
+ self.queue = deque()
+ self.lock = threading.Lock()
+ self.queue_cv = threading.Condition(self.lock)
+ self.thread_exit_cv = threading.Condition(self.lock)
+
+ def start_new_thread(self, target, args):
+ t = threading.Thread(target=target, name="waitress", args=args)
+ t.daemon = True
+ t.start()
+
+ def handler_thread(self, thread_no):
+ while True:
+ with self.lock:
+ while not self.queue and self.stop_count == 0:
+ # Mark ourselves as idle before waiting to be
+ # woken up, then we will once again be active
+ self.active_count -= 1
+ self.queue_cv.wait()
+ self.active_count += 1
+
+ if self.stop_count > 0:
+ self.active_count -= 1
+ self.stop_count -= 1
+ self.threads.discard(thread_no)
+ self.thread_exit_cv.notify()
+ break
+
+ task = self.queue.popleft()
+ try:
+ task.service()
+ except BaseException:
+ self.logger.exception("Exception when servicing %r", task)
+
+ def set_thread_count(self, count):
+ with self.lock:
+ threads = self.threads
+ thread_no = 0
+ running = len(threads) - self.stop_count
+ while running < count:
+ # Start threads.
+ while thread_no in threads:
+ thread_no = thread_no + 1
+ threads.add(thread_no)
+ running += 1
+ self.start_new_thread(self.handler_thread, (thread_no,))
+ self.active_count += 1
+ thread_no = thread_no + 1
+ if running > count:
+ # Stop threads.
+ self.stop_count += running - count
+ self.queue_cv.notify_all()
+
+ def add_task(self, task):
+ with self.lock:
+ self.queue.append(task)
+ self.queue_cv.notify()
+ queue_size = len(self.queue)
+ idle_threads = len(self.threads) - self.stop_count - self.active_count
+ if queue_size > idle_threads:
+ self.queue_logger.warning(
+ "Task queue depth is %d", queue_size - idle_threads
+ )
+
+ def shutdown(self, cancel_pending=True, timeout=5):
+ self.set_thread_count(0)
+ # Ensure the threads shut down.
+ threads = self.threads
+ expiration = time.time() + timeout
+ with self.lock:
+ while threads:
+ if time.time() >= expiration:
+ self.logger.warning("%d thread(s) still running", len(threads))
+ break
+ self.thread_exit_cv.wait(0.1)
+ if cancel_pending:
+ # Cancel remaining tasks.
+ queue = self.queue
+ if len(queue) > 0:
+ self.logger.warning("Canceling %d pending task(s)", len(queue))
+ while queue:
+ task = queue.popleft()
+ task.cancel()
+ self.queue_cv.notify_all()
+ return True
+ return False
+
+
+class Task(object):
+ close_on_finish = False
+ status = "200 OK"
+ wrote_header = False
+ start_time = 0
+ content_length = None
+ content_bytes_written = 0
+ logged_write_excess = False
+ logged_write_no_body = False
+ complete = False
+ chunked_response = False
+ logger = logger
+
+ def __init__(self, channel, request):
+ self.channel = channel
+ self.request = request
+ self.response_headers = []
+ version = request.version
+ if version not in ("1.0", "1.1"):
+ # fall back to a version we support.
+ version = "1.0"
+ self.version = version
+
+ def service(self):
+ try:
+ try:
+ self.start()
+ self.execute()
+ self.finish()
+ except socket.error:
+ self.close_on_finish = True
+ if self.channel.adj.log_socket_errors:
+ raise
+ finally:
+ pass
+
+ @property
+ def has_body(self):
+ return not (
+ self.status.startswith("1")
+ or self.status.startswith("204")
+ or self.status.startswith("304")
+ )
+
+ def build_response_header(self):
+ version = self.version
+ # Figure out whether the connection should be closed.
+ connection = self.request.headers.get("CONNECTION", "").lower()
+ response_headers = []
+ content_length_header = None
+ date_header = None
+ server_header = None
+ connection_close_header = None
+
+ for (headername, headerval) in self.response_headers:
+ headername = "-".join([x.capitalize() for x in headername.split("-")])
+
+ if headername == "Content-Length":
+ if self.has_body:
+ content_length_header = headerval
+ else:
+ continue # pragma: no cover
+
+ if headername == "Date":
+ date_header = headerval
+
+ if headername == "Server":
+ server_header = headerval
+
+ if headername == "Connection":
+ connection_close_header = headerval.lower()
+ # replace with properly capitalized version
+ response_headers.append((headername, headerval))
+
+ if (
+ content_length_header is None
+ and self.content_length is not None
+ and self.has_body
+ ):
+ content_length_header = str(self.content_length)
+ response_headers.append(("Content-Length", content_length_header))
+
+ def close_on_finish():
+ if connection_close_header is None:
+ response_headers.append(("Connection", "close"))
+ self.close_on_finish = True
+
+ if version == "1.0":
+ if connection == "keep-alive":
+ if not content_length_header:
+ close_on_finish()
+ else:
+ response_headers.append(("Connection", "Keep-Alive"))
+ else:
+ close_on_finish()
+
+ elif version == "1.1":
+ if connection == "close":
+ close_on_finish()
+
+ if not content_length_header:
+ # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length
+ # for any response with a status code of 1xx, 204 or 304.
+
+ if self.has_body:
+ response_headers.append(("Transfer-Encoding", "chunked"))
+ self.chunked_response = True
+
+ if not self.close_on_finish:
+ close_on_finish()
+
+ # under HTTP 1.1 keep-alive is default, no need to set the header
+ else:
+ raise AssertionError("neither HTTP/1.0 or HTTP/1.1")
+
+ # Set the Server and Date field, if not yet specified. This is needed
+ # if the server is used as a proxy.
+ ident = self.channel.server.adj.ident
+
+ if not server_header:
+ if ident:
+ response_headers.append(("Server", ident))
+ else:
+ response_headers.append(("Via", ident or "waitress"))
+
+ if not date_header:
+ response_headers.append(("Date", build_http_date(self.start_time)))
+
+ self.response_headers = response_headers
+
+ first_line = "HTTP/%s %s" % (self.version, self.status)
+ # NB: sorting headers needs to preserve same-named-header order
+ # as per RFC 2616 section 4.2; thus the key=lambda x: x[0] here;
+ # rely on stable sort to keep relative position of same-named headers
+ next_lines = [
+ "%s: %s" % hv for hv in sorted(self.response_headers, key=lambda x: x[0])
+ ]
+ lines = [first_line] + next_lines
+ res = "%s\r\n\r\n" % "\r\n".join(lines)
+
+ return tobytes(res)
+
+ def remove_content_length_header(self):
+ response_headers = []
+
+ for header_name, header_value in self.response_headers:
+ if header_name.lower() == "content-length":
+ continue # pragma: nocover
+ response_headers.append((header_name, header_value))
+
+ self.response_headers = response_headers
+
+ def start(self):
+ self.start_time = time.time()
+
+ def finish(self):
+ if not self.wrote_header:
+ self.write(b"")
+ if self.chunked_response:
+ # not self.write, it will chunk it!
+ self.channel.write_soon(b"0\r\n\r\n")
+
+ def write(self, data):
+ if not self.complete:
+ raise RuntimeError("start_response was not called before body written")
+ channel = self.channel
+ if not self.wrote_header:
+ rh = self.build_response_header()
+ channel.write_soon(rh)
+ self.wrote_header = True
+
+ if data and self.has_body:
+ towrite = data
+ cl = self.content_length
+ if self.chunked_response:
+ # use chunked encoding response
+ towrite = tobytes(hex(len(data))[2:].upper()) + b"\r\n"
+ towrite += data + b"\r\n"
+ elif cl is not None:
+ towrite = data[: cl - self.content_bytes_written]
+ self.content_bytes_written += len(towrite)
+ if towrite != data and not self.logged_write_excess:
+ self.logger.warning(
+ "application-written content exceeded the number of "
+ "bytes specified by Content-Length header (%s)" % cl
+ )
+ self.logged_write_excess = True
+ if towrite:
+ channel.write_soon(towrite)
+ elif data:
+ # Cheat, and tell the application we have written all of the bytes,
+ # even though the response shouldn't have a body and we are
+ # ignoring it entirely.
+ self.content_bytes_written += len(data)
+
+ if not self.logged_write_no_body:
+ self.logger.warning(
+ "application-written content was ignored due to HTTP "
+ "response that may not contain a message-body: (%s)" % self.status
+ )
+ self.logged_write_no_body = True
+
+
+class ErrorTask(Task):
+ """ An error task produces an error response
+ """
+
+ complete = True
+
+ def execute(self):
+ e = self.request.error
+ status, headers, body = e.to_response()
+ self.status = status
+ self.response_headers.extend(headers)
+ # We need to explicitly tell the remote client we are closing the
+ # connection, because self.close_on_finish is set, and we are going to
+ # slam the door in the clients face.
+ self.response_headers.append(("Connection", "close"))
+ self.close_on_finish = True
+ self.content_length = len(body)
+ self.write(tobytes(body))
+
+
+class WSGITask(Task):
+ """A WSGI task produces a response from a WSGI application.
+ """
+
+ environ = None
+
+ def execute(self):
+ environ = self.get_environment()
+
+ def start_response(status, headers, exc_info=None):
+ if self.complete and not exc_info:
+ raise AssertionError(
+ "start_response called a second time without providing exc_info."
+ )
+ if exc_info:
+ try:
+ if self.wrote_header:
+ # higher levels will catch and handle raised exception:
+ # 1. "service" method in task.py
+ # 2. "service" method in channel.py
+ # 3. "handler_thread" method in task.py
+ reraise(exc_info[0], exc_info[1], exc_info[2])
+ else:
+ # As per WSGI spec existing headers must be cleared
+ self.response_headers = []
+ finally:
+ exc_info = None
+
+ self.complete = True
+
+ if not status.__class__ is str:
+ raise AssertionError("status %s is not a string" % status)
+ if "\n" in status or "\r" in status:
+ raise ValueError(
+ "carriage return/line feed character present in status"
+ )
+
+ self.status = status
+
+ # Prepare the headers for output
+ for k, v in headers:
+ if not k.__class__ is str:
+ raise AssertionError(
+ "Header name %r is not a string in %r" % (k, (k, v))
+ )
+ if not v.__class__ is str:
+ raise AssertionError(
+ "Header value %r is not a string in %r" % (v, (k, v))
+ )
+
+ if "\n" in v or "\r" in v:
+ raise ValueError(
+ "carriage return/line feed character present in header value"
+ )
+ if "\n" in k or "\r" in k:
+ raise ValueError(
+ "carriage return/line feed character present in header name"
+ )
+
+ kl = k.lower()
+ if kl == "content-length":
+ self.content_length = int(v)
+ elif kl in hop_by_hop:
+ raise AssertionError(
+ '%s is a "hop-by-hop" header; it cannot be used by '
+ "a WSGI application (see PEP 3333)" % k
+ )
+
+ self.response_headers.extend(headers)
+
+ # Return a method used to write the response data.
+ return self.write
+
+ # Call the application to handle the request and write a response
+ app_iter = self.channel.server.application(environ, start_response)
+
+ can_close_app_iter = True
+ try:
+ if app_iter.__class__ is ReadOnlyFileBasedBuffer:
+ cl = self.content_length
+ size = app_iter.prepare(cl)
+ if size:
+ if cl != size:
+ if cl is not None:
+ self.remove_content_length_header()
+ self.content_length = size
+ self.write(b"") # generate headers
+ # if the write_soon below succeeds then the channel will
+ # take over closing the underlying file via the channel's
+ # _flush_some or handle_close so we intentionally avoid
+ # calling close in the finally block
+ self.channel.write_soon(app_iter)
+ can_close_app_iter = False
+ return
+
+ first_chunk_len = None
+ for chunk in app_iter:
+ if first_chunk_len is None:
+ first_chunk_len = len(chunk)
+ # Set a Content-Length header if one is not supplied.
+ # start_response may not have been called until first
+ # iteration as per PEP, so we must reinterrogate
+ # self.content_length here
+ if self.content_length is None:
+ app_iter_len = None
+ if hasattr(app_iter, "__len__"):
+ app_iter_len = len(app_iter)
+ if app_iter_len == 1:
+ self.content_length = first_chunk_len
+ # transmit headers only after first iteration of the iterable
+ # that returns a non-empty bytestring (PEP 3333)
+ if chunk:
+ self.write(chunk)
+
+ cl = self.content_length
+ if cl is not None:
+ if self.content_bytes_written != cl:
+ # close the connection so the client isn't sitting around
+ # waiting for more data when there are too few bytes
+ # to service content-length
+ self.close_on_finish = True
+ if self.request.command != "HEAD":
+ self.logger.warning(
+ "application returned too few bytes (%s) "
+ "for specified Content-Length (%s) via app_iter"
+ % (self.content_bytes_written, cl),
+ )
+ finally:
+ if can_close_app_iter and hasattr(app_iter, "close"):
+ app_iter.close()
+
+ def get_environment(self):
+ """Returns a WSGI environment."""
+ environ = self.environ
+ if environ is not None:
+ # Return the cached copy.
+ return environ
+
+ request = self.request
+ path = request.path
+ channel = self.channel
+ server = channel.server
+ url_prefix = server.adj.url_prefix
+
+ if path.startswith("/"):
+ # strip extra slashes at the beginning of a path that starts
+ # with any number of slashes
+ path = "/" + path.lstrip("/")
+
+ if url_prefix:
+ # NB: url_prefix is guaranteed by the configuration machinery to
+ # be either the empty string or a string that starts with a single
+ # slash and ends without any slashes
+ if path == url_prefix:
+ # if the path is the same as the url prefix, the SCRIPT_NAME
+ # should be the url_prefix and PATH_INFO should be empty
+ path = ""
+ else:
+ # if the path starts with the url prefix plus a slash,
+ # the SCRIPT_NAME should be the url_prefix and PATH_INFO should
+ # the value of path from the slash until its end
+ url_prefix_with_trailing_slash = url_prefix + "/"
+ if path.startswith(url_prefix_with_trailing_slash):
+ path = path[len(url_prefix) :]
+
+ environ = {
+ "REMOTE_ADDR": channel.addr[0],
+ # Nah, we aren't actually going to look up the reverse DNS for
+ # REMOTE_ADDR, but we will happily set this environment variable
+ # for the WSGI application. Spec says we can just set this to
+ # REMOTE_ADDR, so we do.
+ "REMOTE_HOST": channel.addr[0],
+ # try and set the REMOTE_PORT to something useful, but maybe None
+ "REMOTE_PORT": str(channel.addr[1]),
+ "REQUEST_METHOD": request.command.upper(),
+ "SERVER_PORT": str(server.effective_port),
+ "SERVER_NAME": server.server_name,
+ "SERVER_SOFTWARE": server.adj.ident,
+ "SERVER_PROTOCOL": "HTTP/%s" % self.version,
+ "SCRIPT_NAME": url_prefix,
+ "PATH_INFO": path,
+ "QUERY_STRING": request.query,
+ "wsgi.url_scheme": request.url_scheme,
+ # the following environment variables are required by the WSGI spec
+ "wsgi.version": (1, 0),
+ # apps should use the logging module
+ "wsgi.errors": sys.stderr,
+ "wsgi.multithread": True,
+ "wsgi.multiprocess": False,
+ "wsgi.run_once": False,
+ "wsgi.input": request.get_body_stream(),
+ "wsgi.file_wrapper": ReadOnlyFileBasedBuffer,
+ "wsgi.input_terminated": True, # wsgi.input is EOF terminated
+ }
+
+ for key, value in dict(request.headers).items():
+ value = value.strip()
+ mykey = rename_headers.get(key, None)
+ if mykey is None:
+ mykey = "HTTP_" + key
+ if mykey not in environ:
+ environ[mykey] = value
+
+ # cache the environ for this request
+ self.environ = environ
+ return environ
diff --git a/libs/waitress/tests/__init__.py b/libs/waitress/tests/__init__.py
new file mode 100644
index 000000000..b711d3609
--- /dev/null
+++ b/libs/waitress/tests/__init__.py
@@ -0,0 +1,2 @@
+#
+# This file is necessary to make this directory a package.
diff --git a/libs/waitress/tests/fixtureapps/__init__.py b/libs/waitress/tests/fixtureapps/__init__.py
new file mode 100644
index 000000000..f215a2b90
--- /dev/null
+++ b/libs/waitress/tests/fixtureapps/__init__.py
@@ -0,0 +1 @@
+# package (for -m)
diff --git a/libs/waitress/tests/fixtureapps/badcl.py b/libs/waitress/tests/fixtureapps/badcl.py
new file mode 100644
index 000000000..24067de41
--- /dev/null
+++ b/libs/waitress/tests/fixtureapps/badcl.py
@@ -0,0 +1,11 @@
+def app(environ, start_response): # pragma: no cover
+ body = b"abcdefghi"
+ cl = len(body)
+ if environ["PATH_INFO"] == "/short_body":
+ cl = len(body) + 1
+ if environ["PATH_INFO"] == "/long_body":
+ cl = len(body) - 1
+ start_response(
+ "200 OK", [("Content-Length", str(cl)), ("Content-Type", "text/plain")]
+ )
+ return [body]
diff --git a/libs/waitress/tests/fixtureapps/echo.py b/libs/waitress/tests/fixtureapps/echo.py
new file mode 100644
index 000000000..813bdacea
--- /dev/null
+++ b/libs/waitress/tests/fixtureapps/echo.py
@@ -0,0 +1,56 @@
+from collections import namedtuple
+import json
+
+
+def app_body_only(environ, start_response): # pragma: no cover
+ cl = environ.get("CONTENT_LENGTH", None)
+ if cl is not None:
+ cl = int(cl)
+ body = environ["wsgi.input"].read(cl)
+ cl = str(len(body))
+ start_response("200 OK", [("Content-Length", cl), ("Content-Type", "text/plain"),])
+ return [body]
+
+
+def app(environ, start_response): # pragma: no cover
+ cl = environ.get("CONTENT_LENGTH", None)
+ if cl is not None:
+ cl = int(cl)
+ request_body = environ["wsgi.input"].read(cl)
+ cl = str(len(request_body))
+ meta = {
+ "method": environ["REQUEST_METHOD"],
+ "path_info": environ["PATH_INFO"],
+ "script_name": environ["SCRIPT_NAME"],
+ "query_string": environ["QUERY_STRING"],
+ "content_length": cl,
+ "scheme": environ["wsgi.url_scheme"],
+ "remote_addr": environ["REMOTE_ADDR"],
+ "remote_host": environ["REMOTE_HOST"],
+ "server_port": environ["SERVER_PORT"],
+ "server_name": environ["SERVER_NAME"],
+ "headers": {
+ k[len("HTTP_") :]: v for k, v in environ.items() if k.startswith("HTTP_")
+ },
+ }
+ response = json.dumps(meta).encode("utf8") + b"\r\n\r\n" + request_body
+ start_response(
+ "200 OK",
+ [("Content-Length", str(len(response))), ("Content-Type", "text/plain"),],
+ )
+ return [response]
+
+
+Echo = namedtuple(
+ "Echo",
+ (
+ "method path_info script_name query_string content_length scheme "
+ "remote_addr remote_host server_port server_name headers body"
+ ),
+)
+
+
+def parse_response(response):
+ meta, body = response.split(b"\r\n\r\n", 1)
+ meta = json.loads(meta.decode("utf8"))
+ return Echo(body=body, **meta)
diff --git a/libs/waitress/tests/fixtureapps/error.py b/libs/waitress/tests/fixtureapps/error.py
new file mode 100644
index 000000000..5afb1c542
--- /dev/null
+++ b/libs/waitress/tests/fixtureapps/error.py
@@ -0,0 +1,21 @@
+def app(environ, start_response): # pragma: no cover
+ cl = environ.get("CONTENT_LENGTH", None)
+ if cl is not None:
+ cl = int(cl)
+ body = environ["wsgi.input"].read(cl)
+ cl = str(len(body))
+ if environ["PATH_INFO"] == "/before_start_response":
+ raise ValueError("wrong")
+ write = start_response(
+ "200 OK", [("Content-Length", cl), ("Content-Type", "text/plain")]
+ )
+ if environ["PATH_INFO"] == "/after_write_cb":
+ write("abc")
+ if environ["PATH_INFO"] == "/in_generator":
+
+ def foo():
+ yield "abc"
+ raise ValueError
+
+ return foo()
+ raise ValueError("wrong")
diff --git a/libs/waitress/tests/fixtureapps/filewrapper.py b/libs/waitress/tests/fixtureapps/filewrapper.py
new file mode 100644
index 000000000..63df5a6dc
--- /dev/null
+++ b/libs/waitress/tests/fixtureapps/filewrapper.py
@@ -0,0 +1,93 @@
+import io
+import os
+
+here = os.path.dirname(os.path.abspath(__file__))
+fn = os.path.join(here, "groundhog1.jpg")
+
+
+class KindaFilelike(object): # pragma: no cover
+ def __init__(self, bytes):
+ self.bytes = bytes
+
+ def read(self, n):
+ bytes = self.bytes[:n]
+ self.bytes = self.bytes[n:]
+ return bytes
+
+
+class UnseekableIOBase(io.RawIOBase): # pragma: no cover
+ def __init__(self, bytes):
+ self.buf = io.BytesIO(bytes)
+
+ def writable(self):
+ return False
+
+ def readable(self):
+ return True
+
+ def seekable(self):
+ return False
+
+ def read(self, n):
+ return self.buf.read(n)
+
+
+def app(environ, start_response): # pragma: no cover
+ path_info = environ["PATH_INFO"]
+ if path_info.startswith("/filelike"):
+ f = open(fn, "rb")
+ f.seek(0, 2)
+ cl = f.tell()
+ f.seek(0)
+ if path_info == "/filelike":
+ headers = [
+ ("Content-Length", str(cl)),
+ ("Content-Type", "image/jpeg"),
+ ]
+ elif path_info == "/filelike_nocl":
+ headers = [("Content-Type", "image/jpeg")]
+ elif path_info == "/filelike_shortcl":
+ # short content length
+ headers = [
+ ("Content-Length", "1"),
+ ("Content-Type", "image/jpeg"),
+ ]
+ else:
+ # long content length (/filelike_longcl)
+ headers = [
+ ("Content-Length", str(cl + 10)),
+ ("Content-Type", "image/jpeg"),
+ ]
+ else:
+ with open(fn, "rb") as fp:
+ data = fp.read()
+ cl = len(data)
+ f = KindaFilelike(data)
+ if path_info == "/notfilelike":
+ headers = [
+ ("Content-Length", str(len(data))),
+ ("Content-Type", "image/jpeg"),
+ ]
+ elif path_info == "/notfilelike_iobase":
+ headers = [
+ ("Content-Length", str(len(data))),
+ ("Content-Type", "image/jpeg"),
+ ]
+ f = UnseekableIOBase(data)
+ elif path_info == "/notfilelike_nocl":
+ headers = [("Content-Type", "image/jpeg")]
+ elif path_info == "/notfilelike_shortcl":
+ # short content length
+ headers = [
+ ("Content-Length", "1"),
+ ("Content-Type", "image/jpeg"),
+ ]
+ else:
+ # long content length (/notfilelike_longcl)
+ headers = [
+ ("Content-Length", str(cl + 10)),
+ ("Content-Type", "image/jpeg"),
+ ]
+
+ start_response("200 OK", headers)
+ return environ["wsgi.file_wrapper"](f, 8192)
diff --git a/libs/waitress/tests/fixtureapps/getline.py b/libs/waitress/tests/fixtureapps/getline.py
new file mode 100644
index 000000000..5e0ad3ae5
--- /dev/null
+++ b/libs/waitress/tests/fixtureapps/getline.py
@@ -0,0 +1,17 @@
+import sys
+
+if __name__ == "__main__":
+ try:
+ from urllib.request import urlopen, URLError
+ except ImportError:
+ from urllib2 import urlopen, URLError
+
+ url = sys.argv[1]
+ headers = {"Content-Type": "text/plain; charset=utf-8"}
+ try:
+ resp = urlopen(url)
+ line = resp.readline().decode("ascii") # py3
+ except URLError:
+ line = "failed to read %s" % url
+ sys.stdout.write(line)
+ sys.stdout.flush()
diff --git a/libs/waitress/tests/fixtureapps/groundhog1.jpg b/libs/waitress/tests/fixtureapps/groundhog1.jpg
new file mode 100644
index 000000000..90f610ea0
--- /dev/null
+++ b/libs/waitress/tests/fixtureapps/groundhog1.jpg
Binary files differ
diff --git a/libs/waitress/tests/fixtureapps/nocl.py b/libs/waitress/tests/fixtureapps/nocl.py
new file mode 100644
index 000000000..f82bba0c8
--- /dev/null
+++ b/libs/waitress/tests/fixtureapps/nocl.py
@@ -0,0 +1,23 @@
+def chunks(l, n): # pragma: no cover
+ """ Yield successive n-sized chunks from l.
+ """
+ for i in range(0, len(l), n):
+ yield l[i : i + n]
+
+
+def gen(body): # pragma: no cover
+ for chunk in chunks(body, 10):
+ yield chunk
+
+
+def app(environ, start_response): # pragma: no cover
+ cl = environ.get("CONTENT_LENGTH", None)
+ if cl is not None:
+ cl = int(cl)
+ body = environ["wsgi.input"].read(cl)
+ start_response("200 OK", [("Content-Type", "text/plain")])
+ if environ["PATH_INFO"] == "/list":
+ return [body]
+ if environ["PATH_INFO"] == "/list_lentwo":
+ return [body[0:1], body[1:]]
+ return gen(body)
diff --git a/libs/waitress/tests/fixtureapps/runner.py b/libs/waitress/tests/fixtureapps/runner.py
new file mode 100644
index 000000000..1d66ad1cc
--- /dev/null
+++ b/libs/waitress/tests/fixtureapps/runner.py
@@ -0,0 +1,6 @@
+def app(): # pragma: no cover
+ return None
+
+
+def returns_app(): # pragma: no cover
+ return app
diff --git a/libs/waitress/tests/fixtureapps/sleepy.py b/libs/waitress/tests/fixtureapps/sleepy.py
new file mode 100644
index 000000000..2d171d8be
--- /dev/null
+++ b/libs/waitress/tests/fixtureapps/sleepy.py
@@ -0,0 +1,12 @@
+import time
+
+
+def app(environ, start_response): # pragma: no cover
+ if environ["PATH_INFO"] == "/sleepy":
+ time.sleep(2)
+ body = b"sleepy returned"
+ else:
+ body = b"notsleepy returned"
+ cl = str(len(body))
+ start_response("200 OK", [("Content-Length", cl), ("Content-Type", "text/plain")])
+ return [body]
diff --git a/libs/waitress/tests/fixtureapps/toolarge.py b/libs/waitress/tests/fixtureapps/toolarge.py
new file mode 100644
index 000000000..a0f36d2cc
--- /dev/null
+++ b/libs/waitress/tests/fixtureapps/toolarge.py
@@ -0,0 +1,7 @@
+def app(environ, start_response): # pragma: no cover
+ body = b"abcdef"
+ cl = len(body)
+ start_response(
+ "200 OK", [("Content-Length", str(cl)), ("Content-Type", "text/plain")]
+ )
+ return [body]
diff --git a/libs/waitress/tests/fixtureapps/writecb.py b/libs/waitress/tests/fixtureapps/writecb.py
new file mode 100644
index 000000000..e1d2792e6
--- /dev/null
+++ b/libs/waitress/tests/fixtureapps/writecb.py
@@ -0,0 +1,14 @@
+def app(environ, start_response): # pragma: no cover
+ path_info = environ["PATH_INFO"]
+ if path_info == "/no_content_length":
+ headers = []
+ else:
+ headers = [("Content-Length", "9")]
+ write = start_response("200 OK", headers)
+ if path_info == "/long_body":
+ write(b"abcdefghij")
+ elif path_info == "/short_body":
+ write(b"abcdefgh")
+ else:
+ write(b"abcdefghi")
+ return []
diff --git a/libs/waitress/tests/test_adjustments.py b/libs/waitress/tests/test_adjustments.py
new file mode 100644
index 000000000..303c1aa3a
--- /dev/null
+++ b/libs/waitress/tests/test_adjustments.py
@@ -0,0 +1,481 @@
+import sys
+import socket
+import warnings
+
+from waitress.compat import (
+ PY2,
+ WIN,
+)
+
+if sys.version_info[:2] == (2, 6): # pragma: no cover
+ import unittest2 as unittest
+else: # pragma: no cover
+ import unittest
+
+
+class Test_asbool(unittest.TestCase):
+ def _callFUT(self, s):
+ from waitress.adjustments import asbool
+
+ return asbool(s)
+
+ def test_s_is_None(self):
+ result = self._callFUT(None)
+ self.assertEqual(result, False)
+
+ def test_s_is_True(self):
+ result = self._callFUT(True)
+ self.assertEqual(result, True)
+
+ def test_s_is_False(self):
+ result = self._callFUT(False)
+ self.assertEqual(result, False)
+
+ def test_s_is_true(self):
+ result = self._callFUT("True")
+ self.assertEqual(result, True)
+
+ def test_s_is_false(self):
+ result = self._callFUT("False")
+ self.assertEqual(result, False)
+
+ def test_s_is_yes(self):
+ result = self._callFUT("yes")
+ self.assertEqual(result, True)
+
+ def test_s_is_on(self):
+ result = self._callFUT("on")
+ self.assertEqual(result, True)
+
+ def test_s_is_1(self):
+ result = self._callFUT(1)
+ self.assertEqual(result, True)
+
+
+class Test_as_socket_list(unittest.TestCase):
+ def test_only_sockets_in_list(self):
+ from waitress.adjustments import as_socket_list
+
+ sockets = [
+ socket.socket(socket.AF_INET, socket.SOCK_STREAM),
+ socket.socket(socket.AF_INET6, socket.SOCK_STREAM),
+ ]
+ if hasattr(socket, "AF_UNIX"):
+ sockets.append(socket.socket(socket.AF_UNIX, socket.SOCK_STREAM))
+ new_sockets = as_socket_list(sockets)
+ self.assertEqual(sockets, new_sockets)
+ for sock in sockets:
+ sock.close()
+
+ def test_not_only_sockets_in_list(self):
+ from waitress.adjustments import as_socket_list
+
+ sockets = [
+ socket.socket(socket.AF_INET, socket.SOCK_STREAM),
+ socket.socket(socket.AF_INET6, socket.SOCK_STREAM),
+ {"something": "else"},
+ ]
+ new_sockets = as_socket_list(sockets)
+ self.assertEqual(new_sockets, [sockets[0], sockets[1]])
+ for sock in [sock for sock in sockets if isinstance(sock, socket.socket)]:
+ sock.close()
+
+
+class TestAdjustments(unittest.TestCase):
+ def _hasIPv6(self): # pragma: nocover
+ if not socket.has_ipv6:
+ return False
+
+ try:
+ socket.getaddrinfo(
+ "::1",
+ 0,
+ socket.AF_UNSPEC,
+ socket.SOCK_STREAM,
+ socket.IPPROTO_TCP,
+ socket.AI_PASSIVE | socket.AI_ADDRCONFIG,
+ )
+
+ return True
+ except socket.gaierror as e:
+ # Check to see what the error is
+ if e.errno == socket.EAI_ADDRFAMILY:
+ return False
+ else:
+ raise e
+
+ def _makeOne(self, **kw):
+ from waitress.adjustments import Adjustments
+
+ return Adjustments(**kw)
+
+ def test_goodvars(self):
+ inst = self._makeOne(
+ host="localhost",
+ port="8080",
+ threads="5",
+ trusted_proxy="192.168.1.1",
+ trusted_proxy_headers={"forwarded"},
+ trusted_proxy_count=2,
+ log_untrusted_proxy_headers=True,
+ url_scheme="https",
+ backlog="20",
+ recv_bytes="200",
+ send_bytes="300",
+ outbuf_overflow="400",
+ inbuf_overflow="500",
+ connection_limit="1000",
+ cleanup_interval="1100",
+ channel_timeout="1200",
+ log_socket_errors="true",
+ max_request_header_size="1300",
+ max_request_body_size="1400",
+ expose_tracebacks="true",
+ ident="abc",
+ asyncore_loop_timeout="5",
+ asyncore_use_poll=True,
+ unix_socket_perms="777",
+ url_prefix="///foo/",
+ ipv4=True,
+ ipv6=False,
+ )
+
+ self.assertEqual(inst.host, "localhost")
+ self.assertEqual(inst.port, 8080)
+ self.assertEqual(inst.threads, 5)
+ self.assertEqual(inst.trusted_proxy, "192.168.1.1")
+ self.assertEqual(inst.trusted_proxy_headers, {"forwarded"})
+ self.assertEqual(inst.trusted_proxy_count, 2)
+ self.assertEqual(inst.log_untrusted_proxy_headers, True)
+ self.assertEqual(inst.url_scheme, "https")
+ self.assertEqual(inst.backlog, 20)
+ self.assertEqual(inst.recv_bytes, 200)
+ self.assertEqual(inst.send_bytes, 300)
+ self.assertEqual(inst.outbuf_overflow, 400)
+ self.assertEqual(inst.inbuf_overflow, 500)
+ self.assertEqual(inst.connection_limit, 1000)
+ self.assertEqual(inst.cleanup_interval, 1100)
+ self.assertEqual(inst.channel_timeout, 1200)
+ self.assertEqual(inst.log_socket_errors, True)
+ self.assertEqual(inst.max_request_header_size, 1300)
+ self.assertEqual(inst.max_request_body_size, 1400)
+ self.assertEqual(inst.expose_tracebacks, True)
+ self.assertEqual(inst.asyncore_loop_timeout, 5)
+ self.assertEqual(inst.asyncore_use_poll, True)
+ self.assertEqual(inst.ident, "abc")
+ self.assertEqual(inst.unix_socket_perms, 0o777)
+ self.assertEqual(inst.url_prefix, "/foo")
+ self.assertEqual(inst.ipv4, True)
+ self.assertEqual(inst.ipv6, False)
+
+ bind_pairs = [
+ sockaddr[:2]
+ for (family, _, _, sockaddr) in inst.listen
+ if family == socket.AF_INET
+ ]
+
+ # On Travis, somehow we start listening to two sockets when resolving
+ # localhost...
+ self.assertEqual(("127.0.0.1", 8080), bind_pairs[0])
+
+ def test_goodvar_listen(self):
+ inst = self._makeOne(listen="127.0.0.1")
+
+ bind_pairs = [(host, port) for (_, _, _, (host, port)) in inst.listen]
+
+ self.assertEqual(bind_pairs, [("127.0.0.1", 8080)])
+
+ def test_default_listen(self):
+ inst = self._makeOne()
+
+ bind_pairs = [(host, port) for (_, _, _, (host, port)) in inst.listen]
+
+ self.assertEqual(bind_pairs, [("0.0.0.0", 8080)])
+
+ def test_multiple_listen(self):
+ inst = self._makeOne(listen="127.0.0.1:9090 127.0.0.1:8080")
+
+ bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen]
+
+ self.assertEqual(bind_pairs, [("127.0.0.1", 9090), ("127.0.0.1", 8080)])
+
+ def test_wildcard_listen(self):
+ inst = self._makeOne(listen="*:8080")
+
+ bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen]
+
+ self.assertTrue(len(bind_pairs) >= 1)
+
+ def test_ipv6_no_port(self): # pragma: nocover
+ if not self._hasIPv6():
+ return
+
+ inst = self._makeOne(listen="[::1]")
+
+ bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen]
+
+ self.assertEqual(bind_pairs, [("::1", 8080)])
+
+ def test_bad_port(self):
+ self.assertRaises(ValueError, self._makeOne, listen="127.0.0.1:test")
+
+ def test_service_port(self):
+ if WIN and PY2: # pragma: no cover
+ # On Windows and Python 2 this is broken, so we raise a ValueError
+ self.assertRaises(
+ ValueError, self._makeOne, listen="127.0.0.1:http",
+ )
+ return
+
+ inst = self._makeOne(listen="127.0.0.1:http 0.0.0.0:https")
+
+ bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen]
+
+ self.assertEqual(bind_pairs, [("127.0.0.1", 80), ("0.0.0.0", 443)])
+
+ def test_dont_mix_host_port_listen(self):
+ self.assertRaises(
+ ValueError,
+ self._makeOne,
+ host="localhost",
+ port="8080",
+ listen="127.0.0.1:8080",
+ )
+
+ def test_good_sockets(self):
+ sockets = [
+ socket.socket(socket.AF_INET6, socket.SOCK_STREAM),
+ socket.socket(socket.AF_INET, socket.SOCK_STREAM),
+ ]
+ inst = self._makeOne(sockets=sockets)
+ self.assertEqual(inst.sockets, sockets)
+ sockets[0].close()
+ sockets[1].close()
+
+ def test_dont_mix_sockets_and_listen(self):
+ sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)]
+ self.assertRaises(
+ ValueError, self._makeOne, listen="127.0.0.1:8080", sockets=sockets
+ )
+ sockets[0].close()
+
+ def test_dont_mix_sockets_and_host_port(self):
+ sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)]
+ self.assertRaises(
+ ValueError, self._makeOne, host="localhost", port="8080", sockets=sockets
+ )
+ sockets[0].close()
+
+ def test_dont_mix_sockets_and_unix_socket(self):
+ sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)]
+ self.assertRaises(
+ ValueError, self._makeOne, unix_socket="./tmp/test", sockets=sockets
+ )
+ sockets[0].close()
+
+ def test_dont_mix_unix_socket_and_host_port(self):
+ self.assertRaises(
+ ValueError,
+ self._makeOne,
+ unix_socket="./tmp/test",
+ host="localhost",
+ port="8080",
+ )
+
+ def test_dont_mix_unix_socket_and_listen(self):
+ self.assertRaises(
+ ValueError, self._makeOne, unix_socket="./tmp/test", listen="127.0.0.1:8080"
+ )
+
+ def test_dont_use_unsupported_socket_types(self):
+ sockets = [socket.socket(socket.AF_INET, socket.SOCK_DGRAM)]
+ self.assertRaises(ValueError, self._makeOne, sockets=sockets)
+ sockets[0].close()
+
+ def test_dont_mix_forwarded_with_x_forwarded(self):
+ with self.assertRaises(ValueError) as cm:
+ self._makeOne(
+ trusted_proxy="localhost",
+ trusted_proxy_headers={"forwarded", "x-forwarded-for"},
+ )
+
+ self.assertIn("The Forwarded proxy header", str(cm.exception))
+
+ def test_unknown_trusted_proxy_header(self):
+ with self.assertRaises(ValueError) as cm:
+ self._makeOne(
+ trusted_proxy="localhost",
+ trusted_proxy_headers={"forwarded", "x-forwarded-unknown"},
+ )
+
+ self.assertIn(
+ "unknown trusted_proxy_headers value (x-forwarded-unknown)",
+ str(cm.exception),
+ )
+
+ def test_trusted_proxy_count_no_trusted_proxy(self):
+ with self.assertRaises(ValueError) as cm:
+ self._makeOne(trusted_proxy_count=1)
+
+ self.assertIn("trusted_proxy_count has no meaning", str(cm.exception))
+
+ def test_trusted_proxy_headers_no_trusted_proxy(self):
+ with self.assertRaises(ValueError) as cm:
+ self._makeOne(trusted_proxy_headers={"forwarded"})
+
+ self.assertIn("trusted_proxy_headers has no meaning", str(cm.exception))
+
+ def test_trusted_proxy_headers_string_list(self):
+ inst = self._makeOne(
+ trusted_proxy="localhost",
+ trusted_proxy_headers="x-forwarded-for x-forwarded-by",
+ )
+ self.assertEqual(
+ inst.trusted_proxy_headers, {"x-forwarded-for", "x-forwarded-by"}
+ )
+
+ def test_trusted_proxy_headers_string_list_newlines(self):
+ inst = self._makeOne(
+ trusted_proxy="localhost",
+ trusted_proxy_headers="x-forwarded-for\nx-forwarded-by\nx-forwarded-host",
+ )
+ self.assertEqual(
+ inst.trusted_proxy_headers,
+ {"x-forwarded-for", "x-forwarded-by", "x-forwarded-host"},
+ )
+
+ def test_no_trusted_proxy_headers_trusted_proxy(self):
+ with warnings.catch_warnings(record=True) as w:
+ warnings.resetwarnings()
+ warnings.simplefilter("always")
+ self._makeOne(trusted_proxy="localhost")
+
+ self.assertGreaterEqual(len(w), 1)
+ self.assertTrue(issubclass(w[0].category, DeprecationWarning))
+ self.assertIn("Implicitly trusting X-Forwarded-Proto", str(w[0]))
+
+ def test_clear_untrusted_proxy_headers(self):
+ with warnings.catch_warnings(record=True) as w:
+ warnings.resetwarnings()
+ warnings.simplefilter("always")
+ self._makeOne(
+ trusted_proxy="localhost", trusted_proxy_headers={"x-forwarded-for"}
+ )
+
+ self.assertGreaterEqual(len(w), 1)
+ self.assertTrue(issubclass(w[0].category, DeprecationWarning))
+ self.assertIn(
+ "clear_untrusted_proxy_headers will be set to True", str(w[0])
+ )
+
+ def test_deprecated_send_bytes(self):
+ with warnings.catch_warnings(record=True) as w:
+ warnings.resetwarnings()
+ warnings.simplefilter("always")
+ self._makeOne(send_bytes=1)
+
+ self.assertGreaterEqual(len(w), 1)
+ self.assertTrue(issubclass(w[0].category, DeprecationWarning))
+ self.assertIn("send_bytes", str(w[0]))
+
+ def test_badvar(self):
+ self.assertRaises(ValueError, self._makeOne, nope=True)
+
+ def test_ipv4_disabled(self):
+ self.assertRaises(
+ ValueError, self._makeOne, ipv4=False, listen="127.0.0.1:8080"
+ )
+
+ def test_ipv6_disabled(self):
+ self.assertRaises(ValueError, self._makeOne, ipv6=False, listen="[::]:8080")
+
+ def test_server_header_removable(self):
+ inst = self._makeOne(ident=None)
+ self.assertEqual(inst.ident, None)
+
+ inst = self._makeOne(ident="")
+ self.assertEqual(inst.ident, None)
+
+ inst = self._makeOne(ident="specific_header")
+ self.assertEqual(inst.ident, "specific_header")
+
+
+class TestCLI(unittest.TestCase):
+ def parse(self, argv):
+ from waitress.adjustments import Adjustments
+
+ return Adjustments.parse_args(argv)
+
+ def test_noargs(self):
+ opts, args = self.parse([])
+ self.assertDictEqual(opts, {"call": False, "help": False})
+ self.assertSequenceEqual(args, [])
+
+ def test_help(self):
+ opts, args = self.parse(["--help"])
+ self.assertDictEqual(opts, {"call": False, "help": True})
+ self.assertSequenceEqual(args, [])
+
+ def test_call(self):
+ opts, args = self.parse(["--call"])
+ self.assertDictEqual(opts, {"call": True, "help": False})
+ self.assertSequenceEqual(args, [])
+
+ def test_both(self):
+ opts, args = self.parse(["--call", "--help"])
+ self.assertDictEqual(opts, {"call": True, "help": True})
+ self.assertSequenceEqual(args, [])
+
+ def test_positive_boolean(self):
+ opts, args = self.parse(["--expose-tracebacks"])
+ self.assertDictContainsSubset({"expose_tracebacks": "true"}, opts)
+ self.assertSequenceEqual(args, [])
+
+ def test_negative_boolean(self):
+ opts, args = self.parse(["--no-expose-tracebacks"])
+ self.assertDictContainsSubset({"expose_tracebacks": "false"}, opts)
+ self.assertSequenceEqual(args, [])
+
+ def test_cast_params(self):
+ opts, args = self.parse(
+ ["--host=localhost", "--port=80", "--unix-socket-perms=777"]
+ )
+ self.assertDictContainsSubset(
+ {"host": "localhost", "port": "80", "unix_socket_perms": "777",}, opts
+ )
+ self.assertSequenceEqual(args, [])
+
+ def test_listen_params(self):
+ opts, args = self.parse(["--listen=test:80",])
+
+ self.assertDictContainsSubset({"listen": " test:80"}, opts)
+ self.assertSequenceEqual(args, [])
+
+ def test_multiple_listen_params(self):
+ opts, args = self.parse(["--listen=test:80", "--listen=test:8080",])
+
+ self.assertDictContainsSubset({"listen": " test:80 test:8080"}, opts)
+ self.assertSequenceEqual(args, [])
+
+ def test_bad_param(self):
+ import getopt
+
+ self.assertRaises(getopt.GetoptError, self.parse, ["--no-host"])
+
+
+if hasattr(socket, "AF_UNIX"):
+
+ class TestUnixSocket(unittest.TestCase):
+ def _makeOne(self, **kw):
+ from waitress.adjustments import Adjustments
+
+ return Adjustments(**kw)
+
+ def test_dont_mix_internet_and_unix_sockets(self):
+ sockets = [
+ socket.socket(socket.AF_INET, socket.SOCK_STREAM),
+ socket.socket(socket.AF_UNIX, socket.SOCK_STREAM),
+ ]
+ self.assertRaises(ValueError, self._makeOne, sockets=sockets)
+ sockets[0].close()
+ sockets[1].close()
diff --git a/libs/waitress/tests/test_buffers.py b/libs/waitress/tests/test_buffers.py
new file mode 100644
index 000000000..a1330ac1b
--- /dev/null
+++ b/libs/waitress/tests/test_buffers.py
@@ -0,0 +1,523 @@
+import unittest
+import io
+
+
+class TestFileBasedBuffer(unittest.TestCase):
+ def _makeOne(self, file=None, from_buffer=None):
+ from waitress.buffers import FileBasedBuffer
+
+ buf = FileBasedBuffer(file, from_buffer=from_buffer)
+ self.buffers_to_close.append(buf)
+ return buf
+
+ def setUp(self):
+ self.buffers_to_close = []
+
+ def tearDown(self):
+ for buf in self.buffers_to_close:
+ buf.close()
+
+ def test_ctor_from_buffer_None(self):
+ inst = self._makeOne("file")
+ self.assertEqual(inst.file, "file")
+
+ def test_ctor_from_buffer(self):
+ from_buffer = io.BytesIO(b"data")
+ from_buffer.getfile = lambda *x: from_buffer
+ f = io.BytesIO()
+ inst = self._makeOne(f, from_buffer)
+ self.assertEqual(inst.file, f)
+ del from_buffer.getfile
+ self.assertEqual(inst.remain, 4)
+ from_buffer.close()
+
+ def test___len__(self):
+ inst = self._makeOne()
+ inst.remain = 10
+ self.assertEqual(len(inst), 10)
+
+ def test___nonzero__(self):
+ inst = self._makeOne()
+ inst.remain = 10
+ self.assertEqual(bool(inst), True)
+ inst.remain = 0
+ self.assertEqual(bool(inst), True)
+
+ def test_append(self):
+ f = io.BytesIO(b"data")
+ inst = self._makeOne(f)
+ inst.append(b"data2")
+ self.assertEqual(f.getvalue(), b"datadata2")
+ self.assertEqual(inst.remain, 5)
+
+ def test_get_skip_true(self):
+ f = io.BytesIO(b"data")
+ inst = self._makeOne(f)
+ result = inst.get(100, skip=True)
+ self.assertEqual(result, b"data")
+ self.assertEqual(inst.remain, -4)
+
+ def test_get_skip_false(self):
+ f = io.BytesIO(b"data")
+ inst = self._makeOne(f)
+ result = inst.get(100, skip=False)
+ self.assertEqual(result, b"data")
+ self.assertEqual(inst.remain, 0)
+
+ def test_get_skip_bytes_less_than_zero(self):
+ f = io.BytesIO(b"data")
+ inst = self._makeOne(f)
+ result = inst.get(-1, skip=False)
+ self.assertEqual(result, b"data")
+ self.assertEqual(inst.remain, 0)
+
+ def test_skip_remain_gt_bytes(self):
+ f = io.BytesIO(b"d")
+ inst = self._makeOne(f)
+ inst.remain = 1
+ inst.skip(1)
+ self.assertEqual(inst.remain, 0)
+
+ def test_skip_remain_lt_bytes(self):
+ f = io.BytesIO(b"d")
+ inst = self._makeOne(f)
+ inst.remain = 1
+ self.assertRaises(ValueError, inst.skip, 2)
+
+ def test_newfile(self):
+ inst = self._makeOne()
+ self.assertRaises(NotImplementedError, inst.newfile)
+
+ def test_prune_remain_notzero(self):
+ f = io.BytesIO(b"d")
+ inst = self._makeOne(f)
+ inst.remain = 1
+ nf = io.BytesIO()
+ inst.newfile = lambda *x: nf
+ inst.prune()
+ self.assertTrue(inst.file is not f)
+ self.assertEqual(nf.getvalue(), b"d")
+
+ def test_prune_remain_zero_tell_notzero(self):
+ f = io.BytesIO(b"d")
+ inst = self._makeOne(f)
+ nf = io.BytesIO(b"d")
+ inst.newfile = lambda *x: nf
+ inst.remain = 0
+ inst.prune()
+ self.assertTrue(inst.file is not f)
+ self.assertEqual(nf.getvalue(), b"d")
+
+ def test_prune_remain_zero_tell_zero(self):
+ f = io.BytesIO()
+ inst = self._makeOne(f)
+ inst.remain = 0
+ inst.prune()
+ self.assertTrue(inst.file is f)
+
+ def test_close(self):
+ f = io.BytesIO()
+ inst = self._makeOne(f)
+ inst.close()
+ self.assertTrue(f.closed)
+ self.buffers_to_close.remove(inst)
+
+
+class TestTempfileBasedBuffer(unittest.TestCase):
+ def _makeOne(self, from_buffer=None):
+ from waitress.buffers import TempfileBasedBuffer
+
+ buf = TempfileBasedBuffer(from_buffer=from_buffer)
+ self.buffers_to_close.append(buf)
+ return buf
+
+ def setUp(self):
+ self.buffers_to_close = []
+
+ def tearDown(self):
+ for buf in self.buffers_to_close:
+ buf.close()
+
+ def test_newfile(self):
+ inst = self._makeOne()
+ r = inst.newfile()
+ self.assertTrue(hasattr(r, "fileno")) # file
+ r.close()
+
+
+class TestBytesIOBasedBuffer(unittest.TestCase):
+ def _makeOne(self, from_buffer=None):
+ from waitress.buffers import BytesIOBasedBuffer
+
+ return BytesIOBasedBuffer(from_buffer=from_buffer)
+
+ def test_ctor_from_buffer_not_None(self):
+ f = io.BytesIO()
+ f.getfile = lambda *x: f
+ inst = self._makeOne(f)
+ self.assertTrue(hasattr(inst.file, "read"))
+
+ def test_ctor_from_buffer_None(self):
+ inst = self._makeOne()
+ self.assertTrue(hasattr(inst.file, "read"))
+
+ def test_newfile(self):
+ inst = self._makeOne()
+ r = inst.newfile()
+ self.assertTrue(hasattr(r, "read"))
+
+
+class TestReadOnlyFileBasedBuffer(unittest.TestCase):
+ def _makeOne(self, file, block_size=8192):
+ from waitress.buffers import ReadOnlyFileBasedBuffer
+
+ buf = ReadOnlyFileBasedBuffer(file, block_size)
+ self.buffers_to_close.append(buf)
+ return buf
+
+ def setUp(self):
+ self.buffers_to_close = []
+
+ def tearDown(self):
+ for buf in self.buffers_to_close:
+ buf.close()
+
+ def test_prepare_not_seekable(self):
+ f = KindaFilelike(b"abc")
+ inst = self._makeOne(f)
+ result = inst.prepare()
+ self.assertEqual(result, False)
+ self.assertEqual(inst.remain, 0)
+
+ def test_prepare_not_seekable_closeable(self):
+ f = KindaFilelike(b"abc", close=1)
+ inst = self._makeOne(f)
+ result = inst.prepare()
+ self.assertEqual(result, False)
+ self.assertEqual(inst.remain, 0)
+ self.assertTrue(hasattr(inst, "close"))
+
+ def test_prepare_seekable_closeable(self):
+ f = Filelike(b"abc", close=1, tellresults=[0, 10])
+ inst = self._makeOne(f)
+ result = inst.prepare()
+ self.assertEqual(result, 10)
+ self.assertEqual(inst.remain, 10)
+ self.assertEqual(inst.file.seeked, 0)
+ self.assertTrue(hasattr(inst, "close"))
+
+ def test_get_numbytes_neg_one(self):
+ f = io.BytesIO(b"abcdef")
+ inst = self._makeOne(f)
+ inst.remain = 2
+ result = inst.get(-1)
+ self.assertEqual(result, b"ab")
+ self.assertEqual(inst.remain, 2)
+ self.assertEqual(f.tell(), 0)
+
+ def test_get_numbytes_gt_remain(self):
+ f = io.BytesIO(b"abcdef")
+ inst = self._makeOne(f)
+ inst.remain = 2
+ result = inst.get(3)
+ self.assertEqual(result, b"ab")
+ self.assertEqual(inst.remain, 2)
+ self.assertEqual(f.tell(), 0)
+
+ def test_get_numbytes_lt_remain(self):
+ f = io.BytesIO(b"abcdef")
+ inst = self._makeOne(f)
+ inst.remain = 2
+ result = inst.get(1)
+ self.assertEqual(result, b"a")
+ self.assertEqual(inst.remain, 2)
+ self.assertEqual(f.tell(), 0)
+
+ def test_get_numbytes_gt_remain_withskip(self):
+ f = io.BytesIO(b"abcdef")
+ inst = self._makeOne(f)
+ inst.remain = 2
+ result = inst.get(3, skip=True)
+ self.assertEqual(result, b"ab")
+ self.assertEqual(inst.remain, 0)
+ self.assertEqual(f.tell(), 2)
+
+ def test_get_numbytes_lt_remain_withskip(self):
+ f = io.BytesIO(b"abcdef")
+ inst = self._makeOne(f)
+ inst.remain = 2
+ result = inst.get(1, skip=True)
+ self.assertEqual(result, b"a")
+ self.assertEqual(inst.remain, 1)
+ self.assertEqual(f.tell(), 1)
+
+ def test___iter__(self):
+ data = b"a" * 10000
+ f = io.BytesIO(data)
+ inst = self._makeOne(f)
+ r = b""
+ for val in inst:
+ r += val
+ self.assertEqual(r, data)
+
+ def test_append(self):
+ inst = self._makeOne(None)
+ self.assertRaises(NotImplementedError, inst.append, "a")
+
+
+class TestOverflowableBuffer(unittest.TestCase):
+ def _makeOne(self, overflow=10):
+ from waitress.buffers import OverflowableBuffer
+
+ buf = OverflowableBuffer(overflow)
+ self.buffers_to_close.append(buf)
+ return buf
+
+ def setUp(self):
+ self.buffers_to_close = []
+
+ def tearDown(self):
+ for buf in self.buffers_to_close:
+ buf.close()
+
+ def test___len__buf_is_None(self):
+ inst = self._makeOne()
+ self.assertEqual(len(inst), 0)
+
+ def test___len__buf_is_not_None(self):
+ inst = self._makeOne()
+ inst.buf = b"abc"
+ self.assertEqual(len(inst), 3)
+ self.buffers_to_close.remove(inst)
+
+ def test___nonzero__(self):
+ inst = self._makeOne()
+ inst.buf = b"abc"
+ self.assertEqual(bool(inst), True)
+ inst.buf = b""
+ self.assertEqual(bool(inst), False)
+ self.buffers_to_close.remove(inst)
+
+ def test___nonzero___on_int_overflow_buffer(self):
+ inst = self._makeOne()
+
+ class int_overflow_buf(bytes):
+ def __len__(self):
+ # maxint + 1
+ return 0x7FFFFFFFFFFFFFFF + 1
+
+ inst.buf = int_overflow_buf()
+ self.assertEqual(bool(inst), True)
+ inst.buf = b""
+ self.assertEqual(bool(inst), False)
+ self.buffers_to_close.remove(inst)
+
+ def test__create_buffer_large(self):
+ from waitress.buffers import TempfileBasedBuffer
+
+ inst = self._makeOne()
+ inst.strbuf = b"x" * 11
+ inst._create_buffer()
+ self.assertEqual(inst.buf.__class__, TempfileBasedBuffer)
+ self.assertEqual(inst.buf.get(100), b"x" * 11)
+ self.assertEqual(inst.strbuf, b"")
+
+ def test__create_buffer_small(self):
+ from waitress.buffers import BytesIOBasedBuffer
+
+ inst = self._makeOne()
+ inst.strbuf = b"x" * 5
+ inst._create_buffer()
+ self.assertEqual(inst.buf.__class__, BytesIOBasedBuffer)
+ self.assertEqual(inst.buf.get(100), b"x" * 5)
+ self.assertEqual(inst.strbuf, b"")
+
+ def test_append_with_len_more_than_max_int(self):
+ from waitress.compat import MAXINT
+
+ inst = self._makeOne()
+ inst.overflowed = True
+ buf = DummyBuffer(length=MAXINT)
+ inst.buf = buf
+ result = inst.append(b"x")
+ # we don't want this to throw an OverflowError on Python 2 (see
+ # https://github.com/Pylons/waitress/issues/47)
+ self.assertEqual(result, None)
+ self.buffers_to_close.remove(inst)
+
+ def test_append_buf_None_not_longer_than_srtbuf_limit(self):
+ inst = self._makeOne()
+ inst.strbuf = b"x" * 5
+ inst.append(b"hello")
+ self.assertEqual(inst.strbuf, b"xxxxxhello")
+
+ def test_append_buf_None_longer_than_strbuf_limit(self):
+ inst = self._makeOne(10000)
+ inst.strbuf = b"x" * 8192
+ inst.append(b"hello")
+ self.assertEqual(inst.strbuf, b"")
+ self.assertEqual(len(inst.buf), 8197)
+
+ def test_append_overflow(self):
+ inst = self._makeOne(10)
+ inst.strbuf = b"x" * 8192
+ inst.append(b"hello")
+ self.assertEqual(inst.strbuf, b"")
+ self.assertEqual(len(inst.buf), 8197)
+
+ def test_append_sz_gt_overflow(self):
+ from waitress.buffers import BytesIOBasedBuffer
+
+ f = io.BytesIO(b"data")
+ inst = self._makeOne(f)
+ buf = BytesIOBasedBuffer()
+ inst.buf = buf
+ inst.overflow = 2
+ inst.append(b"data2")
+ self.assertEqual(f.getvalue(), b"data")
+ self.assertTrue(inst.overflowed)
+ self.assertNotEqual(inst.buf, buf)
+
+ def test_get_buf_None_skip_False(self):
+ inst = self._makeOne()
+ inst.strbuf = b"x" * 5
+ r = inst.get(5)
+ self.assertEqual(r, b"xxxxx")
+
+ def test_get_buf_None_skip_True(self):
+ inst = self._makeOne()
+ inst.strbuf = b"x" * 5
+ r = inst.get(5, skip=True)
+ self.assertFalse(inst.buf is None)
+ self.assertEqual(r, b"xxxxx")
+
+ def test_skip_buf_None(self):
+ inst = self._makeOne()
+ inst.strbuf = b"data"
+ inst.skip(4)
+ self.assertEqual(inst.strbuf, b"")
+ self.assertNotEqual(inst.buf, None)
+
+ def test_skip_buf_None_allow_prune_True(self):
+ inst = self._makeOne()
+ inst.strbuf = b"data"
+ inst.skip(4, True)
+ self.assertEqual(inst.strbuf, b"")
+ self.assertEqual(inst.buf, None)
+
+ def test_prune_buf_None(self):
+ inst = self._makeOne()
+ inst.prune()
+ self.assertEqual(inst.strbuf, b"")
+
+ def test_prune_with_buf(self):
+ inst = self._makeOne()
+
+ class Buf(object):
+ def prune(self):
+ self.pruned = True
+
+ inst.buf = Buf()
+ inst.prune()
+ self.assertEqual(inst.buf.pruned, True)
+ self.buffers_to_close.remove(inst)
+
+ def test_prune_with_buf_overflow(self):
+ inst = self._makeOne()
+
+ class DummyBuffer(io.BytesIO):
+ def getfile(self):
+ return self
+
+ def prune(self):
+ return True
+
+ def __len__(self):
+ return 5
+
+ def close(self):
+ pass
+
+ buf = DummyBuffer(b"data")
+ inst.buf = buf
+ inst.overflowed = True
+ inst.overflow = 10
+ inst.prune()
+ self.assertNotEqual(inst.buf, buf)
+
+ def test_prune_with_buflen_more_than_max_int(self):
+ from waitress.compat import MAXINT
+
+ inst = self._makeOne()
+ inst.overflowed = True
+ buf = DummyBuffer(length=MAXINT + 1)
+ inst.buf = buf
+ result = inst.prune()
+ # we don't want this to throw an OverflowError on Python 2 (see
+ # https://github.com/Pylons/waitress/issues/47)
+ self.assertEqual(result, None)
+
+ def test_getfile_buf_None(self):
+ inst = self._makeOne()
+ f = inst.getfile()
+ self.assertTrue(hasattr(f, "read"))
+
+ def test_getfile_buf_not_None(self):
+ inst = self._makeOne()
+ buf = io.BytesIO()
+ buf.getfile = lambda *x: buf
+ inst.buf = buf
+ f = inst.getfile()
+ self.assertEqual(f, buf)
+
+ def test_close_nobuf(self):
+ inst = self._makeOne()
+ inst.buf = None
+ self.assertEqual(inst.close(), None) # doesnt raise
+ self.buffers_to_close.remove(inst)
+
+ def test_close_withbuf(self):
+ class Buffer(object):
+ def close(self):
+ self.closed = True
+
+ buf = Buffer()
+ inst = self._makeOne()
+ inst.buf = buf
+ inst.close()
+ self.assertTrue(buf.closed)
+ self.buffers_to_close.remove(inst)
+
+
+class KindaFilelike(object):
+ def __init__(self, bytes, close=None, tellresults=None):
+ self.bytes = bytes
+ self.tellresults = tellresults
+ if close is not None:
+ self.close = lambda: close
+
+
+class Filelike(KindaFilelike):
+ def seek(self, v, whence=0):
+ self.seeked = v
+
+ def tell(self):
+ v = self.tellresults.pop(0)
+ return v
+
+
+class DummyBuffer(object):
+ def __init__(self, length=0):
+ self.length = length
+
+ def __len__(self):
+ return self.length
+
+ def append(self, s):
+ self.length = self.length + len(s)
+
+ def prune(self):
+ pass
+
+ def close(self):
+ pass
diff --git a/libs/waitress/tests/test_channel.py b/libs/waitress/tests/test_channel.py
new file mode 100644
index 000000000..14ef5a0ec
--- /dev/null
+++ b/libs/waitress/tests/test_channel.py
@@ -0,0 +1,882 @@
+import unittest
+import io
+
+
+class TestHTTPChannel(unittest.TestCase):
+ def _makeOne(self, sock, addr, adj, map=None):
+ from waitress.channel import HTTPChannel
+
+ server = DummyServer()
+ return HTTPChannel(server, sock, addr, adj=adj, map=map)
+
+ def _makeOneWithMap(self, adj=None):
+ if adj is None:
+ adj = DummyAdjustments()
+ sock = DummySock()
+ map = {}
+ inst = self._makeOne(sock, "127.0.0.1", adj, map=map)
+ inst.outbuf_lock = DummyLock()
+ return inst, sock, map
+
+ def test_ctor(self):
+ inst, _, map = self._makeOneWithMap()
+ self.assertEqual(inst.addr, "127.0.0.1")
+ self.assertEqual(inst.sendbuf_len, 2048)
+ self.assertEqual(map[100], inst)
+
+ def test_total_outbufs_len_an_outbuf_size_gt_sys_maxint(self):
+ from waitress.compat import MAXINT
+
+ inst, _, map = self._makeOneWithMap()
+
+ class DummyBuffer(object):
+ chunks = []
+
+ def append(self, data):
+ self.chunks.append(data)
+
+ class DummyData(object):
+ def __len__(self):
+ return MAXINT
+
+ inst.total_outbufs_len = 1
+ inst.outbufs = [DummyBuffer()]
+ inst.write_soon(DummyData())
+ # we are testing that this method does not raise an OverflowError
+ # (see https://github.com/Pylons/waitress/issues/47)
+ self.assertEqual(inst.total_outbufs_len, MAXINT + 1)
+
+ def test_writable_something_in_outbuf(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.total_outbufs_len = 3
+ self.assertTrue(inst.writable())
+
+ def test_writable_nothing_in_outbuf(self):
+ inst, sock, map = self._makeOneWithMap()
+ self.assertFalse(inst.writable())
+
+ def test_writable_nothing_in_outbuf_will_close(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.will_close = True
+ self.assertTrue(inst.writable())
+
+ def test_handle_write_not_connected(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.connected = False
+ self.assertFalse(inst.handle_write())
+
+ def test_handle_write_with_requests(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.requests = True
+ inst.last_activity = 0
+ result = inst.handle_write()
+ self.assertEqual(result, None)
+ self.assertEqual(inst.last_activity, 0)
+
+ def test_handle_write_no_request_with_outbuf(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.requests = []
+ inst.outbufs = [DummyBuffer(b"abc")]
+ inst.total_outbufs_len = len(inst.outbufs[0])
+ inst.last_activity = 0
+ result = inst.handle_write()
+ self.assertEqual(result, None)
+ self.assertNotEqual(inst.last_activity, 0)
+ self.assertEqual(sock.sent, b"abc")
+
+ def test_handle_write_outbuf_raises_socketerror(self):
+ import socket
+
+ inst, sock, map = self._makeOneWithMap()
+ inst.requests = []
+ outbuf = DummyBuffer(b"abc", socket.error)
+ inst.outbufs = [outbuf]
+ inst.total_outbufs_len = len(outbuf)
+ inst.last_activity = 0
+ inst.logger = DummyLogger()
+ result = inst.handle_write()
+ self.assertEqual(result, None)
+ self.assertEqual(inst.last_activity, 0)
+ self.assertEqual(sock.sent, b"")
+ self.assertEqual(len(inst.logger.exceptions), 1)
+ self.assertTrue(outbuf.closed)
+
+ def test_handle_write_outbuf_raises_othererror(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.requests = []
+ outbuf = DummyBuffer(b"abc", IOError)
+ inst.outbufs = [outbuf]
+ inst.total_outbufs_len = len(outbuf)
+ inst.last_activity = 0
+ inst.logger = DummyLogger()
+ result = inst.handle_write()
+ self.assertEqual(result, None)
+ self.assertEqual(inst.last_activity, 0)
+ self.assertEqual(sock.sent, b"")
+ self.assertEqual(len(inst.logger.exceptions), 1)
+ self.assertTrue(outbuf.closed)
+
+ def test_handle_write_no_requests_no_outbuf_will_close(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.requests = []
+ outbuf = DummyBuffer(b"")
+ inst.outbufs = [outbuf]
+ inst.will_close = True
+ inst.last_activity = 0
+ result = inst.handle_write()
+ self.assertEqual(result, None)
+ self.assertEqual(inst.connected, False)
+ self.assertEqual(sock.closed, True)
+ self.assertEqual(inst.last_activity, 0)
+ self.assertTrue(outbuf.closed)
+
+ def test_handle_write_no_requests_outbuf_gt_send_bytes(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.requests = [True]
+ inst.outbufs = [DummyBuffer(b"abc")]
+ inst.total_outbufs_len = len(inst.outbufs[0])
+ inst.adj.send_bytes = 2
+ inst.will_close = False
+ inst.last_activity = 0
+ result = inst.handle_write()
+ self.assertEqual(result, None)
+ self.assertEqual(inst.will_close, False)
+ self.assertTrue(inst.outbuf_lock.acquired)
+ self.assertEqual(sock.sent, b"abc")
+
+ def test_handle_write_close_when_flushed(self):
+ inst, sock, map = self._makeOneWithMap()
+ outbuf = DummyBuffer(b"abc")
+ inst.outbufs = [outbuf]
+ inst.total_outbufs_len = len(outbuf)
+ inst.will_close = False
+ inst.close_when_flushed = True
+ inst.last_activity = 0
+ result = inst.handle_write()
+ self.assertEqual(result, None)
+ self.assertEqual(inst.will_close, True)
+ self.assertEqual(inst.close_when_flushed, False)
+ self.assertEqual(sock.sent, b"abc")
+ self.assertTrue(outbuf.closed)
+
+ def test_readable_no_requests_not_will_close(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.requests = []
+ inst.will_close = False
+ self.assertEqual(inst.readable(), True)
+
+ def test_readable_no_requests_will_close(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.requests = []
+ inst.will_close = True
+ self.assertEqual(inst.readable(), False)
+
+ def test_readable_with_requests(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.requests = True
+ self.assertEqual(inst.readable(), False)
+
+ def test_handle_read_no_error(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.will_close = False
+ inst.recv = lambda *arg: b"abc"
+ inst.last_activity = 0
+ L = []
+ inst.received = lambda x: L.append(x)
+ result = inst.handle_read()
+ self.assertEqual(result, None)
+ self.assertNotEqual(inst.last_activity, 0)
+ self.assertEqual(L, [b"abc"])
+
+ def test_handle_read_error(self):
+ import socket
+
+ inst, sock, map = self._makeOneWithMap()
+ inst.will_close = False
+
+ def recv(b):
+ raise socket.error
+
+ inst.recv = recv
+ inst.last_activity = 0
+ inst.logger = DummyLogger()
+ result = inst.handle_read()
+ self.assertEqual(result, None)
+ self.assertEqual(inst.last_activity, 0)
+ self.assertEqual(len(inst.logger.exceptions), 1)
+
+ def test_write_soon_empty_byte(self):
+ inst, sock, map = self._makeOneWithMap()
+ wrote = inst.write_soon(b"")
+ self.assertEqual(wrote, 0)
+ self.assertEqual(len(inst.outbufs[0]), 0)
+
+ def test_write_soon_nonempty_byte(self):
+ inst, sock, map = self._makeOneWithMap()
+ wrote = inst.write_soon(b"a")
+ self.assertEqual(wrote, 1)
+ self.assertEqual(len(inst.outbufs[0]), 1)
+
+ def test_write_soon_filewrapper(self):
+ from waitress.buffers import ReadOnlyFileBasedBuffer
+
+ f = io.BytesIO(b"abc")
+ wrapper = ReadOnlyFileBasedBuffer(f, 8192)
+ wrapper.prepare()
+ inst, sock, map = self._makeOneWithMap()
+ outbufs = inst.outbufs
+ orig_outbuf = outbufs[0]
+ wrote = inst.write_soon(wrapper)
+ self.assertEqual(wrote, 3)
+ self.assertEqual(len(outbufs), 3)
+ self.assertEqual(outbufs[0], orig_outbuf)
+ self.assertEqual(outbufs[1], wrapper)
+ self.assertEqual(outbufs[2].__class__.__name__, "OverflowableBuffer")
+
+ def test_write_soon_disconnected(self):
+ from waitress.channel import ClientDisconnected
+
+ inst, sock, map = self._makeOneWithMap()
+ inst.connected = False
+ self.assertRaises(ClientDisconnected, lambda: inst.write_soon(b"stuff"))
+
+ def test_write_soon_disconnected_while_over_watermark(self):
+ from waitress.channel import ClientDisconnected
+
+ inst, sock, map = self._makeOneWithMap()
+
+ def dummy_flush():
+ inst.connected = False
+
+ inst._flush_outbufs_below_high_watermark = dummy_flush
+ self.assertRaises(ClientDisconnected, lambda: inst.write_soon(b"stuff"))
+
+ def test_write_soon_rotates_outbuf_on_overflow(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.adj.outbuf_high_watermark = 3
+ inst.current_outbuf_count = 4
+ wrote = inst.write_soon(b"xyz")
+ self.assertEqual(wrote, 3)
+ self.assertEqual(len(inst.outbufs), 2)
+ self.assertEqual(inst.outbufs[0].get(), b"")
+ self.assertEqual(inst.outbufs[1].get(), b"xyz")
+
+ def test_write_soon_waits_on_backpressure(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.adj.outbuf_high_watermark = 3
+ inst.total_outbufs_len = 4
+ inst.current_outbuf_count = 4
+
+ class Lock(DummyLock):
+ def wait(self):
+ inst.total_outbufs_len = 0
+ super(Lock, self).wait()
+
+ inst.outbuf_lock = Lock()
+ wrote = inst.write_soon(b"xyz")
+ self.assertEqual(wrote, 3)
+ self.assertEqual(len(inst.outbufs), 2)
+ self.assertEqual(inst.outbufs[0].get(), b"")
+ self.assertEqual(inst.outbufs[1].get(), b"xyz")
+ self.assertTrue(inst.outbuf_lock.waited)
+
+ def test_handle_write_notify_after_flush(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.requests = [True]
+ inst.outbufs = [DummyBuffer(b"abc")]
+ inst.total_outbufs_len = len(inst.outbufs[0])
+ inst.adj.send_bytes = 1
+ inst.adj.outbuf_high_watermark = 5
+ inst.will_close = False
+ inst.last_activity = 0
+ result = inst.handle_write()
+ self.assertEqual(result, None)
+ self.assertEqual(inst.will_close, False)
+ self.assertTrue(inst.outbuf_lock.acquired)
+ self.assertTrue(inst.outbuf_lock.notified)
+ self.assertEqual(sock.sent, b"abc")
+
+ def test_handle_write_no_notify_after_flush(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.requests = [True]
+ inst.outbufs = [DummyBuffer(b"abc")]
+ inst.total_outbufs_len = len(inst.outbufs[0])
+ inst.adj.send_bytes = 1
+ inst.adj.outbuf_high_watermark = 2
+ sock.send = lambda x: False
+ inst.will_close = False
+ inst.last_activity = 0
+ result = inst.handle_write()
+ self.assertEqual(result, None)
+ self.assertEqual(inst.will_close, False)
+ self.assertTrue(inst.outbuf_lock.acquired)
+ self.assertFalse(inst.outbuf_lock.notified)
+ self.assertEqual(sock.sent, b"")
+
+ def test__flush_some_empty_outbuf(self):
+ inst, sock, map = self._makeOneWithMap()
+ result = inst._flush_some()
+ self.assertEqual(result, False)
+
+ def test__flush_some_full_outbuf_socket_returns_nonzero(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.outbufs[0].append(b"abc")
+ inst.total_outbufs_len = sum(len(x) for x in inst.outbufs)
+ result = inst._flush_some()
+ self.assertEqual(result, True)
+
+ def test__flush_some_full_outbuf_socket_returns_zero(self):
+ inst, sock, map = self._makeOneWithMap()
+ sock.send = lambda x: False
+ inst.outbufs[0].append(b"abc")
+ inst.total_outbufs_len = sum(len(x) for x in inst.outbufs)
+ result = inst._flush_some()
+ self.assertEqual(result, False)
+
+ def test_flush_some_multiple_buffers_first_empty(self):
+ inst, sock, map = self._makeOneWithMap()
+ sock.send = lambda x: len(x)
+ buffer = DummyBuffer(b"abc")
+ inst.outbufs.append(buffer)
+ inst.total_outbufs_len = sum(len(x) for x in inst.outbufs)
+ result = inst._flush_some()
+ self.assertEqual(result, True)
+ self.assertEqual(buffer.skipped, 3)
+ self.assertEqual(inst.outbufs, [buffer])
+
+ def test_flush_some_multiple_buffers_close_raises(self):
+ inst, sock, map = self._makeOneWithMap()
+ sock.send = lambda x: len(x)
+ buffer = DummyBuffer(b"abc")
+ inst.outbufs.append(buffer)
+ inst.total_outbufs_len = sum(len(x) for x in inst.outbufs)
+ inst.logger = DummyLogger()
+
+ def doraise():
+ raise NotImplementedError
+
+ inst.outbufs[0].close = doraise
+ result = inst._flush_some()
+ self.assertEqual(result, True)
+ self.assertEqual(buffer.skipped, 3)
+ self.assertEqual(inst.outbufs, [buffer])
+ self.assertEqual(len(inst.logger.exceptions), 1)
+
+ def test__flush_some_outbuf_len_gt_sys_maxint(self):
+ from waitress.compat import MAXINT
+
+ inst, sock, map = self._makeOneWithMap()
+
+ class DummyHugeOutbuffer(object):
+ def __init__(self):
+ self.length = MAXINT + 1
+
+ def __len__(self):
+ return self.length
+
+ def get(self, numbytes):
+ self.length = 0
+ return b"123"
+
+ buf = DummyHugeOutbuffer()
+ inst.outbufs = [buf]
+ inst.send = lambda *arg: 0
+ result = inst._flush_some()
+ # we are testing that _flush_some doesn't raise an OverflowError
+ # when one of its outbufs has a __len__ that returns gt sys.maxint
+ self.assertEqual(result, False)
+
+ def test_handle_close(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.handle_close()
+ self.assertEqual(inst.connected, False)
+ self.assertEqual(sock.closed, True)
+
+ def test_handle_close_outbuf_raises_on_close(self):
+ inst, sock, map = self._makeOneWithMap()
+
+ def doraise():
+ raise NotImplementedError
+
+ inst.outbufs[0].close = doraise
+ inst.logger = DummyLogger()
+ inst.handle_close()
+ self.assertEqual(inst.connected, False)
+ self.assertEqual(sock.closed, True)
+ self.assertEqual(len(inst.logger.exceptions), 1)
+
+ def test_add_channel(self):
+ inst, sock, map = self._makeOneWithMap()
+ fileno = inst._fileno
+ inst.add_channel(map)
+ self.assertEqual(map[fileno], inst)
+ self.assertEqual(inst.server.active_channels[fileno], inst)
+
+ def test_del_channel(self):
+ inst, sock, map = self._makeOneWithMap()
+ fileno = inst._fileno
+ inst.server.active_channels[fileno] = True
+ inst.del_channel(map)
+ self.assertEqual(map.get(fileno), None)
+ self.assertEqual(inst.server.active_channels.get(fileno), None)
+
+ def test_received(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.server = DummyServer()
+ inst.received(b"GET / HTTP/1.1\r\n\r\n")
+ self.assertEqual(inst.server.tasks, [inst])
+ self.assertTrue(inst.requests)
+
+ def test_received_no_chunk(self):
+ inst, sock, map = self._makeOneWithMap()
+ self.assertEqual(inst.received(b""), False)
+
+ def test_received_preq_not_completed(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.server = DummyServer()
+ preq = DummyParser()
+ inst.request = preq
+ preq.completed = False
+ preq.empty = True
+ inst.received(b"GET / HTTP/1.1\r\n\r\n")
+ self.assertEqual(inst.requests, ())
+ self.assertEqual(inst.server.tasks, [])
+
+ def test_received_preq_completed_empty(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.server = DummyServer()
+ preq = DummyParser()
+ inst.request = preq
+ preq.completed = True
+ preq.empty = True
+ inst.received(b"GET / HTTP/1.1\r\n\r\n")
+ self.assertEqual(inst.request, None)
+ self.assertEqual(inst.server.tasks, [])
+
+ def test_received_preq_error(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.server = DummyServer()
+ preq = DummyParser()
+ inst.request = preq
+ preq.completed = True
+ preq.error = True
+ inst.received(b"GET / HTTP/1.1\r\n\r\n")
+ self.assertEqual(inst.request, None)
+ self.assertEqual(len(inst.server.tasks), 1)
+ self.assertTrue(inst.requests)
+
+ def test_received_preq_completed_connection_close(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.server = DummyServer()
+ preq = DummyParser()
+ inst.request = preq
+ preq.completed = True
+ preq.empty = True
+ preq.connection_close = True
+ inst.received(b"GET / HTTP/1.1\r\n\r\n" + b"a" * 50000)
+ self.assertEqual(inst.request, None)
+ self.assertEqual(inst.server.tasks, [])
+
+ def test_received_headers_finished_expect_continue_false(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.server = DummyServer()
+ preq = DummyParser()
+ inst.request = preq
+ preq.expect_continue = False
+ preq.headers_finished = True
+ preq.completed = False
+ preq.empty = False
+ preq.retval = 1
+ inst.received(b"GET / HTTP/1.1\r\n\r\n")
+ self.assertEqual(inst.request, preq)
+ self.assertEqual(inst.server.tasks, [])
+ self.assertEqual(inst.outbufs[0].get(100), b"")
+
+ def test_received_headers_finished_expect_continue_true(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.server = DummyServer()
+ preq = DummyParser()
+ inst.request = preq
+ preq.expect_continue = True
+ preq.headers_finished = True
+ preq.completed = False
+ preq.empty = False
+ inst.received(b"GET / HTTP/1.1\r\n\r\n")
+ self.assertEqual(inst.request, preq)
+ self.assertEqual(inst.server.tasks, [])
+ self.assertEqual(sock.sent, b"HTTP/1.1 100 Continue\r\n\r\n")
+ self.assertEqual(inst.sent_continue, True)
+ self.assertEqual(preq.completed, False)
+
+ def test_received_headers_finished_expect_continue_true_sent_true(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.server = DummyServer()
+ preq = DummyParser()
+ inst.request = preq
+ preq.expect_continue = True
+ preq.headers_finished = True
+ preq.completed = False
+ preq.empty = False
+ inst.sent_continue = True
+ inst.received(b"GET / HTTP/1.1\r\n\r\n")
+ self.assertEqual(inst.request, preq)
+ self.assertEqual(inst.server.tasks, [])
+ self.assertEqual(sock.sent, b"")
+ self.assertEqual(inst.sent_continue, True)
+ self.assertEqual(preq.completed, False)
+
+ def test_service_no_requests(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.requests = []
+ inst.service()
+ self.assertEqual(inst.requests, [])
+ self.assertTrue(inst.server.trigger_pulled)
+ self.assertTrue(inst.last_activity)
+
+ def test_service_with_one_request(self):
+ inst, sock, map = self._makeOneWithMap()
+ request = DummyRequest()
+ inst.task_class = DummyTaskClass()
+ inst.requests = [request]
+ inst.service()
+ self.assertEqual(inst.requests, [])
+ self.assertTrue(request.serviced)
+ self.assertTrue(request.closed)
+
+ def test_service_with_one_error_request(self):
+ inst, sock, map = self._makeOneWithMap()
+ request = DummyRequest()
+ request.error = DummyError()
+ inst.error_task_class = DummyTaskClass()
+ inst.requests = [request]
+ inst.service()
+ self.assertEqual(inst.requests, [])
+ self.assertTrue(request.serviced)
+ self.assertTrue(request.closed)
+
+ def test_service_with_multiple_requests(self):
+ inst, sock, map = self._makeOneWithMap()
+ request1 = DummyRequest()
+ request2 = DummyRequest()
+ inst.task_class = DummyTaskClass()
+ inst.requests = [request1, request2]
+ inst.service()
+ self.assertEqual(inst.requests, [])
+ self.assertTrue(request1.serviced)
+ self.assertTrue(request2.serviced)
+ self.assertTrue(request1.closed)
+ self.assertTrue(request2.closed)
+
+ def test_service_with_request_raises(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.adj.expose_tracebacks = False
+ inst.server = DummyServer()
+ request = DummyRequest()
+ inst.requests = [request]
+ inst.task_class = DummyTaskClass(ValueError)
+ inst.task_class.wrote_header = False
+ inst.error_task_class = DummyTaskClass()
+ inst.logger = DummyLogger()
+ inst.service()
+ self.assertTrue(request.serviced)
+ self.assertEqual(inst.requests, [])
+ self.assertEqual(len(inst.logger.exceptions), 1)
+ self.assertTrue(inst.server.trigger_pulled)
+ self.assertTrue(inst.last_activity)
+ self.assertFalse(inst.will_close)
+ self.assertEqual(inst.error_task_class.serviced, True)
+ self.assertTrue(request.closed)
+
+ def test_service_with_requests_raises_already_wrote_header(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.adj.expose_tracebacks = False
+ inst.server = DummyServer()
+ request = DummyRequest()
+ inst.requests = [request]
+ inst.task_class = DummyTaskClass(ValueError)
+ inst.error_task_class = DummyTaskClass()
+ inst.logger = DummyLogger()
+ inst.service()
+ self.assertTrue(request.serviced)
+ self.assertEqual(inst.requests, [])
+ self.assertEqual(len(inst.logger.exceptions), 1)
+ self.assertTrue(inst.server.trigger_pulled)
+ self.assertTrue(inst.last_activity)
+ self.assertTrue(inst.close_when_flushed)
+ self.assertEqual(inst.error_task_class.serviced, False)
+ self.assertTrue(request.closed)
+
+ def test_service_with_requests_raises_didnt_write_header_expose_tbs(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.adj.expose_tracebacks = True
+ inst.server = DummyServer()
+ request = DummyRequest()
+ inst.requests = [request]
+ inst.task_class = DummyTaskClass(ValueError)
+ inst.task_class.wrote_header = False
+ inst.error_task_class = DummyTaskClass()
+ inst.logger = DummyLogger()
+ inst.service()
+ self.assertTrue(request.serviced)
+ self.assertFalse(inst.will_close)
+ self.assertEqual(inst.requests, [])
+ self.assertEqual(len(inst.logger.exceptions), 1)
+ self.assertTrue(inst.server.trigger_pulled)
+ self.assertTrue(inst.last_activity)
+ self.assertEqual(inst.error_task_class.serviced, True)
+ self.assertTrue(request.closed)
+
+ def test_service_with_requests_raises_didnt_write_header(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.adj.expose_tracebacks = False
+ inst.server = DummyServer()
+ request = DummyRequest()
+ inst.requests = [request]
+ inst.task_class = DummyTaskClass(ValueError)
+ inst.task_class.wrote_header = False
+ inst.logger = DummyLogger()
+ inst.service()
+ self.assertTrue(request.serviced)
+ self.assertEqual(inst.requests, [])
+ self.assertEqual(len(inst.logger.exceptions), 1)
+ self.assertTrue(inst.server.trigger_pulled)
+ self.assertTrue(inst.last_activity)
+ self.assertTrue(inst.close_when_flushed)
+ self.assertTrue(request.closed)
+
+ def test_service_with_request_raises_disconnect(self):
+ from waitress.channel import ClientDisconnected
+
+ inst, sock, map = self._makeOneWithMap()
+ inst.adj.expose_tracebacks = False
+ inst.server = DummyServer()
+ request = DummyRequest()
+ inst.requests = [request]
+ inst.task_class = DummyTaskClass(ClientDisconnected)
+ inst.error_task_class = DummyTaskClass()
+ inst.logger = DummyLogger()
+ inst.service()
+ self.assertTrue(request.serviced)
+ self.assertEqual(inst.requests, [])
+ self.assertEqual(len(inst.logger.infos), 1)
+ self.assertTrue(inst.server.trigger_pulled)
+ self.assertTrue(inst.last_activity)
+ self.assertFalse(inst.will_close)
+ self.assertEqual(inst.error_task_class.serviced, False)
+ self.assertTrue(request.closed)
+
+ def test_service_with_request_error_raises_disconnect(self):
+ from waitress.channel import ClientDisconnected
+
+ inst, sock, map = self._makeOneWithMap()
+ inst.adj.expose_tracebacks = False
+ inst.server = DummyServer()
+ request = DummyRequest()
+ err_request = DummyRequest()
+ inst.requests = [request]
+ inst.parser_class = lambda x: err_request
+ inst.task_class = DummyTaskClass(RuntimeError)
+ inst.task_class.wrote_header = False
+ inst.error_task_class = DummyTaskClass(ClientDisconnected)
+ inst.logger = DummyLogger()
+ inst.service()
+ self.assertTrue(request.serviced)
+ self.assertTrue(err_request.serviced)
+ self.assertEqual(inst.requests, [])
+ self.assertEqual(len(inst.logger.exceptions), 1)
+ self.assertEqual(len(inst.logger.infos), 0)
+ self.assertTrue(inst.server.trigger_pulled)
+ self.assertTrue(inst.last_activity)
+ self.assertFalse(inst.will_close)
+ self.assertEqual(inst.task_class.serviced, True)
+ self.assertEqual(inst.error_task_class.serviced, True)
+ self.assertTrue(request.closed)
+
+ def test_cancel_no_requests(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.requests = ()
+ inst.cancel()
+ self.assertEqual(inst.requests, [])
+
+ def test_cancel_with_requests(self):
+ inst, sock, map = self._makeOneWithMap()
+ inst.requests = [None]
+ inst.cancel()
+ self.assertEqual(inst.requests, [])
+
+
+class DummySock(object):
+ blocking = False
+ closed = False
+
+ def __init__(self):
+ self.sent = b""
+
+ def setblocking(self, *arg):
+ self.blocking = True
+
+ def fileno(self):
+ return 100
+
+ def getpeername(self):
+ return "127.0.0.1"
+
+ def getsockopt(self, level, option):
+ return 2048
+
+ def close(self):
+ self.closed = True
+
+ def send(self, data):
+ self.sent += data
+ return len(data)
+
+
+class DummyLock(object):
+ notified = False
+
+ def __init__(self, acquirable=True):
+ self.acquirable = acquirable
+
+ def acquire(self, val):
+ self.val = val
+ self.acquired = True
+ return self.acquirable
+
+ def release(self):
+ self.released = True
+
+ def notify(self):
+ self.notified = True
+
+ def wait(self):
+ self.waited = True
+
+ def __exit__(self, type, val, traceback):
+ self.acquire(True)
+
+ def __enter__(self):
+ pass
+
+
+class DummyBuffer(object):
+ closed = False
+
+ def __init__(self, data, toraise=None):
+ self.data = data
+ self.toraise = toraise
+
+ def get(self, *arg):
+ if self.toraise:
+ raise self.toraise
+ data = self.data
+ self.data = b""
+ return data
+
+ def skip(self, num, x):
+ self.skipped = num
+
+ def __len__(self):
+ return len(self.data)
+
+ def close(self):
+ self.closed = True
+
+
+class DummyAdjustments(object):
+ outbuf_overflow = 1048576
+ outbuf_high_watermark = 1048576
+ inbuf_overflow = 512000
+ cleanup_interval = 900
+ url_scheme = "http"
+ channel_timeout = 300
+ log_socket_errors = True
+ recv_bytes = 8192
+ send_bytes = 1
+ expose_tracebacks = True
+ ident = "waitress"
+ max_request_header_size = 10000
+
+
+class DummyServer(object):
+ trigger_pulled = False
+ adj = DummyAdjustments()
+
+ def __init__(self):
+ self.tasks = []
+ self.active_channels = {}
+
+ def add_task(self, task):
+ self.tasks.append(task)
+
+ def pull_trigger(self):
+ self.trigger_pulled = True
+
+
+class DummyParser(object):
+ version = 1
+ data = None
+ completed = True
+ empty = False
+ headers_finished = False
+ expect_continue = False
+ retval = None
+ error = None
+ connection_close = False
+
+ def received(self, data):
+ self.data = data
+ if self.retval is not None:
+ return self.retval
+ return len(data)
+
+
+class DummyRequest(object):
+ error = None
+ path = "/"
+ version = "1.0"
+ closed = False
+
+ def __init__(self):
+ self.headers = {}
+
+ def close(self):
+ self.closed = True
+
+
+class DummyLogger(object):
+ def __init__(self):
+ self.exceptions = []
+ self.infos = []
+ self.warnings = []
+
+ def info(self, msg):
+ self.infos.append(msg)
+
+ def exception(self, msg):
+ self.exceptions.append(msg)
+
+
+class DummyError(object):
+ code = "431"
+ reason = "Bleh"
+ body = "My body"
+
+
+class DummyTaskClass(object):
+ wrote_header = True
+ close_on_finish = False
+ serviced = False
+
+ def __init__(self, toraise=None):
+ self.toraise = toraise
+
+ def __call__(self, channel, request):
+ self.request = request
+ return self
+
+ def service(self):
+ self.serviced = True
+ self.request.serviced = True
+ if self.toraise:
+ raise self.toraise
diff --git a/libs/waitress/tests/test_compat.py b/libs/waitress/tests/test_compat.py
new file mode 100644
index 000000000..37c219303
--- /dev/null
+++ b/libs/waitress/tests/test_compat.py
@@ -0,0 +1,22 @@
+# -*- coding: utf-8 -*-
+
+import unittest
+
+
+class Test_unquote_bytes_to_wsgi(unittest.TestCase):
+ def _callFUT(self, v):
+ from waitress.compat import unquote_bytes_to_wsgi
+
+ return unquote_bytes_to_wsgi(v)
+
+ def test_highorder(self):
+ from waitress.compat import PY3
+
+ val = b"/a%C5%9B"
+ result = self._callFUT(val)
+ if PY3: # pragma: no cover
+ # PEP 3333 urlunquoted-latin1-decoded-bytes
+ self.assertEqual(result, "/aÅ\x9b")
+ else: # pragma: no cover
+ # sanity
+ self.assertEqual(result, b"/a\xc5\x9b")
diff --git a/libs/waitress/tests/test_functional.py b/libs/waitress/tests/test_functional.py
new file mode 100644
index 000000000..8f4b262fe
--- /dev/null
+++ b/libs/waitress/tests/test_functional.py
@@ -0,0 +1,1667 @@
+import errno
+import logging
+import multiprocessing
+import os
+import signal
+import socket
+import string
+import subprocess
+import sys
+import time
+import unittest
+from waitress import server
+from waitress.compat import httplib, tobytes
+from waitress.utilities import cleanup_unix_socket
+
+dn = os.path.dirname
+here = dn(__file__)
+
+
+class NullHandler(logging.Handler): # pragma: no cover
+ """A logging handler that swallows all emitted messages.
+ """
+
+ def emit(self, record):
+ pass
+
+
+def start_server(app, svr, queue, **kwargs): # pragma: no cover
+ """Run a fixture application.
+ """
+ logging.getLogger("waitress").addHandler(NullHandler())
+ try_register_coverage()
+ svr(app, queue, **kwargs).run()
+
+
+def try_register_coverage(): # pragma: no cover
+ # Hack around multiprocessing exiting early and not triggering coverage's
+ # atexit handler by always registering a signal handler
+
+ if "COVERAGE_PROCESS_START" in os.environ:
+ def sigterm(*args):
+ sys.exit(0)
+
+ signal.signal(signal.SIGTERM, sigterm)
+
+
+class FixtureTcpWSGIServer(server.TcpWSGIServer):
+ """A version of TcpWSGIServer that relays back what it's bound to.
+ """
+
+ family = socket.AF_INET # Testing
+
+ def __init__(self, application, queue, **kw): # pragma: no cover
+ # Coverage doesn't see this as it's ran in a separate process.
+ kw["port"] = 0 # Bind to any available port.
+ super(FixtureTcpWSGIServer, self).__init__(application, **kw)
+ host, port = self.socket.getsockname()
+ if os.name == "nt":
+ host = "127.0.0.1"
+ queue.put((host, port))
+
+
+class SubprocessTests(object):
+
+ # For nose: all tests may be ran in separate processes.
+ _multiprocess_can_split_ = True
+
+ exe = sys.executable
+
+ server = None
+
+ def start_subprocess(self, target, **kw):
+ # Spawn a server process.
+ self.queue = multiprocessing.Queue()
+
+ if "COVERAGE_RCFILE" in os.environ:
+ os.environ["COVERAGE_PROCESS_START"] = os.environ["COVERAGE_RCFILE"]
+
+ self.proc = multiprocessing.Process(
+ target=start_server, args=(target, self.server, self.queue), kwargs=kw,
+ )
+ self.proc.start()
+
+ if self.proc.exitcode is not None: # pragma: no cover
+ raise RuntimeError("%s didn't start" % str(target))
+ # Get the socket the server is listening on.
+ self.bound_to = self.queue.get(timeout=5)
+ self.sock = self.create_socket()
+
+ def stop_subprocess(self):
+ if self.proc.exitcode is None:
+ self.proc.terminate()
+ self.sock.close()
+ # This give us one FD back ...
+ self.queue.close()
+ self.proc.join()
+
+ def assertline(self, line, status, reason, version):
+ v, s, r = (x.strip() for x in line.split(None, 2))
+ self.assertEqual(s, tobytes(status))
+ self.assertEqual(r, tobytes(reason))
+ self.assertEqual(v, tobytes(version))
+
+ def create_socket(self):
+ return socket.socket(self.server.family, socket.SOCK_STREAM)
+
+ def connect(self):
+ self.sock.connect(self.bound_to)
+
+ def make_http_connection(self):
+ raise NotImplementedError # pragma: no cover
+
+ def send_check_error(self, to_send):
+ self.sock.send(to_send)
+
+
+class TcpTests(SubprocessTests):
+
+ server = FixtureTcpWSGIServer
+
+ def make_http_connection(self):
+ return httplib.HTTPConnection(*self.bound_to)
+
+
+class SleepyThreadTests(TcpTests, unittest.TestCase):
+ # test that sleepy thread doesnt block other requests
+
+ def setUp(self):
+ from waitress.tests.fixtureapps import sleepy
+
+ self.start_subprocess(sleepy.app)
+
+ def tearDown(self):
+ self.stop_subprocess()
+
+ def test_it(self):
+ getline = os.path.join(here, "fixtureapps", "getline.py")
+ cmds = (
+ [self.exe, getline, "http://%s:%d/sleepy" % self.bound_to],
+ [self.exe, getline, "http://%s:%d/" % self.bound_to],
+ )
+ r, w = os.pipe()
+ procs = []
+ for cmd in cmds:
+ procs.append(subprocess.Popen(cmd, stdout=w))
+ time.sleep(3)
+ for proc in procs:
+ if proc.returncode is not None: # pragma: no cover
+ proc.terminate()
+ proc.wait()
+ # the notsleepy response should always be first returned (it sleeps
+ # for 2 seconds, then returns; the notsleepy response should be
+ # processed in the meantime)
+ result = os.read(r, 10000)
+ os.close(r)
+ os.close(w)
+ self.assertEqual(result, b"notsleepy returnedsleepy returned")
+
+
+class EchoTests(object):
+ def setUp(self):
+ from waitress.tests.fixtureapps import echo
+
+ self.start_subprocess(
+ echo.app,
+ trusted_proxy="*",
+ trusted_proxy_count=1,
+ trusted_proxy_headers={"x-forwarded-for", "x-forwarded-proto"},
+ clear_untrusted_proxy_headers=True,
+ )
+
+ def tearDown(self):
+ self.stop_subprocess()
+
+ def _read_echo(self, fp):
+ from waitress.tests.fixtureapps import echo
+
+ line, headers, body = read_http(fp)
+ return line, headers, echo.parse_response(body)
+
+ def test_date_and_server(self):
+ to_send = "GET / HTTP/1.0\r\nContent-Length: 0\r\n\r\n"
+ to_send = tobytes(to_send)
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, echo = self._read_echo(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.0")
+ self.assertEqual(headers.get("server"), "waitress")
+ self.assertTrue(headers.get("date"))
+
+ def test_bad_host_header(self):
+ # https://corte.si/posts/code/pathod/pythonservers/index.html
+ to_send = "GET / HTTP/1.0\r\n Host: 0\r\n\r\n"
+ to_send = tobytes(to_send)
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "400", "Bad Request", "HTTP/1.0")
+ self.assertEqual(headers.get("server"), "waitress")
+ self.assertTrue(headers.get("date"))
+
+ def test_send_with_body(self):
+ to_send = "GET / HTTP/1.0\r\nContent-Length: 5\r\n\r\n"
+ to_send += "hello"
+ to_send = tobytes(to_send)
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, echo = self._read_echo(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.0")
+ self.assertEqual(echo.content_length, "5")
+ self.assertEqual(echo.body, b"hello")
+
+ def test_send_empty_body(self):
+ to_send = "GET / HTTP/1.0\r\nContent-Length: 0\r\n\r\n"
+ to_send = tobytes(to_send)
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, echo = self._read_echo(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.0")
+ self.assertEqual(echo.content_length, "0")
+ self.assertEqual(echo.body, b"")
+
+ def test_multiple_requests_with_body(self):
+ orig_sock = self.sock
+ for x in range(3):
+ self.sock = self.create_socket()
+ self.test_send_with_body()
+ self.sock.close()
+ self.sock = orig_sock
+
+ def test_multiple_requests_without_body(self):
+ orig_sock = self.sock
+ for x in range(3):
+ self.sock = self.create_socket()
+ self.test_send_empty_body()
+ self.sock.close()
+ self.sock = orig_sock
+
+ def test_without_crlf(self):
+ data = "Echo\r\nthis\r\nplease"
+ s = tobytes(
+ "GET / HTTP/1.0\r\n"
+ "Connection: close\r\n"
+ "Content-Length: %d\r\n"
+ "\r\n"
+ "%s" % (len(data), data)
+ )
+ self.connect()
+ self.sock.send(s)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, echo = self._read_echo(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.0")
+ self.assertEqual(int(echo.content_length), len(data))
+ self.assertEqual(len(echo.body), len(data))
+ self.assertEqual(echo.body, tobytes(data))
+
+ def test_large_body(self):
+ # 1024 characters.
+ body = "This string has 32 characters.\r\n" * 32
+ s = tobytes(
+ "GET / HTTP/1.0\r\nContent-Length: %d\r\n\r\n%s" % (len(body), body)
+ )
+ self.connect()
+ self.sock.send(s)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, echo = self._read_echo(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.0")
+ self.assertEqual(echo.content_length, "1024")
+ self.assertEqual(echo.body, tobytes(body))
+
+ def test_many_clients(self):
+ conns = []
+ for n in range(50):
+ h = self.make_http_connection()
+ h.request("GET", "/", headers={"Accept": "text/plain"})
+ conns.append(h)
+ responses = []
+ for h in conns:
+ response = h.getresponse()
+ self.assertEqual(response.status, 200)
+ responses.append(response)
+ for response in responses:
+ response.read()
+ for h in conns:
+ h.close()
+
+ def test_chunking_request_without_content(self):
+ header = tobytes("GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n")
+ self.connect()
+ self.sock.send(header)
+ self.sock.send(b"0\r\n\r\n")
+ fp = self.sock.makefile("rb", 0)
+ line, headers, echo = self._read_echo(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.1")
+ self.assertEqual(echo.body, b"")
+ self.assertEqual(echo.content_length, "0")
+ self.assertFalse("transfer-encoding" in headers)
+
+ def test_chunking_request_with_content(self):
+ control_line = b"20;\r\n" # 20 hex = 32 dec
+ s = b"This string has 32 characters.\r\n"
+ expected = s * 12
+ header = tobytes("GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n")
+ self.connect()
+ self.sock.send(header)
+ fp = self.sock.makefile("rb", 0)
+ for n in range(12):
+ self.sock.send(control_line)
+ self.sock.send(s)
+ self.sock.send(b"\r\n") # End the chunk
+ self.sock.send(b"0\r\n\r\n")
+ line, headers, echo = self._read_echo(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.1")
+ self.assertEqual(echo.body, expected)
+ self.assertEqual(echo.content_length, str(len(expected)))
+ self.assertFalse("transfer-encoding" in headers)
+
+ def test_broken_chunked_encoding(self):
+ control_line = "20;\r\n" # 20 hex = 32 dec
+ s = "This string has 32 characters.\r\n"
+ to_send = "GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n"
+ to_send += control_line + s + "\r\n"
+ # garbage in input
+ to_send += "garbage\r\n"
+ to_send = tobytes(to_send)
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ # receiver caught garbage and turned it into a 400
+ self.assertline(line, "400", "Bad Request", "HTTP/1.1")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ self.assertEqual(
+ sorted(headers.keys()), ["connection", "content-length", "content-type", "date", "server"]
+ )
+ self.assertEqual(headers["content-type"], "text/plain")
+ # connection has been closed
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+ def test_broken_chunked_encoding_missing_chunk_end(self):
+ control_line = "20;\r\n" # 20 hex = 32 dec
+ s = "This string has 32 characters.\r\n"
+ to_send = "GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n"
+ to_send += control_line + s
+ # garbage in input
+ to_send += "garbage"
+ to_send = tobytes(to_send)
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ # receiver caught garbage and turned it into a 400
+ self.assertline(line, "400", "Bad Request", "HTTP/1.1")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ self.assertTrue(b"Chunk not properly terminated" in response_body)
+ self.assertEqual(
+ sorted(headers.keys()), ["connection", "content-length", "content-type", "date", "server"]
+ )
+ self.assertEqual(headers["content-type"], "text/plain")
+ # connection has been closed
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+ def test_keepalive_http_10(self):
+ # Handling of Keep-Alive within HTTP 1.0
+ data = "Default: Don't keep me alive"
+ s = tobytes(
+ "GET / HTTP/1.0\r\nContent-Length: %d\r\n\r\n%s" % (len(data), data)
+ )
+ self.connect()
+ self.sock.send(s)
+ response = httplib.HTTPResponse(self.sock)
+ response.begin()
+ self.assertEqual(int(response.status), 200)
+ connection = response.getheader("Connection", "")
+ # We sent no Connection: Keep-Alive header
+ # Connection: close (or no header) is default.
+ self.assertTrue(connection != "Keep-Alive")
+
+ def test_keepalive_http10_explicit(self):
+ # If header Connection: Keep-Alive is explicitly sent,
+ # we want to keept the connection open, we also need to return
+ # the corresponding header
+ data = "Keep me alive"
+ s = tobytes(
+ "GET / HTTP/1.0\r\n"
+ "Connection: Keep-Alive\r\n"
+ "Content-Length: %d\r\n"
+ "\r\n"
+ "%s" % (len(data), data)
+ )
+ self.connect()
+ self.sock.send(s)
+ response = httplib.HTTPResponse(self.sock)
+ response.begin()
+ self.assertEqual(int(response.status), 200)
+ connection = response.getheader("Connection", "")
+ self.assertEqual(connection, "Keep-Alive")
+
+ def test_keepalive_http_11(self):
+ # Handling of Keep-Alive within HTTP 1.1
+
+ # All connections are kept alive, unless stated otherwise
+ data = "Default: Keep me alive"
+ s = tobytes(
+ "GET / HTTP/1.1\r\nContent-Length: %d\r\n\r\n%s" % (len(data), data)
+ )
+ self.connect()
+ self.sock.send(s)
+ response = httplib.HTTPResponse(self.sock)
+ response.begin()
+ self.assertEqual(int(response.status), 200)
+ self.assertTrue(response.getheader("connection") != "close")
+
+ def test_keepalive_http11_explicit(self):
+ # Explicitly set keep-alive
+ data = "Default: Keep me alive"
+ s = tobytes(
+ "GET / HTTP/1.1\r\n"
+ "Connection: keep-alive\r\n"
+ "Content-Length: %d\r\n"
+ "\r\n"
+ "%s" % (len(data), data)
+ )
+ self.connect()
+ self.sock.send(s)
+ response = httplib.HTTPResponse(self.sock)
+ response.begin()
+ self.assertEqual(int(response.status), 200)
+ self.assertTrue(response.getheader("connection") != "close")
+
+ def test_keepalive_http11_connclose(self):
+ # specifying Connection: close explicitly
+ data = "Don't keep me alive"
+ s = tobytes(
+ "GET / HTTP/1.1\r\n"
+ "Connection: close\r\n"
+ "Content-Length: %d\r\n"
+ "\r\n"
+ "%s" % (len(data), data)
+ )
+ self.connect()
+ self.sock.send(s)
+ response = httplib.HTTPResponse(self.sock)
+ response.begin()
+ self.assertEqual(int(response.status), 200)
+ self.assertEqual(response.getheader("connection"), "close")
+
+ def test_proxy_headers(self):
+ to_send = (
+ "GET / HTTP/1.0\r\n"
+ "Content-Length: 0\r\n"
+ "Host: www.google.com:8080\r\n"
+ "X-Forwarded-For: 192.168.1.1\r\n"
+ "X-Forwarded-Proto: https\r\n"
+ "X-Forwarded-Port: 5000\r\n\r\n"
+ )
+ to_send = tobytes(to_send)
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, echo = self._read_echo(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.0")
+ self.assertEqual(headers.get("server"), "waitress")
+ self.assertTrue(headers.get("date"))
+ self.assertIsNone(echo.headers.get("X_FORWARDED_PORT"))
+ self.assertEqual(echo.headers["HOST"], "www.google.com:8080")
+ self.assertEqual(echo.scheme, "https")
+ self.assertEqual(echo.remote_addr, "192.168.1.1")
+ self.assertEqual(echo.remote_host, "192.168.1.1")
+
+
+class PipeliningTests(object):
+ def setUp(self):
+ from waitress.tests.fixtureapps import echo
+
+ self.start_subprocess(echo.app_body_only)
+
+ def tearDown(self):
+ self.stop_subprocess()
+
+ def test_pipelining(self):
+ s = (
+ "GET / HTTP/1.0\r\n"
+ "Connection: %s\r\n"
+ "Content-Length: %d\r\n"
+ "\r\n"
+ "%s"
+ )
+ to_send = b""
+ count = 25
+ for n in range(count):
+ body = "Response #%d\r\n" % (n + 1)
+ if n + 1 < count:
+ conn = "keep-alive"
+ else:
+ conn = "close"
+ to_send += tobytes(s % (conn, len(body), body))
+
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ for n in range(count):
+ expect_body = tobytes("Response #%d\r\n" % (n + 1))
+ line = fp.readline() # status line
+ version, status, reason = (x.strip() for x in line.split(None, 2))
+ headers = parse_headers(fp)
+ length = int(headers.get("content-length")) or None
+ response_body = fp.read(length)
+ self.assertEqual(int(status), 200)
+ self.assertEqual(length, len(response_body))
+ self.assertEqual(response_body, expect_body)
+
+
+class ExpectContinueTests(object):
+ def setUp(self):
+ from waitress.tests.fixtureapps import echo
+
+ self.start_subprocess(echo.app_body_only)
+
+ def tearDown(self):
+ self.stop_subprocess()
+
+ def test_expect_continue(self):
+ # specifying Connection: close explicitly
+ data = "I have expectations"
+ to_send = tobytes(
+ "GET / HTTP/1.1\r\n"
+ "Connection: close\r\n"
+ "Content-Length: %d\r\n"
+ "Expect: 100-continue\r\n"
+ "\r\n"
+ "%s" % (len(data), data)
+ )
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line = fp.readline() # continue status line
+ version, status, reason = (x.strip() for x in line.split(None, 2))
+ self.assertEqual(int(status), 100)
+ self.assertEqual(reason, b"Continue")
+ self.assertEqual(version, b"HTTP/1.1")
+ fp.readline() # blank line
+ line = fp.readline() # next status line
+ version, status, reason = (x.strip() for x in line.split(None, 2))
+ headers = parse_headers(fp)
+ length = int(headers.get("content-length")) or None
+ response_body = fp.read(length)
+ self.assertEqual(int(status), 200)
+ self.assertEqual(length, len(response_body))
+ self.assertEqual(response_body, tobytes(data))
+
+
+class BadContentLengthTests(object):
+ def setUp(self):
+ from waitress.tests.fixtureapps import badcl
+
+ self.start_subprocess(badcl.app)
+
+ def tearDown(self):
+ self.stop_subprocess()
+
+ def test_short_body(self):
+ # check to see if server closes connection when body is too short
+ # for cl header
+ to_send = tobytes(
+ "GET /short_body HTTP/1.0\r\n"
+ "Connection: Keep-Alive\r\n"
+ "Content-Length: 0\r\n"
+ "\r\n"
+ )
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line = fp.readline() # status line
+ version, status, reason = (x.strip() for x in line.split(None, 2))
+ headers = parse_headers(fp)
+ content_length = int(headers.get("content-length"))
+ response_body = fp.read(content_length)
+ self.assertEqual(int(status), 200)
+ self.assertNotEqual(content_length, len(response_body))
+ self.assertEqual(len(response_body), content_length - 1)
+ self.assertEqual(response_body, tobytes("abcdefghi"))
+ # remote closed connection (despite keepalive header); not sure why
+ # first send succeeds
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+ def test_long_body(self):
+ # check server doesnt close connection when body is too short
+ # for cl header
+ to_send = tobytes(
+ "GET /long_body HTTP/1.0\r\n"
+ "Connection: Keep-Alive\r\n"
+ "Content-Length: 0\r\n"
+ "\r\n"
+ )
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line = fp.readline() # status line
+ version, status, reason = (x.strip() for x in line.split(None, 2))
+ headers = parse_headers(fp)
+ content_length = int(headers.get("content-length")) or None
+ response_body = fp.read(content_length)
+ self.assertEqual(int(status), 200)
+ self.assertEqual(content_length, len(response_body))
+ self.assertEqual(response_body, tobytes("abcdefgh"))
+ # remote does not close connection (keepalive header)
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line = fp.readline() # status line
+ version, status, reason = (x.strip() for x in line.split(None, 2))
+ headers = parse_headers(fp)
+ content_length = int(headers.get("content-length")) or None
+ response_body = fp.read(content_length)
+ self.assertEqual(int(status), 200)
+
+
+class NoContentLengthTests(object):
+ def setUp(self):
+ from waitress.tests.fixtureapps import nocl
+
+ self.start_subprocess(nocl.app)
+
+ def tearDown(self):
+ self.stop_subprocess()
+
+ def test_http10_generator(self):
+ body = string.ascii_letters
+ to_send = (
+ "GET / HTTP/1.0\r\n"
+ "Connection: Keep-Alive\r\n"
+ "Content-Length: %d\r\n\r\n" % len(body)
+ )
+ to_send += body
+ to_send = tobytes(to_send)
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.0")
+ self.assertEqual(headers.get("content-length"), None)
+ self.assertEqual(headers.get("connection"), "close")
+ self.assertEqual(response_body, tobytes(body))
+ # remote closed connection (despite keepalive header), because
+ # generators cannot have a content-length divined
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+ def test_http10_list(self):
+ body = string.ascii_letters
+ to_send = (
+ "GET /list HTTP/1.0\r\n"
+ "Connection: Keep-Alive\r\n"
+ "Content-Length: %d\r\n\r\n" % len(body)
+ )
+ to_send += body
+ to_send = tobytes(to_send)
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.0")
+ self.assertEqual(headers["content-length"], str(len(body)))
+ self.assertEqual(headers.get("connection"), "Keep-Alive")
+ self.assertEqual(response_body, tobytes(body))
+ # remote keeps connection open because it divined the content length
+ # from a length-1 list
+ self.sock.send(to_send)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.0")
+
+ def test_http10_listlentwo(self):
+ body = string.ascii_letters
+ to_send = (
+ "GET /list_lentwo HTTP/1.0\r\n"
+ "Connection: Keep-Alive\r\n"
+ "Content-Length: %d\r\n\r\n" % len(body)
+ )
+ to_send += body
+ to_send = tobytes(to_send)
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.0")
+ self.assertEqual(headers.get("content-length"), None)
+ self.assertEqual(headers.get("connection"), "close")
+ self.assertEqual(response_body, tobytes(body))
+ # remote closed connection (despite keepalive header), because
+ # lists of length > 1 cannot have their content length divined
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+ def test_http11_generator(self):
+ body = string.ascii_letters
+ to_send = "GET / HTTP/1.1\r\nContent-Length: %s\r\n\r\n" % len(body)
+ to_send += body
+ to_send = tobytes(to_send)
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb")
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.1")
+ expected = b""
+ for chunk in chunks(body, 10):
+ expected += tobytes(
+ "%s\r\n%s\r\n" % (str(hex(len(chunk))[2:].upper()), chunk)
+ )
+ expected += b"0\r\n\r\n"
+ self.assertEqual(response_body, expected)
+ # connection is always closed at the end of a chunked response
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+ def test_http11_list(self):
+ body = string.ascii_letters
+ to_send = "GET /list HTTP/1.1\r\nContent-Length: %d\r\n\r\n" % len(body)
+ to_send += body
+ to_send = tobytes(to_send)
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.1")
+ self.assertEqual(headers["content-length"], str(len(body)))
+ self.assertEqual(response_body, tobytes(body))
+ # remote keeps connection open because it divined the content length
+ # from a length-1 list
+ self.sock.send(to_send)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.1")
+
+ def test_http11_listlentwo(self):
+ body = string.ascii_letters
+ to_send = "GET /list_lentwo HTTP/1.1\r\nContent-Length: %s\r\n\r\n" % len(body)
+ to_send += body
+ to_send = tobytes(to_send)
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb")
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.1")
+ expected = b""
+ for chunk in (body[0], body[1:]):
+ expected += tobytes(
+ "%s\r\n%s\r\n" % (str(hex(len(chunk))[2:].upper()), chunk)
+ )
+ expected += b"0\r\n\r\n"
+ self.assertEqual(response_body, expected)
+ # connection is always closed at the end of a chunked response
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+
+class WriteCallbackTests(object):
+ def setUp(self):
+ from waitress.tests.fixtureapps import writecb
+
+ self.start_subprocess(writecb.app)
+
+ def tearDown(self):
+ self.stop_subprocess()
+
+ def test_short_body(self):
+ # check to see if server closes connection when body is too short
+ # for cl header
+ to_send = tobytes(
+ "GET /short_body HTTP/1.0\r\n"
+ "Connection: Keep-Alive\r\n"
+ "Content-Length: 0\r\n"
+ "\r\n"
+ )
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ # server trusts the content-length header (5)
+ self.assertline(line, "200", "OK", "HTTP/1.0")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, 9)
+ self.assertNotEqual(cl, len(response_body))
+ self.assertEqual(len(response_body), cl - 1)
+ self.assertEqual(response_body, tobytes("abcdefgh"))
+ # remote closed connection (despite keepalive header)
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+ def test_long_body(self):
+ # check server doesnt close connection when body is too long
+ # for cl header
+ to_send = tobytes(
+ "GET /long_body HTTP/1.0\r\n"
+ "Connection: Keep-Alive\r\n"
+ "Content-Length: 0\r\n"
+ "\r\n"
+ )
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ content_length = int(headers.get("content-length")) or None
+ self.assertEqual(content_length, 9)
+ self.assertEqual(content_length, len(response_body))
+ self.assertEqual(response_body, tobytes("abcdefghi"))
+ # remote does not close connection (keepalive header)
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.0")
+
+ def test_equal_body(self):
+ # check server doesnt close connection when body is equal to
+ # cl header
+ to_send = tobytes(
+ "GET /equal_body HTTP/1.0\r\n"
+ "Connection: Keep-Alive\r\n"
+ "Content-Length: 0\r\n"
+ "\r\n"
+ )
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ content_length = int(headers.get("content-length")) or None
+ self.assertEqual(content_length, 9)
+ self.assertline(line, "200", "OK", "HTTP/1.0")
+ self.assertEqual(content_length, len(response_body))
+ self.assertEqual(response_body, tobytes("abcdefghi"))
+ # remote does not close connection (keepalive header)
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.0")
+
+ def test_no_content_length(self):
+ # wtf happens when there's no content-length
+ to_send = tobytes(
+ "GET /no_content_length HTTP/1.0\r\n"
+ "Connection: Keep-Alive\r\n"
+ "Content-Length: 0\r\n"
+ "\r\n"
+ )
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line = fp.readline() # status line
+ line, headers, response_body = read_http(fp)
+ content_length = headers.get("content-length")
+ self.assertEqual(content_length, None)
+ self.assertEqual(response_body, tobytes("abcdefghi"))
+ # remote closed connection (despite keepalive header)
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+
+class TooLargeTests(object):
+
+ toobig = 1050
+
+ def setUp(self):
+ from waitress.tests.fixtureapps import toolarge
+
+ self.start_subprocess(
+ toolarge.app, max_request_header_size=1000, max_request_body_size=1000
+ )
+
+ def tearDown(self):
+ self.stop_subprocess()
+
+ def test_request_body_too_large_with_wrong_cl_http10(self):
+ body = "a" * self.toobig
+ to_send = "GET / HTTP/1.0\r\nContent-Length: 5\r\n\r\n"
+ to_send += body
+ to_send = tobytes(to_send)
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb")
+ # first request succeeds (content-length 5)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.0")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ # server trusts the content-length header; no pipelining,
+ # so request fulfilled, extra bytes are thrown away
+ # connection has been closed
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+ def test_request_body_too_large_with_wrong_cl_http10_keepalive(self):
+ body = "a" * self.toobig
+ to_send = "GET / HTTP/1.0\r\nContent-Length: 5\r\nConnection: Keep-Alive\r\n\r\n"
+ to_send += body
+ to_send = tobytes(to_send)
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb")
+ # first request succeeds (content-length 5)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.0")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "431", "Request Header Fields Too Large", "HTTP/1.0")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ # connection has been closed
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+ def test_request_body_too_large_with_no_cl_http10(self):
+ body = "a" * self.toobig
+ to_send = "GET / HTTP/1.0\r\n\r\n"
+ to_send += body
+ to_send = tobytes(to_send)
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.0")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ # extra bytes are thrown away (no pipelining), connection closed
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+ def test_request_body_too_large_with_no_cl_http10_keepalive(self):
+ body = "a" * self.toobig
+ to_send = "GET / HTTP/1.0\r\nConnection: Keep-Alive\r\n\r\n"
+ to_send += body
+ to_send = tobytes(to_send)
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ # server trusts the content-length header (assumed zero)
+ self.assertline(line, "200", "OK", "HTTP/1.0")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ line, headers, response_body = read_http(fp)
+ # next response overruns because the extra data appears to be
+ # header data
+ self.assertline(line, "431", "Request Header Fields Too Large", "HTTP/1.0")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ # connection has been closed
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+ def test_request_body_too_large_with_wrong_cl_http11(self):
+ body = "a" * self.toobig
+ to_send = "GET / HTTP/1.1\r\nContent-Length: 5\r\n\r\n"
+ to_send += body
+ to_send = tobytes(to_send)
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb")
+ # first request succeeds (content-length 5)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.1")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ # second response is an error response
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "431", "Request Header Fields Too Large", "HTTP/1.0")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ # connection has been closed
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+ def test_request_body_too_large_with_wrong_cl_http11_connclose(self):
+ body = "a" * self.toobig
+ to_send = "GET / HTTP/1.1\r\nContent-Length: 5\r\nConnection: close\r\n\r\n"
+ to_send += body
+ to_send = tobytes(to_send)
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ # server trusts the content-length header (5)
+ self.assertline(line, "200", "OK", "HTTP/1.1")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ # connection has been closed
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+ def test_request_body_too_large_with_no_cl_http11(self):
+ body = "a" * self.toobig
+ to_send = "GET / HTTP/1.1\r\n\r\n"
+ to_send += body
+ to_send = tobytes(to_send)
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb")
+ # server trusts the content-length header (assumed 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.1")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ # server assumes pipelined requests due to http/1.1, and the first
+ # request was assumed c-l 0 because it had no content-length header,
+ # so entire body looks like the header of the subsequent request
+ # second response is an error response
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "431", "Request Header Fields Too Large", "HTTP/1.0")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ # connection has been closed
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+ def test_request_body_too_large_with_no_cl_http11_connclose(self):
+ body = "a" * self.toobig
+ to_send = "GET / HTTP/1.1\r\nConnection: close\r\n\r\n"
+ to_send += body
+ to_send = tobytes(to_send)
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ # server trusts the content-length header (assumed 0)
+ self.assertline(line, "200", "OK", "HTTP/1.1")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ # connection has been closed
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+ def test_request_body_too_large_chunked_encoding(self):
+ control_line = "20;\r\n" # 20 hex = 32 dec
+ s = "This string has 32 characters.\r\n"
+ to_send = "GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n"
+ repeat = control_line + s
+ to_send += repeat * ((self.toobig // len(repeat)) + 1)
+ to_send = tobytes(to_send)
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ # body bytes counter caught a max_request_body_size overrun
+ self.assertline(line, "413", "Request Entity Too Large", "HTTP/1.1")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ self.assertEqual(headers["content-type"], "text/plain")
+ # connection has been closed
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+
+class InternalServerErrorTests(object):
+ def setUp(self):
+ from waitress.tests.fixtureapps import error
+
+ self.start_subprocess(error.app, expose_tracebacks=True)
+
+ def tearDown(self):
+ self.stop_subprocess()
+
+ def test_before_start_response_http_10(self):
+ to_send = "GET /before_start_response HTTP/1.0\r\n\r\n"
+ to_send = tobytes(to_send)
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "500", "Internal Server Error", "HTTP/1.0")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ self.assertTrue(response_body.startswith(b"Internal Server Error"))
+ self.assertEqual(headers["connection"], "close")
+ # connection has been closed
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+ def test_before_start_response_http_11(self):
+ to_send = "GET /before_start_response HTTP/1.1\r\n\r\n"
+ to_send = tobytes(to_send)
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "500", "Internal Server Error", "HTTP/1.1")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ self.assertTrue(response_body.startswith(b"Internal Server Error"))
+ self.assertEqual(
+ sorted(headers.keys()), ["connection", "content-length", "content-type", "date", "server"]
+ )
+ # connection has been closed
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+ def test_before_start_response_http_11_close(self):
+ to_send = tobytes(
+ "GET /before_start_response HTTP/1.1\r\nConnection: close\r\n\r\n"
+ )
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "500", "Internal Server Error", "HTTP/1.1")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ self.assertTrue(response_body.startswith(b"Internal Server Error"))
+ self.assertEqual(
+ sorted(headers.keys()),
+ ["connection", "content-length", "content-type", "date", "server"],
+ )
+ self.assertEqual(headers["connection"], "close")
+ # connection has been closed
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+ def test_after_start_response_http10(self):
+ to_send = "GET /after_start_response HTTP/1.0\r\n\r\n"
+ to_send = tobytes(to_send)
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "500", "Internal Server Error", "HTTP/1.0")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ self.assertTrue(response_body.startswith(b"Internal Server Error"))
+ self.assertEqual(
+ sorted(headers.keys()),
+ ["connection", "content-length", "content-type", "date", "server"],
+ )
+ self.assertEqual(headers["connection"], "close")
+ # connection has been closed
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+ def test_after_start_response_http11(self):
+ to_send = "GET /after_start_response HTTP/1.1\r\n\r\n"
+ to_send = tobytes(to_send)
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "500", "Internal Server Error", "HTTP/1.1")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ self.assertTrue(response_body.startswith(b"Internal Server Error"))
+ self.assertEqual(
+ sorted(headers.keys()), ["connection", "content-length", "content-type", "date", "server"]
+ )
+ # connection has been closed
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+ def test_after_start_response_http11_close(self):
+ to_send = tobytes(
+ "GET /after_start_response HTTP/1.1\r\nConnection: close\r\n\r\n"
+ )
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "500", "Internal Server Error", "HTTP/1.1")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ self.assertTrue(response_body.startswith(b"Internal Server Error"))
+ self.assertEqual(
+ sorted(headers.keys()),
+ ["connection", "content-length", "content-type", "date", "server"],
+ )
+ self.assertEqual(headers["connection"], "close")
+ # connection has been closed
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+ def test_after_write_cb(self):
+ to_send = "GET /after_write_cb HTTP/1.1\r\n\r\n"
+ to_send = tobytes(to_send)
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.1")
+ self.assertEqual(response_body, b"")
+ # connection has been closed
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+ def test_in_generator(self):
+ to_send = "GET /in_generator HTTP/1.1\r\n\r\n"
+ to_send = tobytes(to_send)
+ self.connect()
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.1")
+ self.assertEqual(response_body, b"")
+ # connection has been closed
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+
+class FileWrapperTests(object):
+ def setUp(self):
+ from waitress.tests.fixtureapps import filewrapper
+
+ self.start_subprocess(filewrapper.app)
+
+ def tearDown(self):
+ self.stop_subprocess()
+
+ def test_filelike_http11(self):
+ to_send = "GET /filelike HTTP/1.1\r\n\r\n"
+ to_send = tobytes(to_send)
+
+ self.connect()
+
+ for t in range(0, 2):
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.1")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ ct = headers["content-type"]
+ self.assertEqual(ct, "image/jpeg")
+ self.assertTrue(b"\377\330\377" in response_body)
+ # connection has not been closed
+
+ def test_filelike_nocl_http11(self):
+ to_send = "GET /filelike_nocl HTTP/1.1\r\n\r\n"
+ to_send = tobytes(to_send)
+
+ self.connect()
+
+ for t in range(0, 2):
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.1")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ ct = headers["content-type"]
+ self.assertEqual(ct, "image/jpeg")
+ self.assertTrue(b"\377\330\377" in response_body)
+ # connection has not been closed
+
+ def test_filelike_shortcl_http11(self):
+ to_send = "GET /filelike_shortcl HTTP/1.1\r\n\r\n"
+ to_send = tobytes(to_send)
+
+ self.connect()
+
+ for t in range(0, 2):
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.1")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, 1)
+ self.assertEqual(cl, len(response_body))
+ ct = headers["content-type"]
+ self.assertEqual(ct, "image/jpeg")
+ self.assertTrue(b"\377" in response_body)
+ # connection has not been closed
+
+ def test_filelike_longcl_http11(self):
+ to_send = "GET /filelike_longcl HTTP/1.1\r\n\r\n"
+ to_send = tobytes(to_send)
+
+ self.connect()
+
+ for t in range(0, 2):
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.1")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ ct = headers["content-type"]
+ self.assertEqual(ct, "image/jpeg")
+ self.assertTrue(b"\377\330\377" in response_body)
+ # connection has not been closed
+
+ def test_notfilelike_http11(self):
+ to_send = "GET /notfilelike HTTP/1.1\r\n\r\n"
+ to_send = tobytes(to_send)
+
+ self.connect()
+
+ for t in range(0, 2):
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.1")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ ct = headers["content-type"]
+ self.assertEqual(ct, "image/jpeg")
+ self.assertTrue(b"\377\330\377" in response_body)
+ # connection has not been closed
+
+ def test_notfilelike_iobase_http11(self):
+ to_send = "GET /notfilelike_iobase HTTP/1.1\r\n\r\n"
+ to_send = tobytes(to_send)
+
+ self.connect()
+
+ for t in range(0, 2):
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.1")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ ct = headers["content-type"]
+ self.assertEqual(ct, "image/jpeg")
+ self.assertTrue(b"\377\330\377" in response_body)
+ # connection has not been closed
+
+ def test_notfilelike_nocl_http11(self):
+ to_send = "GET /notfilelike_nocl HTTP/1.1\r\n\r\n"
+ to_send = tobytes(to_send)
+
+ self.connect()
+
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.1")
+ ct = headers["content-type"]
+ self.assertEqual(ct, "image/jpeg")
+ self.assertTrue(b"\377\330\377" in response_body)
+ # connection has been closed (no content-length)
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+ def test_notfilelike_shortcl_http11(self):
+ to_send = "GET /notfilelike_shortcl HTTP/1.1\r\n\r\n"
+ to_send = tobytes(to_send)
+
+ self.connect()
+
+ for t in range(0, 2):
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.1")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, 1)
+ self.assertEqual(cl, len(response_body))
+ ct = headers["content-type"]
+ self.assertEqual(ct, "image/jpeg")
+ self.assertTrue(b"\377" in response_body)
+ # connection has not been closed
+
+ def test_notfilelike_longcl_http11(self):
+ to_send = "GET /notfilelike_longcl HTTP/1.1\r\n\r\n"
+ to_send = tobytes(to_send)
+
+ self.connect()
+
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.1")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body) + 10)
+ ct = headers["content-type"]
+ self.assertEqual(ct, "image/jpeg")
+ self.assertTrue(b"\377\330\377" in response_body)
+ # connection has been closed
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+ def test_filelike_http10(self):
+ to_send = "GET /filelike HTTP/1.0\r\n\r\n"
+ to_send = tobytes(to_send)
+
+ self.connect()
+
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.0")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ ct = headers["content-type"]
+ self.assertEqual(ct, "image/jpeg")
+ self.assertTrue(b"\377\330\377" in response_body)
+ # connection has been closed
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+ def test_filelike_nocl_http10(self):
+ to_send = "GET /filelike_nocl HTTP/1.0\r\n\r\n"
+ to_send = tobytes(to_send)
+
+ self.connect()
+
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.0")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ ct = headers["content-type"]
+ self.assertEqual(ct, "image/jpeg")
+ self.assertTrue(b"\377\330\377" in response_body)
+ # connection has been closed
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+ def test_notfilelike_http10(self):
+ to_send = "GET /notfilelike HTTP/1.0\r\n\r\n"
+ to_send = tobytes(to_send)
+
+ self.connect()
+
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.0")
+ cl = int(headers["content-length"])
+ self.assertEqual(cl, len(response_body))
+ ct = headers["content-type"]
+ self.assertEqual(ct, "image/jpeg")
+ self.assertTrue(b"\377\330\377" in response_body)
+ # connection has been closed
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+ def test_notfilelike_nocl_http10(self):
+ to_send = "GET /notfilelike_nocl HTTP/1.0\r\n\r\n"
+ to_send = tobytes(to_send)
+
+ self.connect()
+
+ self.sock.send(to_send)
+ fp = self.sock.makefile("rb", 0)
+ line, headers, response_body = read_http(fp)
+ self.assertline(line, "200", "OK", "HTTP/1.0")
+ ct = headers["content-type"]
+ self.assertEqual(ct, "image/jpeg")
+ self.assertTrue(b"\377\330\377" in response_body)
+ # connection has been closed (no content-length)
+ self.send_check_error(to_send)
+ self.assertRaises(ConnectionClosed, read_http, fp)
+
+
+class TcpEchoTests(EchoTests, TcpTests, unittest.TestCase):
+ pass
+
+
+class TcpPipeliningTests(PipeliningTests, TcpTests, unittest.TestCase):
+ pass
+
+
+class TcpExpectContinueTests(ExpectContinueTests, TcpTests, unittest.TestCase):
+ pass
+
+
+class TcpBadContentLengthTests(BadContentLengthTests, TcpTests, unittest.TestCase):
+ pass
+
+
+class TcpNoContentLengthTests(NoContentLengthTests, TcpTests, unittest.TestCase):
+ pass
+
+
+class TcpWriteCallbackTests(WriteCallbackTests, TcpTests, unittest.TestCase):
+ pass
+
+
+class TcpTooLargeTests(TooLargeTests, TcpTests, unittest.TestCase):
+ pass
+
+
+class TcpInternalServerErrorTests(
+ InternalServerErrorTests, TcpTests, unittest.TestCase
+):
+ pass
+
+
+class TcpFileWrapperTests(FileWrapperTests, TcpTests, unittest.TestCase):
+ pass
+
+
+if hasattr(socket, "AF_UNIX"):
+
+ class FixtureUnixWSGIServer(server.UnixWSGIServer):
+ """A version of UnixWSGIServer that relays back what it's bound to.
+ """
+
+ family = socket.AF_UNIX # Testing
+
+ def __init__(self, application, queue, **kw): # pragma: no cover
+ # Coverage doesn't see this as it's ran in a separate process.
+ # To permit parallel testing, use a PID-dependent socket.
+ kw["unix_socket"] = "/tmp/waitress.test-%d.sock" % os.getpid()
+ super(FixtureUnixWSGIServer, self).__init__(application, **kw)
+ queue.put(self.socket.getsockname())
+
+ class UnixTests(SubprocessTests):
+
+ server = FixtureUnixWSGIServer
+
+ def make_http_connection(self):
+ return UnixHTTPConnection(self.bound_to)
+
+ def stop_subprocess(self):
+ super(UnixTests, self).stop_subprocess()
+ cleanup_unix_socket(self.bound_to)
+
+ def send_check_error(self, to_send):
+ # Unlike inet domain sockets, Unix domain sockets can trigger a
+ # 'Broken pipe' error when the socket it closed.
+ try:
+ self.sock.send(to_send)
+ except socket.error as exc:
+ self.assertEqual(get_errno(exc), errno.EPIPE)
+
+ class UnixEchoTests(EchoTests, UnixTests, unittest.TestCase):
+ pass
+
+ class UnixPipeliningTests(PipeliningTests, UnixTests, unittest.TestCase):
+ pass
+
+ class UnixExpectContinueTests(ExpectContinueTests, UnixTests, unittest.TestCase):
+ pass
+
+ class UnixBadContentLengthTests(
+ BadContentLengthTests, UnixTests, unittest.TestCase
+ ):
+ pass
+
+ class UnixNoContentLengthTests(NoContentLengthTests, UnixTests, unittest.TestCase):
+ pass
+
+ class UnixWriteCallbackTests(WriteCallbackTests, UnixTests, unittest.TestCase):
+ pass
+
+ class UnixTooLargeTests(TooLargeTests, UnixTests, unittest.TestCase):
+ pass
+
+ class UnixInternalServerErrorTests(
+ InternalServerErrorTests, UnixTests, unittest.TestCase
+ ):
+ pass
+
+ class UnixFileWrapperTests(FileWrapperTests, UnixTests, unittest.TestCase):
+ pass
+
+
+def parse_headers(fp):
+ """Parses only RFC2822 headers from a file pointer.
+ """
+ headers = {}
+ while True:
+ line = fp.readline()
+ if line in (b"\r\n", b"\n", b""):
+ break
+ line = line.decode("iso-8859-1")
+ name, value = line.strip().split(":", 1)
+ headers[name.lower().strip()] = value.lower().strip()
+ return headers
+
+
+class UnixHTTPConnection(httplib.HTTPConnection):
+ """Patched version of HTTPConnection that uses Unix domain sockets.
+ """
+
+ def __init__(self, path):
+ httplib.HTTPConnection.__init__(self, "localhost")
+ self.path = path
+
+ def connect(self):
+ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ sock.connect(self.path)
+ self.sock = sock
+
+
+class ConnectionClosed(Exception):
+ pass
+
+
+# stolen from gevent
+def read_http(fp): # pragma: no cover
+ try:
+ response_line = fp.readline()
+ except socket.error as exc:
+ fp.close()
+ # errno 104 is ENOTRECOVERABLE, In WinSock 10054 is ECONNRESET
+ if get_errno(exc) in (errno.ECONNABORTED, errno.ECONNRESET, 104, 10054):
+ raise ConnectionClosed
+ raise
+ if not response_line:
+ raise ConnectionClosed
+
+ header_lines = []
+ while True:
+ line = fp.readline()
+ if line in (b"\r\n", b"\r\n", b""):
+ break
+ else:
+ header_lines.append(line)
+ headers = dict()
+ for x in header_lines:
+ x = x.strip()
+ if not x:
+ continue
+ key, value = x.split(b": ", 1)
+ key = key.decode("iso-8859-1").lower()
+ value = value.decode("iso-8859-1")
+ assert key not in headers, "%s header duplicated" % key
+ headers[key] = value
+
+ if "content-length" in headers:
+ num = int(headers["content-length"])
+ body = b""
+ left = num
+ while left > 0:
+ data = fp.read(left)
+ if not data:
+ break
+ body += data
+ left -= len(data)
+ else:
+ # read until EOF
+ body = fp.read()
+
+ return response_line, headers, body
+
+
+# stolen from gevent
+def get_errno(exc): # pragma: no cover
+ """ Get the error code out of socket.error objects.
+ socket.error in <2.5 does not have errno attribute
+ socket.error in 3.x does not allow indexing access
+ e.args[0] works for all.
+ There are cases when args[0] is not errno.
+ i.e. http://bugs.python.org/issue6471
+ Maybe there are cases when errno is set, but it is not the first argument?
+ """
+ try:
+ if exc.errno is not None:
+ return exc.errno
+ except AttributeError:
+ pass
+ try:
+ return exc.args[0]
+ except IndexError:
+ return None
+
+
+def chunks(l, n):
+ """ Yield successive n-sized chunks from l.
+ """
+ for i in range(0, len(l), n):
+ yield l[i : i + n]
diff --git a/libs/waitress/tests/test_init.py b/libs/waitress/tests/test_init.py
new file mode 100644
index 000000000..f9b91d762
--- /dev/null
+++ b/libs/waitress/tests/test_init.py
@@ -0,0 +1,51 @@
+import unittest
+
+
+class Test_serve(unittest.TestCase):
+ def _callFUT(self, app, **kw):
+ from waitress import serve
+
+ return serve(app, **kw)
+
+ def test_it(self):
+ server = DummyServerFactory()
+ app = object()
+ result = self._callFUT(app, _server=server, _quiet=True)
+ self.assertEqual(server.app, app)
+ self.assertEqual(result, None)
+ self.assertEqual(server.ran, True)
+
+
+class Test_serve_paste(unittest.TestCase):
+ def _callFUT(self, app, **kw):
+ from waitress import serve_paste
+
+ return serve_paste(app, None, **kw)
+
+ def test_it(self):
+ server = DummyServerFactory()
+ app = object()
+ result = self._callFUT(app, _server=server, _quiet=True)
+ self.assertEqual(server.app, app)
+ self.assertEqual(result, 0)
+ self.assertEqual(server.ran, True)
+
+
+class DummyServerFactory(object):
+ ran = False
+
+ def __call__(self, app, **kw):
+ self.adj = DummyAdj(kw)
+ self.app = app
+ self.kw = kw
+ return self
+
+ def run(self):
+ self.ran = True
+
+
+class DummyAdj(object):
+ verbose = False
+
+ def __init__(self, kw):
+ self.__dict__.update(kw)
diff --git a/libs/waitress/tests/test_parser.py b/libs/waitress/tests/test_parser.py
new file mode 100644
index 000000000..91837c7fc
--- /dev/null
+++ b/libs/waitress/tests/test_parser.py
@@ -0,0 +1,732 @@
+##############################################################################
+#
+# Copyright (c) 2002 Zope Foundation and Contributors.
+# All Rights Reserved.
+#
+# This software is subject to the provisions of the Zope Public License,
+# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
+# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
+# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
+# FOR A PARTICULAR PURPOSE.
+#
+##############################################################################
+"""HTTP Request Parser tests
+"""
+import unittest
+
+from waitress.compat import text_, tobytes
+
+
+class TestHTTPRequestParser(unittest.TestCase):
+ def setUp(self):
+ from waitress.parser import HTTPRequestParser
+ from waitress.adjustments import Adjustments
+
+ my_adj = Adjustments()
+ self.parser = HTTPRequestParser(my_adj)
+
+ def test_get_body_stream_None(self):
+ self.parser.body_recv = None
+ result = self.parser.get_body_stream()
+ self.assertEqual(result.getvalue(), b"")
+
+ def test_get_body_stream_nonNone(self):
+ body_rcv = DummyBodyStream()
+ self.parser.body_rcv = body_rcv
+ result = self.parser.get_body_stream()
+ self.assertEqual(result, body_rcv)
+
+ def test_received_get_no_headers(self):
+ data = b"HTTP/1.0 GET /foobar\r\n\r\n"
+ result = self.parser.received(data)
+ self.assertEqual(result, 24)
+ self.assertTrue(self.parser.completed)
+ self.assertEqual(self.parser.headers, {})
+
+ def test_received_bad_host_header(self):
+ from waitress.utilities import BadRequest
+
+ data = b"HTTP/1.0 GET /foobar\r\n Host: foo\r\n\r\n"
+ result = self.parser.received(data)
+ self.assertEqual(result, 36)
+ self.assertTrue(self.parser.completed)
+ self.assertEqual(self.parser.error.__class__, BadRequest)
+
+ def test_received_bad_transfer_encoding(self):
+ from waitress.utilities import ServerNotImplemented
+
+ data = (
+ b"GET /foobar HTTP/1.1\r\n"
+ b"Transfer-Encoding: foo\r\n"
+ b"\r\n"
+ b"1d;\r\n"
+ b"This string has 29 characters\r\n"
+ b"0\r\n\r\n"
+ )
+ result = self.parser.received(data)
+ self.assertEqual(result, 48)
+ self.assertTrue(self.parser.completed)
+ self.assertEqual(self.parser.error.__class__, ServerNotImplemented)
+
+ def test_received_nonsense_nothing(self):
+ data = b"\r\n\r\n"
+ result = self.parser.received(data)
+ self.assertEqual(result, 4)
+ self.assertTrue(self.parser.completed)
+ self.assertEqual(self.parser.headers, {})
+
+ def test_received_no_doublecr(self):
+ data = b"GET /foobar HTTP/8.4\r\n"
+ result = self.parser.received(data)
+ self.assertEqual(result, 22)
+ self.assertFalse(self.parser.completed)
+ self.assertEqual(self.parser.headers, {})
+
+ def test_received_already_completed(self):
+ self.parser.completed = True
+ result = self.parser.received(b"a")
+ self.assertEqual(result, 0)
+
+ def test_received_cl_too_large(self):
+ from waitress.utilities import RequestEntityTooLarge
+
+ self.parser.adj.max_request_body_size = 2
+ data = b"GET /foobar HTTP/8.4\r\nContent-Length: 10\r\n\r\n"
+ result = self.parser.received(data)
+ self.assertEqual(result, 44)
+ self.assertTrue(self.parser.completed)
+ self.assertTrue(isinstance(self.parser.error, RequestEntityTooLarge))
+
+ def test_received_headers_too_large(self):
+ from waitress.utilities import RequestHeaderFieldsTooLarge
+
+ self.parser.adj.max_request_header_size = 2
+ data = b"GET /foobar HTTP/8.4\r\nX-Foo: 1\r\n\r\n"
+ result = self.parser.received(data)
+ self.assertEqual(result, 34)
+ self.assertTrue(self.parser.completed)
+ self.assertTrue(isinstance(self.parser.error, RequestHeaderFieldsTooLarge))
+
+ def test_received_body_too_large(self):
+ from waitress.utilities import RequestEntityTooLarge
+
+ self.parser.adj.max_request_body_size = 2
+ data = (
+ b"GET /foobar HTTP/1.1\r\n"
+ b"Transfer-Encoding: chunked\r\n"
+ b"X-Foo: 1\r\n"
+ b"\r\n"
+ b"1d;\r\n"
+ b"This string has 29 characters\r\n"
+ b"0\r\n\r\n"
+ )
+
+ result = self.parser.received(data)
+ self.assertEqual(result, 62)
+ self.parser.received(data[result:])
+ self.assertTrue(self.parser.completed)
+ self.assertTrue(isinstance(self.parser.error, RequestEntityTooLarge))
+
+ def test_received_error_from_parser(self):
+ from waitress.utilities import BadRequest
+
+ data = (
+ b"GET /foobar HTTP/1.1\r\n"
+ b"Transfer-Encoding: chunked\r\n"
+ b"X-Foo: 1\r\n"
+ b"\r\n"
+ b"garbage\r\n"
+ )
+ # header
+ result = self.parser.received(data)
+ # body
+ result = self.parser.received(data[result:])
+ self.assertEqual(result, 9)
+ self.assertTrue(self.parser.completed)
+ self.assertTrue(isinstance(self.parser.error, BadRequest))
+
+ def test_received_chunked_completed_sets_content_length(self):
+ data = (
+ b"GET /foobar HTTP/1.1\r\n"
+ b"Transfer-Encoding: chunked\r\n"
+ b"X-Foo: 1\r\n"
+ b"\r\n"
+ b"1d;\r\n"
+ b"This string has 29 characters\r\n"
+ b"0\r\n\r\n"
+ )
+ result = self.parser.received(data)
+ self.assertEqual(result, 62)
+ data = data[result:]
+ result = self.parser.received(data)
+ self.assertTrue(self.parser.completed)
+ self.assertTrue(self.parser.error is None)
+ self.assertEqual(self.parser.headers["CONTENT_LENGTH"], "29")
+
+ def test_parse_header_gardenpath(self):
+ data = b"GET /foobar HTTP/8.4\r\nfoo: bar\r\n"
+ self.parser.parse_header(data)
+ self.assertEqual(self.parser.first_line, b"GET /foobar HTTP/8.4")
+ self.assertEqual(self.parser.headers["FOO"], "bar")
+
+ def test_parse_header_no_cr_in_headerplus(self):
+ from waitress.parser import ParsingError
+
+ data = b"GET /foobar HTTP/8.4"
+
+ try:
+ self.parser.parse_header(data)
+ except ParsingError:
+ pass
+ else: # pragma: nocover
+ self.assertTrue(False)
+
+ def test_parse_header_bad_content_length(self):
+ from waitress.parser import ParsingError
+
+ data = b"GET /foobar HTTP/8.4\r\ncontent-length: abc\r\n"
+
+ try:
+ self.parser.parse_header(data)
+ except ParsingError as e:
+ self.assertIn("Content-Length is invalid", e.args[0])
+ else: # pragma: nocover
+ self.assertTrue(False)
+
+ def test_parse_header_multiple_content_length(self):
+ from waitress.parser import ParsingError
+
+ data = b"GET /foobar HTTP/8.4\r\ncontent-length: 10\r\ncontent-length: 20\r\n"
+
+ try:
+ self.parser.parse_header(data)
+ except ParsingError as e:
+ self.assertIn("Content-Length is invalid", e.args[0])
+ else: # pragma: nocover
+ self.assertTrue(False)
+
+ def test_parse_header_11_te_chunked(self):
+ # NB: test that capitalization of header value is unimportant
+ data = b"GET /foobar HTTP/1.1\r\ntransfer-encoding: ChUnKed\r\n"
+ self.parser.parse_header(data)
+ self.assertEqual(self.parser.body_rcv.__class__.__name__, "ChunkedReceiver")
+
+ def test_parse_header_transfer_encoding_invalid(self):
+ from waitress.parser import TransferEncodingNotImplemented
+
+ data = b"GET /foobar HTTP/1.1\r\ntransfer-encoding: gzip\r\n"
+
+ try:
+ self.parser.parse_header(data)
+ except TransferEncodingNotImplemented as e:
+ self.assertIn("Transfer-Encoding requested is not supported.", e.args[0])
+ else: # pragma: nocover
+ self.assertTrue(False)
+
+ def test_parse_header_transfer_encoding_invalid_multiple(self):
+ from waitress.parser import TransferEncodingNotImplemented
+
+ data = b"GET /foobar HTTP/1.1\r\ntransfer-encoding: gzip\r\ntransfer-encoding: chunked\r\n"
+
+ try:
+ self.parser.parse_header(data)
+ except TransferEncodingNotImplemented as e:
+ self.assertIn("Transfer-Encoding requested is not supported.", e.args[0])
+ else: # pragma: nocover
+ self.assertTrue(False)
+
+ def test_parse_header_transfer_encoding_invalid_whitespace(self):
+ from waitress.parser import TransferEncodingNotImplemented
+
+ data = b"GET /foobar HTTP/1.1\r\nTransfer-Encoding:\x85chunked\r\n"
+
+ try:
+ self.parser.parse_header(data)
+ except TransferEncodingNotImplemented as e:
+ self.assertIn("Transfer-Encoding requested is not supported.", e.args[0])
+ else: # pragma: nocover
+ self.assertTrue(False)
+
+ def test_parse_header_transfer_encoding_invalid_unicode(self):
+ from waitress.parser import TransferEncodingNotImplemented
+
+ # This is the binary encoding for the UTF-8 character
+ # https://www.compart.com/en/unicode/U+212A "unicode character "K""
+ # which if waitress were to accidentally do the wrong thing get
+ # lowercased to just the ascii "k" due to unicode collisions during
+ # transformation
+ data = b"GET /foobar HTTP/1.1\r\nTransfer-Encoding: chun\xe2\x84\xaaed\r\n"
+
+ try:
+ self.parser.parse_header(data)
+ except TransferEncodingNotImplemented as e:
+ self.assertIn("Transfer-Encoding requested is not supported.", e.args[0])
+ else: # pragma: nocover
+ self.assertTrue(False)
+
+ def test_parse_header_11_expect_continue(self):
+ data = b"GET /foobar HTTP/1.1\r\nexpect: 100-continue\r\n"
+ self.parser.parse_header(data)
+ self.assertEqual(self.parser.expect_continue, True)
+
+ def test_parse_header_connection_close(self):
+ data = b"GET /foobar HTTP/1.1\r\nConnection: close\r\n"
+ self.parser.parse_header(data)
+ self.assertEqual(self.parser.connection_close, True)
+
+ def test_close_with_body_rcv(self):
+ body_rcv = DummyBodyStream()
+ self.parser.body_rcv = body_rcv
+ self.parser.close()
+ self.assertTrue(body_rcv.closed)
+
+ def test_close_with_no_body_rcv(self):
+ self.parser.body_rcv = None
+ self.parser.close() # doesn't raise
+
+ def test_parse_header_lf_only(self):
+ from waitress.parser import ParsingError
+
+ data = b"GET /foobar HTTP/8.4\nfoo: bar"
+
+ try:
+ self.parser.parse_header(data)
+ except ParsingError:
+ pass
+ else: # pragma: nocover
+ self.assertTrue(False)
+
+ def test_parse_header_cr_only(self):
+ from waitress.parser import ParsingError
+
+ data = b"GET /foobar HTTP/8.4\rfoo: bar"
+ try:
+ self.parser.parse_header(data)
+ except ParsingError:
+ pass
+ else: # pragma: nocover
+ self.assertTrue(False)
+
+ def test_parse_header_extra_lf_in_header(self):
+ from waitress.parser import ParsingError
+
+ data = b"GET /foobar HTTP/8.4\r\nfoo: \nbar\r\n"
+ try:
+ self.parser.parse_header(data)
+ except ParsingError as e:
+ self.assertIn("Bare CR or LF found in header line", e.args[0])
+ else: # pragma: nocover
+ self.assertTrue(False)
+
+ def test_parse_header_extra_lf_in_first_line(self):
+ from waitress.parser import ParsingError
+
+ data = b"GET /foobar\n HTTP/8.4\r\n"
+ try:
+ self.parser.parse_header(data)
+ except ParsingError as e:
+ self.assertIn("Bare CR or LF found in HTTP message", e.args[0])
+ else: # pragma: nocover
+ self.assertTrue(False)
+
+ def test_parse_header_invalid_whitespace(self):
+ from waitress.parser import ParsingError
+
+ data = b"GET /foobar HTTP/8.4\r\nfoo : bar\r\n"
+ try:
+ self.parser.parse_header(data)
+ except ParsingError as e:
+ self.assertIn("Invalid header", e.args[0])
+ else: # pragma: nocover
+ self.assertTrue(False)
+
+ def test_parse_header_invalid_whitespace_vtab(self):
+ from waitress.parser import ParsingError
+
+ data = b"GET /foobar HTTP/1.1\r\nfoo:\x0bbar\r\n"
+ try:
+ self.parser.parse_header(data)
+ except ParsingError as e:
+ self.assertIn("Invalid header", e.args[0])
+ else: # pragma: nocover
+ self.assertTrue(False)
+
+ def test_parse_header_invalid_no_colon(self):
+ from waitress.parser import ParsingError
+
+ data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\nnotvalid\r\n"
+ try:
+ self.parser.parse_header(data)
+ except ParsingError as e:
+ self.assertIn("Invalid header", e.args[0])
+ else: # pragma: nocover
+ self.assertTrue(False)
+
+ def test_parse_header_invalid_folding_spacing(self):
+ from waitress.parser import ParsingError
+
+ data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\n\t\x0bbaz\r\n"
+ try:
+ self.parser.parse_header(data)
+ except ParsingError as e:
+ self.assertIn("Invalid header", e.args[0])
+ else: # pragma: nocover
+ self.assertTrue(False)
+
+ def test_parse_header_invalid_chars(self):
+ from waitress.parser import ParsingError
+
+ data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\nfoo: \x0bbaz\r\n"
+ try:
+ self.parser.parse_header(data)
+ except ParsingError as e:
+ self.assertIn("Invalid header", e.args[0])
+ else: # pragma: nocover
+ self.assertTrue(False)
+
+ def test_parse_header_empty(self):
+ from waitress.parser import ParsingError
+
+ data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\nempty:\r\n"
+ self.parser.parse_header(data)
+
+ self.assertIn("EMPTY", self.parser.headers)
+ self.assertIn("FOO", self.parser.headers)
+ self.assertEqual(self.parser.headers["EMPTY"], "")
+ self.assertEqual(self.parser.headers["FOO"], "bar")
+
+ def test_parse_header_multiple_values(self):
+ from waitress.parser import ParsingError
+
+ data = b"GET /foobar HTTP/1.1\r\nfoo: bar, whatever, more, please, yes\r\n"
+ self.parser.parse_header(data)
+
+ self.assertIn("FOO", self.parser.headers)
+ self.assertEqual(self.parser.headers["FOO"], "bar, whatever, more, please, yes")
+
+ def test_parse_header_multiple_values_header_folded(self):
+ from waitress.parser import ParsingError
+
+ data = b"GET /foobar HTTP/1.1\r\nfoo: bar, whatever,\r\n more, please, yes\r\n"
+ self.parser.parse_header(data)
+
+ self.assertIn("FOO", self.parser.headers)
+ self.assertEqual(self.parser.headers["FOO"], "bar, whatever, more, please, yes")
+
+ def test_parse_header_multiple_values_header_folded_multiple(self):
+ from waitress.parser import ParsingError
+
+ data = b"GET /foobar HTTP/1.1\r\nfoo: bar, whatever,\r\n more\r\nfoo: please, yes\r\n"
+ self.parser.parse_header(data)
+
+ self.assertIn("FOO", self.parser.headers)
+ self.assertEqual(self.parser.headers["FOO"], "bar, whatever, more, please, yes")
+
+ def test_parse_header_multiple_values_extra_space(self):
+ # Tests errata from: https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189
+ from waitress.parser import ParsingError
+
+ data = b"GET /foobar HTTP/1.1\r\nfoo: abrowser/0.001 (C O M M E N T)\r\n"
+ self.parser.parse_header(data)
+
+ self.assertIn("FOO", self.parser.headers)
+ self.assertEqual(self.parser.headers["FOO"], "abrowser/0.001 (C O M M E N T)")
+
+ def test_parse_header_invalid_backtrack_bad(self):
+ from waitress.parser import ParsingError
+
+ data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\nfoo: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\x10\r\n"
+ try:
+ self.parser.parse_header(data)
+ except ParsingError as e:
+ self.assertIn("Invalid header", e.args[0])
+ else: # pragma: nocover
+ self.assertTrue(False)
+
+ def test_parse_header_short_values(self):
+ from waitress.parser import ParsingError
+
+ data = b"GET /foobar HTTP/1.1\r\none: 1\r\ntwo: 22\r\n"
+ self.parser.parse_header(data)
+
+ self.assertIn("ONE", self.parser.headers)
+ self.assertIn("TWO", self.parser.headers)
+ self.assertEqual(self.parser.headers["ONE"], "1")
+ self.assertEqual(self.parser.headers["TWO"], "22")
+
+
+class Test_split_uri(unittest.TestCase):
+ def _callFUT(self, uri):
+ from waitress.parser import split_uri
+
+ (
+ self.proxy_scheme,
+ self.proxy_netloc,
+ self.path,
+ self.query,
+ self.fragment,
+ ) = split_uri(uri)
+
+ def test_split_uri_unquoting_unneeded(self):
+ self._callFUT(b"http://localhost:8080/abc def")
+ self.assertEqual(self.path, "/abc def")
+
+ def test_split_uri_unquoting_needed(self):
+ self._callFUT(b"http://localhost:8080/abc%20def")
+ self.assertEqual(self.path, "/abc def")
+
+ def test_split_url_with_query(self):
+ self._callFUT(b"http://localhost:8080/abc?a=1&b=2")
+ self.assertEqual(self.path, "/abc")
+ self.assertEqual(self.query, "a=1&b=2")
+
+ def test_split_url_with_query_empty(self):
+ self._callFUT(b"http://localhost:8080/abc?")
+ self.assertEqual(self.path, "/abc")
+ self.assertEqual(self.query, "")
+
+ def test_split_url_with_fragment(self):
+ self._callFUT(b"http://localhost:8080/#foo")
+ self.assertEqual(self.path, "/")
+ self.assertEqual(self.fragment, "foo")
+
+ def test_split_url_https(self):
+ self._callFUT(b"https://localhost:8080/")
+ self.assertEqual(self.path, "/")
+ self.assertEqual(self.proxy_scheme, "https")
+ self.assertEqual(self.proxy_netloc, "localhost:8080")
+
+ def test_split_uri_unicode_error_raises_parsing_error(self):
+ # See https://github.com/Pylons/waitress/issues/64
+ from waitress.parser import ParsingError
+
+ # Either pass or throw a ParsingError, just don't throw another type of
+ # exception as that will cause the connection to close badly:
+ try:
+ self._callFUT(b"/\xd0")
+ except ParsingError:
+ pass
+
+ def test_split_uri_path(self):
+ self._callFUT(b"//testing/whatever")
+ self.assertEqual(self.path, "//testing/whatever")
+ self.assertEqual(self.proxy_scheme, "")
+ self.assertEqual(self.proxy_netloc, "")
+ self.assertEqual(self.query, "")
+ self.assertEqual(self.fragment, "")
+
+ def test_split_uri_path_query(self):
+ self._callFUT(b"//testing/whatever?a=1&b=2")
+ self.assertEqual(self.path, "//testing/whatever")
+ self.assertEqual(self.proxy_scheme, "")
+ self.assertEqual(self.proxy_netloc, "")
+ self.assertEqual(self.query, "a=1&b=2")
+ self.assertEqual(self.fragment, "")
+
+ def test_split_uri_path_query_fragment(self):
+ self._callFUT(b"//testing/whatever?a=1&b=2#fragment")
+ self.assertEqual(self.path, "//testing/whatever")
+ self.assertEqual(self.proxy_scheme, "")
+ self.assertEqual(self.proxy_netloc, "")
+ self.assertEqual(self.query, "a=1&b=2")
+ self.assertEqual(self.fragment, "fragment")
+
+
+class Test_get_header_lines(unittest.TestCase):
+ def _callFUT(self, data):
+ from waitress.parser import get_header_lines
+
+ return get_header_lines(data)
+
+ def test_get_header_lines(self):
+ result = self._callFUT(b"slam\r\nslim")
+ self.assertEqual(result, [b"slam", b"slim"])
+
+ def test_get_header_lines_folded(self):
+ # From RFC2616:
+ # HTTP/1.1 header field values can be folded onto multiple lines if the
+ # continuation line begins with a space or horizontal tab. All linear
+ # white space, including folding, has the same semantics as SP. A
+ # recipient MAY replace any linear white space with a single SP before
+ # interpreting the field value or forwarding the message downstream.
+
+ # We are just preserving the whitespace that indicates folding.
+ result = self._callFUT(b"slim\r\n slam")
+ self.assertEqual(result, [b"slim slam"])
+
+ def test_get_header_lines_tabbed(self):
+ result = self._callFUT(b"slam\r\n\tslim")
+ self.assertEqual(result, [b"slam\tslim"])
+
+ def test_get_header_lines_malformed(self):
+ # https://corte.si/posts/code/pathod/pythonservers/index.html
+ from waitress.parser import ParsingError
+
+ self.assertRaises(ParsingError, self._callFUT, b" Host: localhost\r\n\r\n")
+
+
+class Test_crack_first_line(unittest.TestCase):
+ def _callFUT(self, line):
+ from waitress.parser import crack_first_line
+
+ return crack_first_line(line)
+
+ def test_crack_first_line_matchok(self):
+ result = self._callFUT(b"GET / HTTP/1.0")
+ self.assertEqual(result, (b"GET", b"/", b"1.0"))
+
+ def test_crack_first_line_lowercase_method(self):
+ from waitress.parser import ParsingError
+
+ self.assertRaises(ParsingError, self._callFUT, b"get / HTTP/1.0")
+
+ def test_crack_first_line_nomatch(self):
+ result = self._callFUT(b"GET / bleh")
+ self.assertEqual(result, (b"", b"", b""))
+
+ result = self._callFUT(b"GET /info?txtAirPlay&txtRAOP RTSP/1.0")
+ self.assertEqual(result, (b"", b"", b""))
+
+ def test_crack_first_line_missing_version(self):
+ result = self._callFUT(b"GET /")
+ self.assertEqual(result, (b"GET", b"/", b""))
+
+
+class TestHTTPRequestParserIntegration(unittest.TestCase):
+ def setUp(self):
+ from waitress.parser import HTTPRequestParser
+ from waitress.adjustments import Adjustments
+
+ my_adj = Adjustments()
+ self.parser = HTTPRequestParser(my_adj)
+
+ def feed(self, data):
+ parser = self.parser
+
+ for n in range(100): # make sure we never loop forever
+ consumed = parser.received(data)
+ data = data[consumed:]
+
+ if parser.completed:
+ return
+ raise ValueError("Looping") # pragma: no cover
+
+ def testSimpleGET(self):
+ data = (
+ b"GET /foobar HTTP/8.4\r\n"
+ b"FirstName: mickey\r\n"
+ b"lastname: Mouse\r\n"
+ b"content-length: 6\r\n"
+ b"\r\n"
+ b"Hello."
+ )
+ parser = self.parser
+ self.feed(data)
+ self.assertTrue(parser.completed)
+ self.assertEqual(parser.version, "8.4")
+ self.assertFalse(parser.empty)
+ self.assertEqual(
+ parser.headers,
+ {"FIRSTNAME": "mickey", "LASTNAME": "Mouse", "CONTENT_LENGTH": "6",},
+ )
+ self.assertEqual(parser.path, "/foobar")
+ self.assertEqual(parser.command, "GET")
+ self.assertEqual(parser.query, "")
+ self.assertEqual(parser.proxy_scheme, "")
+ self.assertEqual(parser.proxy_netloc, "")
+ self.assertEqual(parser.get_body_stream().getvalue(), b"Hello.")
+
+ def testComplexGET(self):
+ data = (
+ b"GET /foo/a+%2B%2F%C3%A4%3D%26a%3Aint?d=b+%2B%2F%3D%26b%3Aint&c+%2B%2F%3D%26c%3Aint=6 HTTP/8.4\r\n"
+ b"FirstName: mickey\r\n"
+ b"lastname: Mouse\r\n"
+ b"content-length: 10\r\n"
+ b"\r\n"
+ b"Hello mickey."
+ )
+ parser = self.parser
+ self.feed(data)
+ self.assertEqual(parser.command, "GET")
+ self.assertEqual(parser.version, "8.4")
+ self.assertFalse(parser.empty)
+ self.assertEqual(
+ parser.headers,
+ {"FIRSTNAME": "mickey", "LASTNAME": "Mouse", "CONTENT_LENGTH": "10"},
+ )
+ # path should be utf-8 encoded
+ self.assertEqual(
+ tobytes(parser.path).decode("utf-8"),
+ text_(b"/foo/a++/\xc3\xa4=&a:int", "utf-8"),
+ )
+ self.assertEqual(
+ parser.query, "d=b+%2B%2F%3D%26b%3Aint&c+%2B%2F%3D%26c%3Aint=6"
+ )
+ self.assertEqual(parser.get_body_stream().getvalue(), b"Hello mick")
+
+ def testProxyGET(self):
+ data = (
+ b"GET https://example.com:8080/foobar HTTP/8.4\r\n"
+ b"content-length: 6\r\n"
+ b"\r\n"
+ b"Hello."
+ )
+ parser = self.parser
+ self.feed(data)
+ self.assertTrue(parser.completed)
+ self.assertEqual(parser.version, "8.4")
+ self.assertFalse(parser.empty)
+ self.assertEqual(parser.headers, {"CONTENT_LENGTH": "6"})
+ self.assertEqual(parser.path, "/foobar")
+ self.assertEqual(parser.command, "GET")
+ self.assertEqual(parser.proxy_scheme, "https")
+ self.assertEqual(parser.proxy_netloc, "example.com:8080")
+ self.assertEqual(parser.command, "GET")
+ self.assertEqual(parser.query, "")
+ self.assertEqual(parser.get_body_stream().getvalue(), b"Hello.")
+
+ def testDuplicateHeaders(self):
+ # Ensure that headers with the same key get concatenated as per
+ # RFC2616.
+ data = (
+ b"GET /foobar HTTP/8.4\r\n"
+ b"x-forwarded-for: 10.11.12.13\r\n"
+ b"x-forwarded-for: unknown,127.0.0.1\r\n"
+ b"X-Forwarded_for: 255.255.255.255\r\n"
+ b"content-length: 6\r\n"
+ b"\r\n"
+ b"Hello."
+ )
+ self.feed(data)
+ self.assertTrue(self.parser.completed)
+ self.assertEqual(
+ self.parser.headers,
+ {
+ "CONTENT_LENGTH": "6",
+ "X_FORWARDED_FOR": "10.11.12.13, unknown,127.0.0.1",
+ },
+ )
+
+ def testSpoofedHeadersDropped(self):
+ data = (
+ b"GET /foobar HTTP/8.4\r\n"
+ b"x-auth_user: bob\r\n"
+ b"content-length: 6\r\n"
+ b"\r\n"
+ b"Hello."
+ )
+ self.feed(data)
+ self.assertTrue(self.parser.completed)
+ self.assertEqual(self.parser.headers, {"CONTENT_LENGTH": "6",})
+
+
+class DummyBodyStream(object):
+ def getfile(self):
+ return self
+
+ def getbuf(self):
+ return self
+
+ def close(self):
+ self.closed = True
diff --git a/libs/waitress/tests/test_proxy_headers.py b/libs/waitress/tests/test_proxy_headers.py
new file mode 100644
index 000000000..15b4a0828
--- /dev/null
+++ b/libs/waitress/tests/test_proxy_headers.py
@@ -0,0 +1,724 @@
+import unittest
+
+from waitress.compat import tobytes
+
+
+class TestProxyHeadersMiddleware(unittest.TestCase):
+ def _makeOne(self, app, **kw):
+ from waitress.proxy_headers import proxy_headers_middleware
+
+ return proxy_headers_middleware(app, **kw)
+
+ def _callFUT(self, app, **kw):
+ response = DummyResponse()
+ environ = DummyEnviron(**kw)
+
+ def start_response(status, response_headers):
+ response.status = status
+ response.headers = response_headers
+
+ response.steps = list(app(environ, start_response))
+ response.body = b"".join(tobytes(s) for s in response.steps)
+ return response
+
+ def test_get_environment_values_w_scheme_override_untrusted(self):
+ inner = DummyApp()
+ app = self._makeOne(inner)
+ response = self._callFUT(
+ app, headers={"X_FOO": "BAR", "X_FORWARDED_PROTO": "https",}
+ )
+ self.assertEqual(response.status, "200 OK")
+ self.assertEqual(inner.environ["wsgi.url_scheme"], "http")
+
+ def test_get_environment_values_w_scheme_override_trusted(self):
+ inner = DummyApp()
+ app = self._makeOne(
+ inner,
+ trusted_proxy="192.168.1.1",
+ trusted_proxy_headers={"x-forwarded-proto"},
+ )
+ response = self._callFUT(
+ app,
+ addr=["192.168.1.1", 8080],
+ headers={"X_FOO": "BAR", "X_FORWARDED_PROTO": "https",},
+ )
+
+ environ = inner.environ
+ self.assertEqual(response.status, "200 OK")
+ self.assertEqual(environ["SERVER_PORT"], "443")
+ self.assertEqual(environ["SERVER_NAME"], "localhost")
+ self.assertEqual(environ["REMOTE_ADDR"], "192.168.1.1")
+ self.assertEqual(environ["HTTP_X_FOO"], "BAR")
+
+ def test_get_environment_values_w_bogus_scheme_override(self):
+ inner = DummyApp()
+ app = self._makeOne(
+ inner,
+ trusted_proxy="192.168.1.1",
+ trusted_proxy_headers={"x-forwarded-proto"},
+ )
+ response = self._callFUT(
+ app,
+ addr=["192.168.1.1", 80],
+ headers={
+ "X_FOO": "BAR",
+ "X_FORWARDED_PROTO": "http://p02n3e.com?url=http",
+ },
+ )
+ self.assertEqual(response.status, "400 Bad Request")
+ self.assertIn(b'Header "X-Forwarded-Proto" malformed', response.body)
+
+ def test_get_environment_warning_other_proxy_headers(self):
+ inner = DummyApp()
+ logger = DummyLogger()
+ app = self._makeOne(
+ inner,
+ trusted_proxy="192.168.1.1",
+ trusted_proxy_count=1,
+ trusted_proxy_headers={"forwarded"},
+ log_untrusted=True,
+ logger=logger,
+ )
+ response = self._callFUT(
+ app,
+ addr=["192.168.1.1", 80],
+ headers={
+ "X_FORWARDED_FOR": "[2001:db8::1]",
+ "FORWARDED": "For=198.51.100.2;host=example.com:8080;proto=https",
+ },
+ )
+ self.assertEqual(response.status, "200 OK")
+
+ self.assertEqual(len(logger.logged), 1)
+
+ environ = inner.environ
+ self.assertNotIn("HTTP_X_FORWARDED_FOR", environ)
+ self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2")
+ self.assertEqual(environ["SERVER_NAME"], "example.com")
+ self.assertEqual(environ["HTTP_HOST"], "example.com:8080")
+ self.assertEqual(environ["SERVER_PORT"], "8080")
+ self.assertEqual(environ["wsgi.url_scheme"], "https")
+
+ def test_get_environment_contains_all_headers_including_untrusted(self):
+ inner = DummyApp()
+ app = self._makeOne(
+ inner,
+ trusted_proxy="192.168.1.1",
+ trusted_proxy_count=1,
+ trusted_proxy_headers={"x-forwarded-by"},
+ clear_untrusted=False,
+ )
+ headers_orig = {
+ "X_FORWARDED_FOR": "198.51.100.2",
+ "X_FORWARDED_BY": "Waitress",
+ "X_FORWARDED_PROTO": "https",
+ "X_FORWARDED_HOST": "example.org",
+ }
+ response = self._callFUT(
+ app, addr=["192.168.1.1", 80], headers=headers_orig.copy(),
+ )
+ self.assertEqual(response.status, "200 OK")
+ environ = inner.environ
+ for k, expected in headers_orig.items():
+ result = environ["HTTP_%s" % k]
+ self.assertEqual(result, expected)
+
+ def test_get_environment_contains_only_trusted_headers(self):
+ inner = DummyApp()
+ app = self._makeOne(
+ inner,
+ trusted_proxy="192.168.1.1",
+ trusted_proxy_count=1,
+ trusted_proxy_headers={"x-forwarded-by"},
+ clear_untrusted=True,
+ )
+ response = self._callFUT(
+ app,
+ addr=["192.168.1.1", 80],
+ headers={
+ "X_FORWARDED_FOR": "198.51.100.2",
+ "X_FORWARDED_BY": "Waitress",
+ "X_FORWARDED_PROTO": "https",
+ "X_FORWARDED_HOST": "example.org",
+ },
+ )
+ self.assertEqual(response.status, "200 OK")
+
+ environ = inner.environ
+ self.assertEqual(environ["HTTP_X_FORWARDED_BY"], "Waitress")
+ self.assertNotIn("HTTP_X_FORWARDED_FOR", environ)
+ self.assertNotIn("HTTP_X_FORWARDED_PROTO", environ)
+ self.assertNotIn("HTTP_X_FORWARDED_HOST", environ)
+
+ def test_get_environment_clears_headers_if_untrusted_proxy(self):
+ inner = DummyApp()
+ app = self._makeOne(
+ inner,
+ trusted_proxy="192.168.1.1",
+ trusted_proxy_count=1,
+ trusted_proxy_headers={"x-forwarded-by"},
+ clear_untrusted=True,
+ )
+ response = self._callFUT(
+ app,
+ addr=["192.168.1.255", 80],
+ headers={
+ "X_FORWARDED_FOR": "198.51.100.2",
+ "X_FORWARDED_BY": "Waitress",
+ "X_FORWARDED_PROTO": "https",
+ "X_FORWARDED_HOST": "example.org",
+ },
+ )
+ self.assertEqual(response.status, "200 OK")
+
+ environ = inner.environ
+ self.assertNotIn("HTTP_X_FORWARDED_BY", environ)
+ self.assertNotIn("HTTP_X_FORWARDED_FOR", environ)
+ self.assertNotIn("HTTP_X_FORWARDED_PROTO", environ)
+ self.assertNotIn("HTTP_X_FORWARDED_HOST", environ)
+
+ def test_parse_proxy_headers_forwarded_for(self):
+ inner = DummyApp()
+ app = self._makeOne(
+ inner, trusted_proxy="*", trusted_proxy_headers={"x-forwarded-for"},
+ )
+ response = self._callFUT(app, headers={"X_FORWARDED_FOR": "192.0.2.1"})
+ self.assertEqual(response.status, "200 OK")
+
+ environ = inner.environ
+ self.assertEqual(environ["REMOTE_ADDR"], "192.0.2.1")
+
+ def test_parse_proxy_headers_forwarded_for_v6_missing_brackets(self):
+ inner = DummyApp()
+ app = self._makeOne(
+ inner, trusted_proxy="*", trusted_proxy_headers={"x-forwarded-for"},
+ )
+ response = self._callFUT(app, headers={"X_FORWARDED_FOR": "2001:db8::0"})
+ self.assertEqual(response.status, "200 OK")
+
+ environ = inner.environ
+ self.assertEqual(environ["REMOTE_ADDR"], "2001:db8::0")
+
+ def test_parse_proxy_headers_forwared_for_multiple(self):
+ inner = DummyApp()
+ app = self._makeOne(
+ inner,
+ trusted_proxy="*",
+ trusted_proxy_count=2,
+ trusted_proxy_headers={"x-forwarded-for"},
+ )
+ response = self._callFUT(
+ app, headers={"X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1"}
+ )
+ self.assertEqual(response.status, "200 OK")
+
+ environ = inner.environ
+ self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2")
+
+ def test_parse_forwarded_multiple_proxies_trust_only_two(self):
+ inner = DummyApp()
+ app = self._makeOne(
+ inner,
+ trusted_proxy="*",
+ trusted_proxy_count=2,
+ trusted_proxy_headers={"forwarded"},
+ )
+ response = self._callFUT(
+ app,
+ headers={
+ "FORWARDED": (
+ "For=192.0.2.1;host=fake.com, "
+ "For=198.51.100.2;host=example.com:8080, "
+ "For=203.0.113.1"
+ ),
+ },
+ )
+ self.assertEqual(response.status, "200 OK")
+
+ environ = inner.environ
+ self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2")
+ self.assertEqual(environ["SERVER_NAME"], "example.com")
+ self.assertEqual(environ["HTTP_HOST"], "example.com:8080")
+ self.assertEqual(environ["SERVER_PORT"], "8080")
+
+ def test_parse_forwarded_multiple_proxies(self):
+ inner = DummyApp()
+ app = self._makeOne(
+ inner,
+ trusted_proxy="*",
+ trusted_proxy_count=2,
+ trusted_proxy_headers={"forwarded"},
+ )
+ response = self._callFUT(
+ app,
+ headers={
+ "FORWARDED": (
+ 'for="[2001:db8::1]:3821";host="example.com:8443";proto="https", '
+ 'for=192.0.2.1;host="example.internal:8080"'
+ ),
+ },
+ )
+ self.assertEqual(response.status, "200 OK")
+
+ environ = inner.environ
+ self.assertEqual(environ["REMOTE_ADDR"], "2001:db8::1")
+ self.assertEqual(environ["REMOTE_PORT"], "3821")
+ self.assertEqual(environ["SERVER_NAME"], "example.com")
+ self.assertEqual(environ["HTTP_HOST"], "example.com:8443")
+ self.assertEqual(environ["SERVER_PORT"], "8443")
+ self.assertEqual(environ["wsgi.url_scheme"], "https")
+
+ def test_parse_forwarded_multiple_proxies_minimal(self):
+ inner = DummyApp()
+ app = self._makeOne(
+ inner,
+ trusted_proxy="*",
+ trusted_proxy_count=2,
+ trusted_proxy_headers={"forwarded"},
+ )
+ response = self._callFUT(
+ app,
+ headers={
+ "FORWARDED": (
+ 'for="[2001:db8::1]";proto="https", '
+ 'for=192.0.2.1;host="example.org"'
+ ),
+ },
+ )
+ self.assertEqual(response.status, "200 OK")
+
+ environ = inner.environ
+ self.assertEqual(environ["REMOTE_ADDR"], "2001:db8::1")
+ self.assertEqual(environ["SERVER_NAME"], "example.org")
+ self.assertEqual(environ["HTTP_HOST"], "example.org")
+ self.assertEqual(environ["SERVER_PORT"], "443")
+ self.assertEqual(environ["wsgi.url_scheme"], "https")
+
+ def test_parse_proxy_headers_forwarded_host_with_port(self):
+ inner = DummyApp()
+ app = self._makeOne(
+ inner,
+ trusted_proxy="*",
+ trusted_proxy_count=2,
+ trusted_proxy_headers={
+ "x-forwarded-for",
+ "x-forwarded-proto",
+ "x-forwarded-host",
+ },
+ )
+ response = self._callFUT(
+ app,
+ headers={
+ "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1",
+ "X_FORWARDED_PROTO": "http",
+ "X_FORWARDED_HOST": "example.com:8080",
+ },
+ )
+ self.assertEqual(response.status, "200 OK")
+
+ environ = inner.environ
+ self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2")
+ self.assertEqual(environ["SERVER_NAME"], "example.com")
+ self.assertEqual(environ["HTTP_HOST"], "example.com:8080")
+ self.assertEqual(environ["SERVER_PORT"], "8080")
+
+ def test_parse_proxy_headers_forwarded_host_without_port(self):
+ inner = DummyApp()
+ app = self._makeOne(
+ inner,
+ trusted_proxy="*",
+ trusted_proxy_count=2,
+ trusted_proxy_headers={
+ "x-forwarded-for",
+ "x-forwarded-proto",
+ "x-forwarded-host",
+ },
+ )
+ response = self._callFUT(
+ app,
+ headers={
+ "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1",
+ "X_FORWARDED_PROTO": "http",
+ "X_FORWARDED_HOST": "example.com",
+ },
+ )
+ self.assertEqual(response.status, "200 OK")
+
+ environ = inner.environ
+ self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2")
+ self.assertEqual(environ["SERVER_NAME"], "example.com")
+ self.assertEqual(environ["HTTP_HOST"], "example.com")
+ self.assertEqual(environ["SERVER_PORT"], "80")
+
+ def test_parse_proxy_headers_forwarded_host_with_forwarded_port(self):
+ inner = DummyApp()
+ app = self._makeOne(
+ inner,
+ trusted_proxy="*",
+ trusted_proxy_count=2,
+ trusted_proxy_headers={
+ "x-forwarded-for",
+ "x-forwarded-proto",
+ "x-forwarded-host",
+ "x-forwarded-port",
+ },
+ )
+ response = self._callFUT(
+ app,
+ headers={
+ "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1",
+ "X_FORWARDED_PROTO": "http",
+ "X_FORWARDED_HOST": "example.com",
+ "X_FORWARDED_PORT": "8080",
+ },
+ )
+ self.assertEqual(response.status, "200 OK")
+
+ environ = inner.environ
+ self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2")
+ self.assertEqual(environ["SERVER_NAME"], "example.com")
+ self.assertEqual(environ["HTTP_HOST"], "example.com:8080")
+ self.assertEqual(environ["SERVER_PORT"], "8080")
+
+ def test_parse_proxy_headers_forwarded_host_multiple_with_forwarded_port(self):
+ inner = DummyApp()
+ app = self._makeOne(
+ inner,
+ trusted_proxy="*",
+ trusted_proxy_count=2,
+ trusted_proxy_headers={
+ "x-forwarded-for",
+ "x-forwarded-proto",
+ "x-forwarded-host",
+ "x-forwarded-port",
+ },
+ )
+ response = self._callFUT(
+ app,
+ headers={
+ "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1",
+ "X_FORWARDED_PROTO": "http",
+ "X_FORWARDED_HOST": "example.com, example.org",
+ "X_FORWARDED_PORT": "8080",
+ },
+ )
+ self.assertEqual(response.status, "200 OK")
+
+ environ = inner.environ
+ self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2")
+ self.assertEqual(environ["SERVER_NAME"], "example.com")
+ self.assertEqual(environ["HTTP_HOST"], "example.com:8080")
+ self.assertEqual(environ["SERVER_PORT"], "8080")
+
+ def test_parse_proxy_headers_forwarded_host_multiple_with_forwarded_port_limit_one_trusted(
+ self,
+ ):
+ inner = DummyApp()
+ app = self._makeOne(
+ inner,
+ trusted_proxy="*",
+ trusted_proxy_count=1,
+ trusted_proxy_headers={
+ "x-forwarded-for",
+ "x-forwarded-proto",
+ "x-forwarded-host",
+ "x-forwarded-port",
+ },
+ )
+ response = self._callFUT(
+ app,
+ headers={
+ "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1",
+ "X_FORWARDED_PROTO": "http",
+ "X_FORWARDED_HOST": "example.com, example.org",
+ "X_FORWARDED_PORT": "8080",
+ },
+ )
+ self.assertEqual(response.status, "200 OK")
+
+ environ = inner.environ
+ self.assertEqual(environ["REMOTE_ADDR"], "203.0.113.1")
+ self.assertEqual(environ["SERVER_NAME"], "example.org")
+ self.assertEqual(environ["HTTP_HOST"], "example.org:8080")
+ self.assertEqual(environ["SERVER_PORT"], "8080")
+
+ def test_parse_forwarded(self):
+ inner = DummyApp()
+ app = self._makeOne(
+ inner,
+ trusted_proxy="*",
+ trusted_proxy_count=1,
+ trusted_proxy_headers={"forwarded"},
+ )
+ response = self._callFUT(
+ app,
+ headers={
+ "FORWARDED": "For=198.51.100.2:5858;host=example.com:8080;proto=https",
+ },
+ )
+ self.assertEqual(response.status, "200 OK")
+
+ environ = inner.environ
+ self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2")
+ self.assertEqual(environ["REMOTE_PORT"], "5858")
+ self.assertEqual(environ["SERVER_NAME"], "example.com")
+ self.assertEqual(environ["HTTP_HOST"], "example.com:8080")
+ self.assertEqual(environ["SERVER_PORT"], "8080")
+ self.assertEqual(environ["wsgi.url_scheme"], "https")
+
+ def test_parse_forwarded_empty_pair(self):
+ inner = DummyApp()
+ app = self._makeOne(
+ inner,
+ trusted_proxy="*",
+ trusted_proxy_count=1,
+ trusted_proxy_headers={"forwarded"},
+ )
+ response = self._callFUT(
+ app, headers={"FORWARDED": "For=198.51.100.2;;proto=https;by=_unused",}
+ )
+ self.assertEqual(response.status, "200 OK")
+
+ environ = inner.environ
+ self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2")
+
+ def test_parse_forwarded_pair_token_whitespace(self):
+ inner = DummyApp()
+ app = self._makeOne(
+ inner,
+ trusted_proxy="*",
+ trusted_proxy_count=1,
+ trusted_proxy_headers={"forwarded"},
+ )
+ response = self._callFUT(
+ app, headers={"FORWARDED": "For=198.51.100.2; proto =https",}
+ )
+ self.assertEqual(response.status, "400 Bad Request")
+ self.assertIn(b'Header "Forwarded" malformed', response.body)
+
+ def test_parse_forwarded_pair_value_whitespace(self):
+ inner = DummyApp()
+ app = self._makeOne(
+ inner,
+ trusted_proxy="*",
+ trusted_proxy_count=1,
+ trusted_proxy_headers={"forwarded"},
+ )
+ response = self._callFUT(
+ app, headers={"FORWARDED": 'For= "198.51.100.2"; proto =https',}
+ )
+ self.assertEqual(response.status, "400 Bad Request")
+ self.assertIn(b'Header "Forwarded" malformed', response.body)
+
+ def test_parse_forwarded_pair_no_equals(self):
+ inner = DummyApp()
+ app = self._makeOne(
+ inner,
+ trusted_proxy="*",
+ trusted_proxy_count=1,
+ trusted_proxy_headers={"forwarded"},
+ )
+ response = self._callFUT(app, headers={"FORWARDED": "For"})
+ self.assertEqual(response.status, "400 Bad Request")
+ self.assertIn(b'Header "Forwarded" malformed', response.body)
+
+ def test_parse_forwarded_warning_unknown_token(self):
+ inner = DummyApp()
+ logger = DummyLogger()
+ app = self._makeOne(
+ inner,
+ trusted_proxy="*",
+ trusted_proxy_count=1,
+ trusted_proxy_headers={"forwarded"},
+ logger=logger,
+ )
+ response = self._callFUT(
+ app,
+ headers={
+ "FORWARDED": (
+ "For=198.51.100.2;host=example.com:8080;proto=https;"
+ 'unknown="yolo"'
+ ),
+ },
+ )
+ self.assertEqual(response.status, "200 OK")
+
+ self.assertEqual(len(logger.logged), 1)
+ self.assertIn("Unknown Forwarded token", logger.logged[0])
+
+ environ = inner.environ
+ self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2")
+ self.assertEqual(environ["SERVER_NAME"], "example.com")
+ self.assertEqual(environ["HTTP_HOST"], "example.com:8080")
+ self.assertEqual(environ["SERVER_PORT"], "8080")
+ self.assertEqual(environ["wsgi.url_scheme"], "https")
+
+ def test_parse_no_valid_proxy_headers(self):
+ inner = DummyApp()
+ app = self._makeOne(inner, trusted_proxy="*", trusted_proxy_count=1,)
+ response = self._callFUT(
+ app,
+ headers={
+ "X_FORWARDED_FOR": "198.51.100.2",
+ "FORWARDED": "For=198.51.100.2;host=example.com:8080;proto=https",
+ },
+ )
+ self.assertEqual(response.status, "200 OK")
+
+ environ = inner.environ
+ self.assertEqual(environ["REMOTE_ADDR"], "127.0.0.1")
+ self.assertEqual(environ["SERVER_NAME"], "localhost")
+ self.assertEqual(environ["HTTP_HOST"], "192.168.1.1:80")
+ self.assertEqual(environ["SERVER_PORT"], "8080")
+ self.assertEqual(environ["wsgi.url_scheme"], "http")
+
+ def test_parse_multiple_x_forwarded_proto(self):
+ inner = DummyApp()
+ logger = DummyLogger()
+ app = self._makeOne(
+ inner,
+ trusted_proxy="*",
+ trusted_proxy_count=1,
+ trusted_proxy_headers={"x-forwarded-proto"},
+ logger=logger,
+ )
+ response = self._callFUT(app, headers={"X_FORWARDED_PROTO": "http, https",})
+ self.assertEqual(response.status, "400 Bad Request")
+ self.assertIn(b'Header "X-Forwarded-Proto" malformed', response.body)
+
+ def test_parse_multiple_x_forwarded_port(self):
+ inner = DummyApp()
+ logger = DummyLogger()
+ app = self._makeOne(
+ inner,
+ trusted_proxy="*",
+ trusted_proxy_count=1,
+ trusted_proxy_headers={"x-forwarded-port"},
+ logger=logger,
+ )
+ response = self._callFUT(app, headers={"X_FORWARDED_PORT": "443, 80",})
+ self.assertEqual(response.status, "400 Bad Request")
+ self.assertIn(b'Header "X-Forwarded-Port" malformed', response.body)
+
+ def test_parse_forwarded_port_wrong_proto_port_80(self):
+ inner = DummyApp()
+ app = self._makeOne(
+ inner,
+ trusted_proxy="*",
+ trusted_proxy_count=1,
+ trusted_proxy_headers={
+ "x-forwarded-port",
+ "x-forwarded-host",
+ "x-forwarded-proto",
+ },
+ )
+ response = self._callFUT(
+ app,
+ headers={
+ "X_FORWARDED_PORT": "80",
+ "X_FORWARDED_PROTO": "https",
+ "X_FORWARDED_HOST": "example.com",
+ },
+ )
+ self.assertEqual(response.status, "200 OK")
+
+ environ = inner.environ
+ self.assertEqual(environ["SERVER_NAME"], "example.com")
+ self.assertEqual(environ["HTTP_HOST"], "example.com:80")
+ self.assertEqual(environ["SERVER_PORT"], "80")
+ self.assertEqual(environ["wsgi.url_scheme"], "https")
+
+ def test_parse_forwarded_port_wrong_proto_port_443(self):
+ inner = DummyApp()
+ app = self._makeOne(
+ inner,
+ trusted_proxy="*",
+ trusted_proxy_count=1,
+ trusted_proxy_headers={
+ "x-forwarded-port",
+ "x-forwarded-host",
+ "x-forwarded-proto",
+ },
+ )
+ response = self._callFUT(
+ app,
+ headers={
+ "X_FORWARDED_PORT": "443",
+ "X_FORWARDED_PROTO": "http",
+ "X_FORWARDED_HOST": "example.com",
+ },
+ )
+ self.assertEqual(response.status, "200 OK")
+
+ environ = inner.environ
+ self.assertEqual(environ["SERVER_NAME"], "example.com")
+ self.assertEqual(environ["HTTP_HOST"], "example.com:443")
+ self.assertEqual(environ["SERVER_PORT"], "443")
+ self.assertEqual(environ["wsgi.url_scheme"], "http")
+
+ def test_parse_forwarded_for_bad_quote(self):
+ inner = DummyApp()
+ app = self._makeOne(
+ inner,
+ trusted_proxy="*",
+ trusted_proxy_count=1,
+ trusted_proxy_headers={"x-forwarded-for"},
+ )
+ response = self._callFUT(app, headers={"X_FORWARDED_FOR": '"foo'})
+ self.assertEqual(response.status, "400 Bad Request")
+ self.assertIn(b'Header "X-Forwarded-For" malformed', response.body)
+
+ def test_parse_forwarded_host_bad_quote(self):
+ inner = DummyApp()
+ app = self._makeOne(
+ inner,
+ trusted_proxy="*",
+ trusted_proxy_count=1,
+ trusted_proxy_headers={"x-forwarded-host"},
+ )
+ response = self._callFUT(app, headers={"X_FORWARDED_HOST": '"foo'})
+ self.assertEqual(response.status, "400 Bad Request")
+ self.assertIn(b'Header "X-Forwarded-Host" malformed', response.body)
+
+
+class DummyLogger(object):
+ def __init__(self):
+ self.logged = []
+
+ def warning(self, msg, *args):
+ self.logged.append(msg % args)
+
+
+class DummyApp(object):
+ def __call__(self, environ, start_response):
+ self.environ = environ
+ start_response("200 OK", [("Content-Type", "text/plain")])
+ yield "hello"
+
+
+class DummyResponse(object):
+ status = None
+ headers = None
+ body = None
+
+
+def DummyEnviron(
+ addr=("127.0.0.1", 8080), scheme="http", server="localhost", headers=None,
+):
+ environ = {
+ "REMOTE_ADDR": addr[0],
+ "REMOTE_HOST": addr[0],
+ "REMOTE_PORT": addr[1],
+ "SERVER_PORT": str(addr[1]),
+ "SERVER_NAME": server,
+ "wsgi.url_scheme": scheme,
+ "HTTP_HOST": "192.168.1.1:80",
+ }
+ if headers:
+ environ.update(
+ {
+ "HTTP_" + key.upper().replace("-", "_"): value
+ for key, value in headers.items()
+ }
+ )
+ return environ
diff --git a/libs/waitress/tests/test_receiver.py b/libs/waitress/tests/test_receiver.py
new file mode 100644
index 000000000..b4910bba8
--- /dev/null
+++ b/libs/waitress/tests/test_receiver.py
@@ -0,0 +1,242 @@
+import unittest
+
+
+class TestFixedStreamReceiver(unittest.TestCase):
+ def _makeOne(self, cl, buf):
+ from waitress.receiver import FixedStreamReceiver
+
+ return FixedStreamReceiver(cl, buf)
+
+ def test_received_remain_lt_1(self):
+ buf = DummyBuffer()
+ inst = self._makeOne(0, buf)
+ result = inst.received("a")
+ self.assertEqual(result, 0)
+ self.assertEqual(inst.completed, True)
+
+ def test_received_remain_lte_datalen(self):
+ buf = DummyBuffer()
+ inst = self._makeOne(1, buf)
+ result = inst.received("aa")
+ self.assertEqual(result, 1)
+ self.assertEqual(inst.completed, True)
+ self.assertEqual(inst.completed, 1)
+ self.assertEqual(inst.remain, 0)
+ self.assertEqual(buf.data, ["a"])
+
+ def test_received_remain_gt_datalen(self):
+ buf = DummyBuffer()
+ inst = self._makeOne(10, buf)
+ result = inst.received("aa")
+ self.assertEqual(result, 2)
+ self.assertEqual(inst.completed, False)
+ self.assertEqual(inst.remain, 8)
+ self.assertEqual(buf.data, ["aa"])
+
+ def test_getfile(self):
+ buf = DummyBuffer()
+ inst = self._makeOne(10, buf)
+ self.assertEqual(inst.getfile(), buf)
+
+ def test_getbuf(self):
+ buf = DummyBuffer()
+ inst = self._makeOne(10, buf)
+ self.assertEqual(inst.getbuf(), buf)
+
+ def test___len__(self):
+ buf = DummyBuffer(["1", "2"])
+ inst = self._makeOne(10, buf)
+ self.assertEqual(inst.__len__(), 2)
+
+
+class TestChunkedReceiver(unittest.TestCase):
+ def _makeOne(self, buf):
+ from waitress.receiver import ChunkedReceiver
+
+ return ChunkedReceiver(buf)
+
+ def test_alreadycompleted(self):
+ buf = DummyBuffer()
+ inst = self._makeOne(buf)
+ inst.completed = True
+ result = inst.received(b"a")
+ self.assertEqual(result, 0)
+ self.assertEqual(inst.completed, True)
+
+ def test_received_remain_gt_zero(self):
+ buf = DummyBuffer()
+ inst = self._makeOne(buf)
+ inst.chunk_remainder = 100
+ result = inst.received(b"a")
+ self.assertEqual(inst.chunk_remainder, 99)
+ self.assertEqual(result, 1)
+ self.assertEqual(inst.completed, False)
+
+ def test_received_control_line_notfinished(self):
+ buf = DummyBuffer()
+ inst = self._makeOne(buf)
+ result = inst.received(b"a")
+ self.assertEqual(inst.control_line, b"a")
+ self.assertEqual(result, 1)
+ self.assertEqual(inst.completed, False)
+
+ def test_received_control_line_finished_garbage_in_input(self):
+ buf = DummyBuffer()
+ inst = self._makeOne(buf)
+ result = inst.received(b"garbage\r\n")
+ self.assertEqual(result, 9)
+ self.assertTrue(inst.error)
+
+ def test_received_control_line_finished_all_chunks_not_received(self):
+ buf = DummyBuffer()
+ inst = self._makeOne(buf)
+ result = inst.received(b"a;discard\r\n")
+ self.assertEqual(inst.control_line, b"")
+ self.assertEqual(inst.chunk_remainder, 10)
+ self.assertEqual(inst.all_chunks_received, False)
+ self.assertEqual(result, 11)
+ self.assertEqual(inst.completed, False)
+
+ def test_received_control_line_finished_all_chunks_received(self):
+ buf = DummyBuffer()
+ inst = self._makeOne(buf)
+ result = inst.received(b"0;discard\r\n")
+ self.assertEqual(inst.control_line, b"")
+ self.assertEqual(inst.all_chunks_received, True)
+ self.assertEqual(result, 11)
+ self.assertEqual(inst.completed, False)
+
+ def test_received_trailer_startswith_crlf(self):
+ buf = DummyBuffer()
+ inst = self._makeOne(buf)
+ inst.all_chunks_received = True
+ result = inst.received(b"\r\n")
+ self.assertEqual(result, 2)
+ self.assertEqual(inst.completed, True)
+
+ def test_received_trailer_startswith_lf(self):
+ buf = DummyBuffer()
+ inst = self._makeOne(buf)
+ inst.all_chunks_received = True
+ result = inst.received(b"\n")
+ self.assertEqual(result, 1)
+ self.assertEqual(inst.completed, False)
+
+ def test_received_trailer_not_finished(self):
+ buf = DummyBuffer()
+ inst = self._makeOne(buf)
+ inst.all_chunks_received = True
+ result = inst.received(b"a")
+ self.assertEqual(result, 1)
+ self.assertEqual(inst.completed, False)
+
+ def test_received_trailer_finished(self):
+ buf = DummyBuffer()
+ inst = self._makeOne(buf)
+ inst.all_chunks_received = True
+ result = inst.received(b"abc\r\n\r\n")
+ self.assertEqual(inst.trailer, b"abc\r\n\r\n")
+ self.assertEqual(result, 7)
+ self.assertEqual(inst.completed, True)
+
+ def test_getfile(self):
+ buf = DummyBuffer()
+ inst = self._makeOne(buf)
+ self.assertEqual(inst.getfile(), buf)
+
+ def test_getbuf(self):
+ buf = DummyBuffer()
+ inst = self._makeOne(buf)
+ self.assertEqual(inst.getbuf(), buf)
+
+ def test___len__(self):
+ buf = DummyBuffer(["1", "2"])
+ inst = self._makeOne(buf)
+ self.assertEqual(inst.__len__(), 2)
+
+ def test_received_chunk_is_properly_terminated(self):
+ buf = DummyBuffer()
+ inst = self._makeOne(buf)
+ data = b"4\r\nWiki\r\n"
+ result = inst.received(data)
+ self.assertEqual(result, len(data))
+ self.assertEqual(inst.completed, False)
+ self.assertEqual(buf.data[0], b"Wiki")
+
+ def test_received_chunk_not_properly_terminated(self):
+ from waitress.utilities import BadRequest
+
+ buf = DummyBuffer()
+ inst = self._makeOne(buf)
+ data = b"4\r\nWikibadchunk\r\n"
+ result = inst.received(data)
+ self.assertEqual(result, len(data))
+ self.assertEqual(inst.completed, False)
+ self.assertEqual(buf.data[0], b"Wiki")
+ self.assertEqual(inst.error.__class__, BadRequest)
+
+ def test_received_multiple_chunks(self):
+ from waitress.utilities import BadRequest
+
+ buf = DummyBuffer()
+ inst = self._makeOne(buf)
+ data = (
+ b"4\r\n"
+ b"Wiki\r\n"
+ b"5\r\n"
+ b"pedia\r\n"
+ b"E\r\n"
+ b" in\r\n"
+ b"\r\n"
+ b"chunks.\r\n"
+ b"0\r\n"
+ b"\r\n"
+ )
+ result = inst.received(data)
+ self.assertEqual(result, len(data))
+ self.assertEqual(inst.completed, True)
+ self.assertEqual(b"".join(buf.data), b"Wikipedia in\r\n\r\nchunks.")
+ self.assertEqual(inst.error, None)
+
+ def test_received_multiple_chunks_split(self):
+ from waitress.utilities import BadRequest
+
+ buf = DummyBuffer()
+ inst = self._makeOne(buf)
+ data1 = b"4\r\nWiki\r"
+ result = inst.received(data1)
+ self.assertEqual(result, len(data1))
+
+ data2 = (
+ b"\n5\r\n"
+ b"pedia\r\n"
+ b"E\r\n"
+ b" in\r\n"
+ b"\r\n"
+ b"chunks.\r\n"
+ b"0\r\n"
+ b"\r\n"
+ )
+
+ result = inst.received(data2)
+ self.assertEqual(result, len(data2))
+
+ self.assertEqual(inst.completed, True)
+ self.assertEqual(b"".join(buf.data), b"Wikipedia in\r\n\r\nchunks.")
+ self.assertEqual(inst.error, None)
+
+
+class DummyBuffer(object):
+ def __init__(self, data=None):
+ if data is None:
+ data = []
+ self.data = data
+
+ def append(self, s):
+ self.data.append(s)
+
+ def getfile(self):
+ return self
+
+ def __len__(self):
+ return len(self.data)
diff --git a/libs/waitress/tests/test_regression.py b/libs/waitress/tests/test_regression.py
new file mode 100644
index 000000000..3c4c6c202
--- /dev/null
+++ b/libs/waitress/tests/test_regression.py
@@ -0,0 +1,147 @@
+##############################################################################
+#
+# Copyright (c) 2005 Zope Foundation and Contributors.
+# All Rights Reserved.
+#
+# This software is subject to the provisions of the Zope Public License,
+# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
+# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
+# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
+# FOR A PARTICULAR PURPOSE.
+#
+##############################################################################
+"""Tests for waitress.channel maintenance logic
+"""
+import doctest
+
+
+class FakeSocket: # pragma: no cover
+ data = ""
+ setblocking = lambda *_: None
+ close = lambda *_: None
+
+ def __init__(self, no):
+ self.no = no
+
+ def fileno(self):
+ return self.no
+
+ def getpeername(self):
+ return ("localhost", self.no)
+
+ def send(self, data):
+ self.data += data
+ return len(data)
+
+ def recv(self, data):
+ return "data"
+
+
+def zombies_test():
+ """Regression test for HTTPChannel.maintenance method
+
+ Bug: This method checks for channels that have been "inactive" for a
+ configured time. The bug was that last_activity is set at creation time
+ but never updated during async channel activity (reads and writes), so
+ any channel older than the configured timeout will be closed when a new
+ channel is created, regardless of activity.
+
+ >>> import time
+ >>> import waitress.adjustments
+ >>> config = waitress.adjustments.Adjustments()
+
+ >>> from waitress.server import HTTPServer
+ >>> class TestServer(HTTPServer):
+ ... def bind(self, (ip, port)):
+ ... print "Listening on %s:%d" % (ip or '*', port)
+ >>> sb = TestServer('127.0.0.1', 80, start=False, verbose=True)
+ Listening on 127.0.0.1:80
+
+ First we confirm the correct behavior, where a channel with no activity
+ for the timeout duration gets closed.
+
+ >>> from waitress.channel import HTTPChannel
+ >>> socket = FakeSocket(42)
+ >>> channel = HTTPChannel(sb, socket, ('localhost', 42))
+
+ >>> channel.connected
+ True
+
+ >>> channel.last_activity -= int(config.channel_timeout) + 1
+
+ >>> channel.next_channel_cleanup[0] = channel.creation_time - int(
+ ... config.cleanup_interval) - 1
+
+ >>> socket2 = FakeSocket(7)
+ >>> channel2 = HTTPChannel(sb, socket2, ('localhost', 7))
+
+ >>> channel.connected
+ False
+
+ Write Activity
+ --------------
+
+ Now we make sure that if there is activity the channel doesn't get closed
+ incorrectly.
+
+ >>> channel2.connected
+ True
+
+ >>> channel2.last_activity -= int(config.channel_timeout) + 1
+
+ >>> channel2.handle_write()
+
+ >>> channel2.next_channel_cleanup[0] = channel2.creation_time - int(
+ ... config.cleanup_interval) - 1
+
+ >>> socket3 = FakeSocket(3)
+ >>> channel3 = HTTPChannel(sb, socket3, ('localhost', 3))
+
+ >>> channel2.connected
+ True
+
+ Read Activity
+ --------------
+
+ We should test to see that read activity will update a channel as well.
+
+ >>> channel3.connected
+ True
+
+ >>> channel3.last_activity -= int(config.channel_timeout) + 1
+
+ >>> import waitress.parser
+ >>> channel3.parser_class = (
+ ... waitress.parser.HTTPRequestParser)
+ >>> channel3.handle_read()
+
+ >>> channel3.next_channel_cleanup[0] = channel3.creation_time - int(
+ ... config.cleanup_interval) - 1
+
+ >>> socket4 = FakeSocket(4)
+ >>> channel4 = HTTPChannel(sb, socket4, ('localhost', 4))
+
+ >>> channel3.connected
+ True
+
+ Main loop window
+ ----------------
+
+ There is also a corner case we'll do a shallow test for where a
+ channel can be closed waiting for the main loop.
+
+ >>> channel4.last_activity -= 1
+
+ >>> last_active = channel4.last_activity
+
+ >>> channel4.set_async()
+
+ >>> channel4.last_activity != last_active
+ True
+
+"""
+
+
+def test_suite():
+ return doctest.DocTestSuite()
diff --git a/libs/waitress/tests/test_runner.py b/libs/waitress/tests/test_runner.py
new file mode 100644
index 000000000..127757e15
--- /dev/null
+++ b/libs/waitress/tests/test_runner.py
@@ -0,0 +1,191 @@
+import contextlib
+import os
+import sys
+
+if sys.version_info[:2] == (2, 6): # pragma: no cover
+ import unittest2 as unittest
+else: # pragma: no cover
+ import unittest
+
+from waitress import runner
+
+
+class Test_match(unittest.TestCase):
+ def test_empty(self):
+ self.assertRaisesRegexp(
+ ValueError, "^Malformed application ''$", runner.match, ""
+ )
+
+ def test_module_only(self):
+ self.assertRaisesRegexp(
+ ValueError, r"^Malformed application 'foo\.bar'$", runner.match, "foo.bar"
+ )
+
+ def test_bad_module(self):
+ self.assertRaisesRegexp(
+ ValueError,
+ r"^Malformed application 'foo#bar:barney'$",
+ runner.match,
+ "foo#bar:barney",
+ )
+
+ def test_module_obj(self):
+ self.assertTupleEqual(
+ runner.match("foo.bar:fred.barney"), ("foo.bar", "fred.barney")
+ )
+
+
+class Test_resolve(unittest.TestCase):
+ def test_bad_module(self):
+ self.assertRaises(
+ ImportError, runner.resolve, "nonexistent", "nonexistent_function"
+ )
+
+ def test_nonexistent_function(self):
+ self.assertRaisesRegexp(
+ AttributeError,
+ r"has no attribute 'nonexistent_function'",
+ runner.resolve,
+ "os.path",
+ "nonexistent_function",
+ )
+
+ def test_simple_happy_path(self):
+ from os.path import exists
+
+ self.assertIs(runner.resolve("os.path", "exists"), exists)
+
+ def test_complex_happy_path(self):
+ # Ensure we can recursively resolve object attributes if necessary.
+ self.assertEquals(runner.resolve("os.path", "exists.__name__"), "exists")
+
+
+class Test_run(unittest.TestCase):
+ def match_output(self, argv, code, regex):
+ argv = ["waitress-serve"] + argv
+ with capture() as captured:
+ self.assertEqual(runner.run(argv=argv), code)
+ self.assertRegexpMatches(captured.getvalue(), regex)
+ captured.close()
+
+ def test_bad(self):
+ self.match_output(["--bad-opt"], 1, "^Error: option --bad-opt not recognized")
+
+ def test_help(self):
+ self.match_output(["--help"], 0, "^Usage:\n\n waitress-serve")
+
+ def test_no_app(self):
+ self.match_output([], 1, "^Error: Specify one application only")
+
+ def test_multiple_apps_app(self):
+ self.match_output(["a:a", "b:b"], 1, "^Error: Specify one application only")
+
+ def test_bad_apps_app(self):
+ self.match_output(["a"], 1, "^Error: Malformed application 'a'")
+
+ def test_bad_app_module(self):
+ self.match_output(["nonexistent:a"], 1, "^Error: Bad module 'nonexistent'")
+
+ self.match_output(
+ ["nonexistent:a"],
+ 1,
+ (
+ r"There was an exception \((ImportError|ModuleNotFoundError)\) "
+ "importing your module.\n\nIt had these arguments: \n"
+ "1. No module named '?nonexistent'?"
+ ),
+ )
+
+ def test_cwd_added_to_path(self):
+ def null_serve(app, **kw):
+ pass
+
+ sys_path = sys.path
+ current_dir = os.getcwd()
+ try:
+ os.chdir(os.path.dirname(__file__))
+ argv = [
+ "waitress-serve",
+ "fixtureapps.runner:app",
+ ]
+ self.assertEqual(runner.run(argv=argv, _serve=null_serve), 0)
+ finally:
+ sys.path = sys_path
+ os.chdir(current_dir)
+
+ def test_bad_app_object(self):
+ self.match_output(
+ ["waitress.tests.fixtureapps.runner:a"], 1, "^Error: Bad object name 'a'"
+ )
+
+ def test_simple_call(self):
+ import waitress.tests.fixtureapps.runner as _apps
+
+ def check_server(app, **kw):
+ self.assertIs(app, _apps.app)
+ self.assertDictEqual(kw, {"port": "80"})
+
+ argv = [
+ "waitress-serve",
+ "--port=80",
+ "waitress.tests.fixtureapps.runner:app",
+ ]
+ self.assertEqual(runner.run(argv=argv, _serve=check_server), 0)
+
+ def test_returned_app(self):
+ import waitress.tests.fixtureapps.runner as _apps
+
+ def check_server(app, **kw):
+ self.assertIs(app, _apps.app)
+ self.assertDictEqual(kw, {"port": "80"})
+
+ argv = [
+ "waitress-serve",
+ "--port=80",
+ "--call",
+ "waitress.tests.fixtureapps.runner:returns_app",
+ ]
+ self.assertEqual(runner.run(argv=argv, _serve=check_server), 0)
+
+
+class Test_helper(unittest.TestCase):
+ def test_exception_logging(self):
+ from waitress.runner import show_exception
+
+ regex = (
+ r"There was an exception \(ImportError\) importing your module."
+ r"\n\nIt had these arguments: \n1. My reason"
+ )
+
+ with capture() as captured:
+ try:
+ raise ImportError("My reason")
+ except ImportError:
+ self.assertEqual(show_exception(sys.stderr), None)
+ self.assertRegexpMatches(captured.getvalue(), regex)
+ captured.close()
+
+ regex = (
+ r"There was an exception \(ImportError\) importing your module."
+ r"\n\nIt had no arguments."
+ )
+
+ with capture() as captured:
+ try:
+ raise ImportError
+ except ImportError:
+ self.assertEqual(show_exception(sys.stderr), None)
+ self.assertRegexpMatches(captured.getvalue(), regex)
+ captured.close()
+
+
+def capture():
+ from waitress.compat import NativeIO
+
+ fd = NativeIO()
+ sys.stdout = fd
+ sys.stderr = fd
+ yield fd
+ sys.stdout = sys.__stdout__
+ sys.stderr = sys.__stderr__
diff --git a/libs/waitress/tests/test_server.py b/libs/waitress/tests/test_server.py
new file mode 100644
index 000000000..9134fb8c1
--- /dev/null
+++ b/libs/waitress/tests/test_server.py
@@ -0,0 +1,533 @@
+import errno
+import socket
+import unittest
+
+dummy_app = object()
+
+
+class TestWSGIServer(unittest.TestCase):
+ def _makeOne(
+ self,
+ application=dummy_app,
+ host="127.0.0.1",
+ port=0,
+ _dispatcher=None,
+ adj=None,
+ map=None,
+ _start=True,
+ _sock=None,
+ _server=None,
+ ):
+ from waitress.server import create_server
+
+ self.inst = create_server(
+ application,
+ host=host,
+ port=port,
+ map=map,
+ _dispatcher=_dispatcher,
+ _start=_start,
+ _sock=_sock,
+ )
+ return self.inst
+
+ def _makeOneWithMap(
+ self, adj=None, _start=True, host="127.0.0.1", port=0, app=dummy_app
+ ):
+ sock = DummySock()
+ task_dispatcher = DummyTaskDispatcher()
+ map = {}
+ return self._makeOne(
+ app,
+ host=host,
+ port=port,
+ map=map,
+ _sock=sock,
+ _dispatcher=task_dispatcher,
+ _start=_start,
+ )
+
+ def _makeOneWithMulti(
+ self, adj=None, _start=True, app=dummy_app, listen="127.0.0.1:0 127.0.0.1:0"
+ ):
+ sock = DummySock()
+ task_dispatcher = DummyTaskDispatcher()
+ map = {}
+ from waitress.server import create_server
+
+ self.inst = create_server(
+ app,
+ listen=listen,
+ map=map,
+ _dispatcher=task_dispatcher,
+ _start=_start,
+ _sock=sock,
+ )
+ return self.inst
+
+ def _makeWithSockets(
+ self,
+ application=dummy_app,
+ _dispatcher=None,
+ map=None,
+ _start=True,
+ _sock=None,
+ _server=None,
+ sockets=None,
+ ):
+ from waitress.server import create_server
+
+ _sockets = []
+ if sockets is not None:
+ _sockets = sockets
+ self.inst = create_server(
+ application,
+ map=map,
+ _dispatcher=_dispatcher,
+ _start=_start,
+ _sock=_sock,
+ sockets=_sockets,
+ )
+ return self.inst
+
+ def tearDown(self):
+ if self.inst is not None:
+ self.inst.close()
+
+ def test_ctor_app_is_None(self):
+ self.inst = None
+ self.assertRaises(ValueError, self._makeOneWithMap, app=None)
+
+ def test_ctor_start_true(self):
+ inst = self._makeOneWithMap(_start=True)
+ self.assertEqual(inst.accepting, True)
+ self.assertEqual(inst.socket.listened, 1024)
+
+ def test_ctor_makes_dispatcher(self):
+ inst = self._makeOne(_start=False, map={})
+ self.assertEqual(
+ inst.task_dispatcher.__class__.__name__, "ThreadedTaskDispatcher"
+ )
+
+ def test_ctor_start_false(self):
+ inst = self._makeOneWithMap(_start=False)
+ self.assertEqual(inst.accepting, False)
+
+ def test_get_server_name_empty(self):
+ inst = self._makeOneWithMap(_start=False)
+ self.assertRaises(ValueError, inst.get_server_name, "")
+
+ def test_get_server_name_with_ip(self):
+ inst = self._makeOneWithMap(_start=False)
+ result = inst.get_server_name("127.0.0.1")
+ self.assertTrue(result)
+
+ def test_get_server_name_with_hostname(self):
+ inst = self._makeOneWithMap(_start=False)
+ result = inst.get_server_name("fred.flintstone.com")
+ self.assertEqual(result, "fred.flintstone.com")
+
+ def test_get_server_name_0000(self):
+ inst = self._makeOneWithMap(_start=False)
+ result = inst.get_server_name("0.0.0.0")
+ self.assertTrue(len(result) != 0)
+
+ def test_get_server_name_double_colon(self):
+ inst = self._makeOneWithMap(_start=False)
+ result = inst.get_server_name("::")
+ self.assertTrue(len(result) != 0)
+
+ def test_get_server_name_ipv6(self):
+ inst = self._makeOneWithMap(_start=False)
+ result = inst.get_server_name("2001:DB8::ffff")
+ self.assertEqual("[2001:DB8::ffff]", result)
+
+ def test_get_server_multi(self):
+ inst = self._makeOneWithMulti()
+ self.assertEqual(inst.__class__.__name__, "MultiSocketServer")
+
+ def test_run(self):
+ inst = self._makeOneWithMap(_start=False)
+ inst.asyncore = DummyAsyncore()
+ inst.task_dispatcher = DummyTaskDispatcher()
+ inst.run()
+ self.assertTrue(inst.task_dispatcher.was_shutdown)
+
+ def test_run_base_server(self):
+ inst = self._makeOneWithMulti(_start=False)
+ inst.asyncore = DummyAsyncore()
+ inst.task_dispatcher = DummyTaskDispatcher()
+ inst.run()
+ self.assertTrue(inst.task_dispatcher.was_shutdown)
+
+ def test_pull_trigger(self):
+ inst = self._makeOneWithMap(_start=False)
+ inst.trigger.close()
+ inst.trigger = DummyTrigger()
+ inst.pull_trigger()
+ self.assertEqual(inst.trigger.pulled, True)
+
+ def test_add_task(self):
+ task = DummyTask()
+ inst = self._makeOneWithMap()
+ inst.add_task(task)
+ self.assertEqual(inst.task_dispatcher.tasks, [task])
+ self.assertFalse(task.serviced)
+
+ def test_readable_not_accepting(self):
+ inst = self._makeOneWithMap()
+ inst.accepting = False
+ self.assertFalse(inst.readable())
+
+ def test_readable_maplen_gt_connection_limit(self):
+ inst = self._makeOneWithMap()
+ inst.accepting = True
+ inst.adj = DummyAdj
+ inst._map = {"a": 1, "b": 2}
+ self.assertFalse(inst.readable())
+
+ def test_readable_maplen_lt_connection_limit(self):
+ inst = self._makeOneWithMap()
+ inst.accepting = True
+ inst.adj = DummyAdj
+ inst._map = {}
+ self.assertTrue(inst.readable())
+
+ def test_readable_maintenance_false(self):
+ import time
+
+ inst = self._makeOneWithMap()
+ then = time.time() + 1000
+ inst.next_channel_cleanup = then
+ L = []
+ inst.maintenance = lambda t: L.append(t)
+ inst.readable()
+ self.assertEqual(L, [])
+ self.assertEqual(inst.next_channel_cleanup, then)
+
+ def test_readable_maintenance_true(self):
+ inst = self._makeOneWithMap()
+ inst.next_channel_cleanup = 0
+ L = []
+ inst.maintenance = lambda t: L.append(t)
+ inst.readable()
+ self.assertEqual(len(L), 1)
+ self.assertNotEqual(inst.next_channel_cleanup, 0)
+
+ def test_writable(self):
+ inst = self._makeOneWithMap()
+ self.assertFalse(inst.writable())
+
+ def test_handle_read(self):
+ inst = self._makeOneWithMap()
+ self.assertEqual(inst.handle_read(), None)
+
+ def test_handle_connect(self):
+ inst = self._makeOneWithMap()
+ self.assertEqual(inst.handle_connect(), None)
+
+ def test_handle_accept_wouldblock_socket_error(self):
+ inst = self._makeOneWithMap()
+ ewouldblock = socket.error(errno.EWOULDBLOCK)
+ inst.socket = DummySock(toraise=ewouldblock)
+ inst.handle_accept()
+ self.assertEqual(inst.socket.accepted, False)
+
+ def test_handle_accept_other_socket_error(self):
+ inst = self._makeOneWithMap()
+ eaborted = socket.error(errno.ECONNABORTED)
+ inst.socket = DummySock(toraise=eaborted)
+ inst.adj = DummyAdj
+
+ def foo():
+ raise socket.error
+
+ inst.accept = foo
+ inst.logger = DummyLogger()
+ inst.handle_accept()
+ self.assertEqual(inst.socket.accepted, False)
+ self.assertEqual(len(inst.logger.logged), 1)
+
+ def test_handle_accept_noerror(self):
+ inst = self._makeOneWithMap()
+ innersock = DummySock()
+ inst.socket = DummySock(acceptresult=(innersock, None))
+ inst.adj = DummyAdj
+ L = []
+ inst.channel_class = lambda *arg, **kw: L.append(arg)
+ inst.handle_accept()
+ self.assertEqual(inst.socket.accepted, True)
+ self.assertEqual(innersock.opts, [("level", "optname", "value")])
+ self.assertEqual(L, [(inst, innersock, None, inst.adj)])
+
+ def test_maintenance(self):
+ inst = self._makeOneWithMap()
+
+ class DummyChannel(object):
+ requests = []
+
+ zombie = DummyChannel()
+ zombie.last_activity = 0
+ zombie.running_tasks = False
+ inst.active_channels[100] = zombie
+ inst.maintenance(10000)
+ self.assertEqual(zombie.will_close, True)
+
+ def test_backward_compatibility(self):
+ from waitress.server import WSGIServer, TcpWSGIServer
+ from waitress.adjustments import Adjustments
+
+ self.assertTrue(WSGIServer is TcpWSGIServer)
+ self.inst = WSGIServer(None, _start=False, port=1234)
+ # Ensure the adjustment was actually applied.
+ self.assertNotEqual(Adjustments.port, 1234)
+ self.assertEqual(self.inst.adj.port, 1234)
+
+ def test_create_with_one_tcp_socket(self):
+ from waitress.server import TcpWSGIServer
+
+ sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)]
+ sockets[0].bind(("127.0.0.1", 0))
+ inst = self._makeWithSockets(_start=False, sockets=sockets)
+ self.assertTrue(isinstance(inst, TcpWSGIServer))
+
+ def test_create_with_multiple_tcp_sockets(self):
+ from waitress.server import MultiSocketServer
+
+ sockets = [
+ socket.socket(socket.AF_INET, socket.SOCK_STREAM),
+ socket.socket(socket.AF_INET, socket.SOCK_STREAM),
+ ]
+ sockets[0].bind(("127.0.0.1", 0))
+ sockets[1].bind(("127.0.0.1", 0))
+ inst = self._makeWithSockets(_start=False, sockets=sockets)
+ self.assertTrue(isinstance(inst, MultiSocketServer))
+ self.assertEqual(len(inst.effective_listen), 2)
+
+ def test_create_with_one_socket_should_not_bind_socket(self):
+ innersock = DummySock()
+ sockets = [DummySock(acceptresult=(innersock, None))]
+ sockets[0].bind(("127.0.0.1", 80))
+ sockets[0].bind_called = False
+ inst = self._makeWithSockets(_start=False, sockets=sockets)
+ self.assertEqual(inst.socket.bound, ("127.0.0.1", 80))
+ self.assertFalse(inst.socket.bind_called)
+
+ def test_create_with_one_socket_handle_accept_noerror(self):
+ innersock = DummySock()
+ sockets = [DummySock(acceptresult=(innersock, None))]
+ sockets[0].bind(("127.0.0.1", 80))
+ inst = self._makeWithSockets(sockets=sockets)
+ L = []
+ inst.channel_class = lambda *arg, **kw: L.append(arg)
+ inst.adj = DummyAdj
+ inst.handle_accept()
+ self.assertEqual(sockets[0].accepted, True)
+ self.assertEqual(innersock.opts, [("level", "optname", "value")])
+ self.assertEqual(L, [(inst, innersock, None, inst.adj)])
+
+
+if hasattr(socket, "AF_UNIX"):
+
+ class TestUnixWSGIServer(unittest.TestCase):
+ unix_socket = "/tmp/waitress.test.sock"
+
+ def _makeOne(self, _start=True, _sock=None):
+ from waitress.server import create_server
+
+ self.inst = create_server(
+ dummy_app,
+ map={},
+ _start=_start,
+ _sock=_sock,
+ _dispatcher=DummyTaskDispatcher(),
+ unix_socket=self.unix_socket,
+ unix_socket_perms="600",
+ )
+ return self.inst
+
+ def _makeWithSockets(
+ self,
+ application=dummy_app,
+ _dispatcher=None,
+ map=None,
+ _start=True,
+ _sock=None,
+ _server=None,
+ sockets=None,
+ ):
+ from waitress.server import create_server
+
+ _sockets = []
+ if sockets is not None:
+ _sockets = sockets
+ self.inst = create_server(
+ application,
+ map=map,
+ _dispatcher=_dispatcher,
+ _start=_start,
+ _sock=_sock,
+ sockets=_sockets,
+ )
+ return self.inst
+
+ def tearDown(self):
+ self.inst.close()
+
+ def _makeDummy(self, *args, **kwargs):
+ sock = DummySock(*args, **kwargs)
+ sock.family = socket.AF_UNIX
+ return sock
+
+ def test_unix(self):
+ inst = self._makeOne(_start=False)
+ self.assertEqual(inst.socket.family, socket.AF_UNIX)
+ self.assertEqual(inst.socket.getsockname(), self.unix_socket)
+
+ def test_handle_accept(self):
+ # Working on the assumption that we only have to test the happy path
+ # for Unix domain sockets as the other paths should've been covered
+ # by inet sockets.
+ client = self._makeDummy()
+ listen = self._makeDummy(acceptresult=(client, None))
+ inst = self._makeOne(_sock=listen)
+ self.assertEqual(inst.accepting, True)
+ self.assertEqual(inst.socket.listened, 1024)
+ L = []
+ inst.channel_class = lambda *arg, **kw: L.append(arg)
+ inst.handle_accept()
+ self.assertEqual(inst.socket.accepted, True)
+ self.assertEqual(client.opts, [])
+ self.assertEqual(L, [(inst, client, ("localhost", None), inst.adj)])
+
+ def test_creates_new_sockinfo(self):
+ from waitress.server import UnixWSGIServer
+
+ self.inst = UnixWSGIServer(
+ dummy_app, unix_socket=self.unix_socket, unix_socket_perms="600"
+ )
+
+ self.assertEqual(self.inst.sockinfo[0], socket.AF_UNIX)
+
+ def test_create_with_unix_socket(self):
+ from waitress.server import (
+ MultiSocketServer,
+ BaseWSGIServer,
+ TcpWSGIServer,
+ UnixWSGIServer,
+ )
+
+ sockets = [
+ socket.socket(socket.AF_UNIX, socket.SOCK_STREAM),
+ socket.socket(socket.AF_UNIX, socket.SOCK_STREAM),
+ ]
+ inst = self._makeWithSockets(sockets=sockets, _start=False)
+ self.assertTrue(isinstance(inst, MultiSocketServer))
+ server = list(
+ filter(lambda s: isinstance(s, BaseWSGIServer), inst.map.values())
+ )
+ self.assertTrue(isinstance(server[0], UnixWSGIServer))
+ self.assertTrue(isinstance(server[1], UnixWSGIServer))
+
+
+class DummySock(socket.socket):
+ accepted = False
+ blocking = False
+ family = socket.AF_INET
+ type = socket.SOCK_STREAM
+ proto = 0
+
+ def __init__(self, toraise=None, acceptresult=(None, None)):
+ self.toraise = toraise
+ self.acceptresult = acceptresult
+ self.bound = None
+ self.opts = []
+ self.bind_called = False
+
+ def bind(self, addr):
+ self.bind_called = True
+ self.bound = addr
+
+ def accept(self):
+ if self.toraise:
+ raise self.toraise
+ self.accepted = True
+ return self.acceptresult
+
+ def setblocking(self, x):
+ self.blocking = True
+
+ def fileno(self):
+ return 10
+
+ def getpeername(self):
+ return "127.0.0.1"
+
+ def setsockopt(self, *arg):
+ self.opts.append(arg)
+
+ def getsockopt(self, *arg):
+ return 1
+
+ def listen(self, num):
+ self.listened = num
+
+ def getsockname(self):
+ return self.bound
+
+ def close(self):
+ pass
+
+
+class DummyTaskDispatcher(object):
+ def __init__(self):
+ self.tasks = []
+
+ def add_task(self, task):
+ self.tasks.append(task)
+
+ def shutdown(self):
+ self.was_shutdown = True
+
+
+class DummyTask(object):
+ serviced = False
+ start_response_called = False
+ wrote_header = False
+ status = "200 OK"
+
+ def __init__(self):
+ self.response_headers = {}
+ self.written = ""
+
+ def service(self): # pragma: no cover
+ self.serviced = True
+
+
+class DummyAdj:
+ connection_limit = 1
+ log_socket_errors = True
+ socket_options = [("level", "optname", "value")]
+ cleanup_interval = 900
+ channel_timeout = 300
+
+
+class DummyAsyncore(object):
+ def loop(self, timeout=30.0, use_poll=False, map=None, count=None):
+ raise SystemExit
+
+
+class DummyTrigger(object):
+ def pull_trigger(self):
+ self.pulled = True
+
+ def close(self):
+ pass
+
+
+class DummyLogger(object):
+ def __init__(self):
+ self.logged = []
+
+ def warning(self, msg, **kw):
+ self.logged.append(msg)
diff --git a/libs/waitress/tests/test_task.py b/libs/waitress/tests/test_task.py
new file mode 100644
index 000000000..1a86245ab
--- /dev/null
+++ b/libs/waitress/tests/test_task.py
@@ -0,0 +1,1001 @@
+import unittest
+import io
+
+
+class TestThreadedTaskDispatcher(unittest.TestCase):
+ def _makeOne(self):
+ from waitress.task import ThreadedTaskDispatcher
+
+ return ThreadedTaskDispatcher()
+
+ def test_handler_thread_task_raises(self):
+ inst = self._makeOne()
+ inst.threads.add(0)
+ inst.logger = DummyLogger()
+
+ class BadDummyTask(DummyTask):
+ def service(self):
+ super(BadDummyTask, self).service()
+ inst.stop_count += 1
+ raise Exception
+
+ task = BadDummyTask()
+ inst.logger = DummyLogger()
+ inst.queue.append(task)
+ inst.active_count += 1
+ inst.handler_thread(0)
+ self.assertEqual(inst.stop_count, 0)
+ self.assertEqual(inst.active_count, 0)
+ self.assertEqual(inst.threads, set())
+ self.assertEqual(len(inst.logger.logged), 1)
+
+ def test_set_thread_count_increase(self):
+ inst = self._makeOne()
+ L = []
+ inst.start_new_thread = lambda *x: L.append(x)
+ inst.set_thread_count(1)
+ self.assertEqual(L, [(inst.handler_thread, (0,))])
+
+ def test_set_thread_count_increase_with_existing(self):
+ inst = self._makeOne()
+ L = []
+ inst.threads = {0}
+ inst.start_new_thread = lambda *x: L.append(x)
+ inst.set_thread_count(2)
+ self.assertEqual(L, [(inst.handler_thread, (1,))])
+
+ def test_set_thread_count_decrease(self):
+ inst = self._makeOne()
+ inst.threads = {0, 1}
+ inst.set_thread_count(1)
+ self.assertEqual(inst.stop_count, 1)
+
+ def test_set_thread_count_same(self):
+ inst = self._makeOne()
+ L = []
+ inst.start_new_thread = lambda *x: L.append(x)
+ inst.threads = {0}
+ inst.set_thread_count(1)
+ self.assertEqual(L, [])
+
+ def test_add_task_with_idle_threads(self):
+ task = DummyTask()
+ inst = self._makeOne()
+ inst.threads.add(0)
+ inst.queue_logger = DummyLogger()
+ inst.add_task(task)
+ self.assertEqual(len(inst.queue), 1)
+ self.assertEqual(len(inst.queue_logger.logged), 0)
+
+ def test_add_task_with_all_busy_threads(self):
+ task = DummyTask()
+ inst = self._makeOne()
+ inst.queue_logger = DummyLogger()
+ inst.add_task(task)
+ self.assertEqual(len(inst.queue_logger.logged), 1)
+ inst.add_task(task)
+ self.assertEqual(len(inst.queue_logger.logged), 2)
+
+ def test_shutdown_one_thread(self):
+ inst = self._makeOne()
+ inst.threads.add(0)
+ inst.logger = DummyLogger()
+ task = DummyTask()
+ inst.queue.append(task)
+ self.assertEqual(inst.shutdown(timeout=0.01), True)
+ self.assertEqual(
+ inst.logger.logged,
+ ["1 thread(s) still running", "Canceling 1 pending task(s)",],
+ )
+ self.assertEqual(task.cancelled, True)
+
+ def test_shutdown_no_threads(self):
+ inst = self._makeOne()
+ self.assertEqual(inst.shutdown(timeout=0.01), True)
+
+ def test_shutdown_no_cancel_pending(self):
+ inst = self._makeOne()
+ self.assertEqual(inst.shutdown(cancel_pending=False, timeout=0.01), False)
+
+
+class TestTask(unittest.TestCase):
+ def _makeOne(self, channel=None, request=None):
+ if channel is None:
+ channel = DummyChannel()
+ if request is None:
+ request = DummyParser()
+ from waitress.task import Task
+
+ return Task(channel, request)
+
+ def test_ctor_version_not_in_known(self):
+ request = DummyParser()
+ request.version = "8.4"
+ inst = self._makeOne(request=request)
+ self.assertEqual(inst.version, "1.0")
+
+ def test_build_response_header_bad_http_version(self):
+ inst = self._makeOne()
+ inst.request = DummyParser()
+ inst.version = "8.4"
+ self.assertRaises(AssertionError, inst.build_response_header)
+
+ def test_build_response_header_v10_keepalive_no_content_length(self):
+ inst = self._makeOne()
+ inst.request = DummyParser()
+ inst.request.headers["CONNECTION"] = "keep-alive"
+ inst.version = "1.0"
+ result = inst.build_response_header()
+ lines = filter_lines(result)
+ self.assertEqual(len(lines), 4)
+ self.assertEqual(lines[0], b"HTTP/1.0 200 OK")
+ self.assertEqual(lines[1], b"Connection: close")
+ self.assertTrue(lines[2].startswith(b"Date:"))
+ self.assertEqual(lines[3], b"Server: waitress")
+ self.assertEqual(inst.close_on_finish, True)
+ self.assertTrue(("Connection", "close") in inst.response_headers)
+
+ def test_build_response_header_v10_keepalive_with_content_length(self):
+ inst = self._makeOne()
+ inst.request = DummyParser()
+ inst.request.headers["CONNECTION"] = "keep-alive"
+ inst.response_headers = [("Content-Length", "10")]
+ inst.version = "1.0"
+ inst.content_length = 0
+ result = inst.build_response_header()
+ lines = filter_lines(result)
+ self.assertEqual(len(lines), 5)
+ self.assertEqual(lines[0], b"HTTP/1.0 200 OK")
+ self.assertEqual(lines[1], b"Connection: Keep-Alive")
+ self.assertEqual(lines[2], b"Content-Length: 10")
+ self.assertTrue(lines[3].startswith(b"Date:"))
+ self.assertEqual(lines[4], b"Server: waitress")
+ self.assertEqual(inst.close_on_finish, False)
+
+ def test_build_response_header_v11_connection_closed_by_client(self):
+ inst = self._makeOne()
+ inst.request = DummyParser()
+ inst.version = "1.1"
+ inst.request.headers["CONNECTION"] = "close"
+ result = inst.build_response_header()
+ lines = filter_lines(result)
+ self.assertEqual(len(lines), 5)
+ self.assertEqual(lines[0], b"HTTP/1.1 200 OK")
+ self.assertEqual(lines[1], b"Connection: close")
+ self.assertTrue(lines[2].startswith(b"Date:"))
+ self.assertEqual(lines[3], b"Server: waitress")
+ self.assertEqual(lines[4], b"Transfer-Encoding: chunked")
+ self.assertTrue(("Connection", "close") in inst.response_headers)
+ self.assertEqual(inst.close_on_finish, True)
+
+ def test_build_response_header_v11_connection_keepalive_by_client(self):
+ inst = self._makeOne()
+ inst.request = DummyParser()
+ inst.request.headers["CONNECTION"] = "keep-alive"
+ inst.version = "1.1"
+ result = inst.build_response_header()
+ lines = filter_lines(result)
+ self.assertEqual(len(lines), 5)
+ self.assertEqual(lines[0], b"HTTP/1.1 200 OK")
+ self.assertEqual(lines[1], b"Connection: close")
+ self.assertTrue(lines[2].startswith(b"Date:"))
+ self.assertEqual(lines[3], b"Server: waitress")
+ self.assertEqual(lines[4], b"Transfer-Encoding: chunked")
+ self.assertTrue(("Connection", "close") in inst.response_headers)
+ self.assertEqual(inst.close_on_finish, True)
+
+ def test_build_response_header_v11_200_no_content_length(self):
+ inst = self._makeOne()
+ inst.request = DummyParser()
+ inst.version = "1.1"
+ result = inst.build_response_header()
+ lines = filter_lines(result)
+ self.assertEqual(len(lines), 5)
+ self.assertEqual(lines[0], b"HTTP/1.1 200 OK")
+ self.assertEqual(lines[1], b"Connection: close")
+ self.assertTrue(lines[2].startswith(b"Date:"))
+ self.assertEqual(lines[3], b"Server: waitress")
+ self.assertEqual(lines[4], b"Transfer-Encoding: chunked")
+ self.assertEqual(inst.close_on_finish, True)
+ self.assertTrue(("Connection", "close") in inst.response_headers)
+
+ def test_build_response_header_v11_204_no_content_length_or_transfer_encoding(self):
+ # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length
+ # for any response with a status code of 1xx or 204.
+ inst = self._makeOne()
+ inst.request = DummyParser()
+ inst.version = "1.1"
+ inst.status = "204 No Content"
+ result = inst.build_response_header()
+ lines = filter_lines(result)
+ self.assertEqual(len(lines), 4)
+ self.assertEqual(lines[0], b"HTTP/1.1 204 No Content")
+ self.assertEqual(lines[1], b"Connection: close")
+ self.assertTrue(lines[2].startswith(b"Date:"))
+ self.assertEqual(lines[3], b"Server: waitress")
+ self.assertEqual(inst.close_on_finish, True)
+ self.assertTrue(("Connection", "close") in inst.response_headers)
+
+ def test_build_response_header_v11_1xx_no_content_length_or_transfer_encoding(self):
+ # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length
+ # for any response with a status code of 1xx or 204.
+ inst = self._makeOne()
+ inst.request = DummyParser()
+ inst.version = "1.1"
+ inst.status = "100 Continue"
+ result = inst.build_response_header()
+ lines = filter_lines(result)
+ self.assertEqual(len(lines), 4)
+ self.assertEqual(lines[0], b"HTTP/1.1 100 Continue")
+ self.assertEqual(lines[1], b"Connection: close")
+ self.assertTrue(lines[2].startswith(b"Date:"))
+ self.assertEqual(lines[3], b"Server: waitress")
+ self.assertEqual(inst.close_on_finish, True)
+ self.assertTrue(("Connection", "close") in inst.response_headers)
+
+ def test_build_response_header_v11_304_no_content_length_or_transfer_encoding(self):
+ # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length
+ # for any response with a status code of 1xx, 204 or 304.
+ inst = self._makeOne()
+ inst.request = DummyParser()
+ inst.version = "1.1"
+ inst.status = "304 Not Modified"
+ result = inst.build_response_header()
+ lines = filter_lines(result)
+ self.assertEqual(len(lines), 4)
+ self.assertEqual(lines[0], b"HTTP/1.1 304 Not Modified")
+ self.assertEqual(lines[1], b"Connection: close")
+ self.assertTrue(lines[2].startswith(b"Date:"))
+ self.assertEqual(lines[3], b"Server: waitress")
+ self.assertEqual(inst.close_on_finish, True)
+ self.assertTrue(("Connection", "close") in inst.response_headers)
+
+ def test_build_response_header_via_added(self):
+ inst = self._makeOne()
+ inst.request = DummyParser()
+ inst.version = "1.0"
+ inst.response_headers = [("Server", "abc")]
+ result = inst.build_response_header()
+ lines = filter_lines(result)
+ self.assertEqual(len(lines), 5)
+ self.assertEqual(lines[0], b"HTTP/1.0 200 OK")
+ self.assertEqual(lines[1], b"Connection: close")
+ self.assertTrue(lines[2].startswith(b"Date:"))
+ self.assertEqual(lines[3], b"Server: abc")
+ self.assertEqual(lines[4], b"Via: waitress")
+
+ def test_build_response_header_date_exists(self):
+ inst = self._makeOne()
+ inst.request = DummyParser()
+ inst.version = "1.0"
+ inst.response_headers = [("Date", "date")]
+ result = inst.build_response_header()
+ lines = filter_lines(result)
+ self.assertEqual(len(lines), 4)
+ self.assertEqual(lines[0], b"HTTP/1.0 200 OK")
+ self.assertEqual(lines[1], b"Connection: close")
+ self.assertTrue(lines[2].startswith(b"Date:"))
+ self.assertEqual(lines[3], b"Server: waitress")
+
+ def test_build_response_header_preexisting_content_length(self):
+ inst = self._makeOne()
+ inst.request = DummyParser()
+ inst.version = "1.1"
+ inst.content_length = 100
+ result = inst.build_response_header()
+ lines = filter_lines(result)
+ self.assertEqual(len(lines), 4)
+ self.assertEqual(lines[0], b"HTTP/1.1 200 OK")
+ self.assertEqual(lines[1], b"Content-Length: 100")
+ self.assertTrue(lines[2].startswith(b"Date:"))
+ self.assertEqual(lines[3], b"Server: waitress")
+
+ def test_remove_content_length_header(self):
+ inst = self._makeOne()
+ inst.response_headers = [("Content-Length", "70")]
+ inst.remove_content_length_header()
+ self.assertEqual(inst.response_headers, [])
+
+ def test_remove_content_length_header_with_other(self):
+ inst = self._makeOne()
+ inst.response_headers = [
+ ("Content-Length", "70"),
+ ("Content-Type", "text/html"),
+ ]
+ inst.remove_content_length_header()
+ self.assertEqual(inst.response_headers, [("Content-Type", "text/html")])
+
+ def test_start(self):
+ inst = self._makeOne()
+ inst.start()
+ self.assertTrue(inst.start_time)
+
+ def test_finish_didnt_write_header(self):
+ inst = self._makeOne()
+ inst.wrote_header = False
+ inst.complete = True
+ inst.finish()
+ self.assertTrue(inst.channel.written)
+
+ def test_finish_wrote_header(self):
+ inst = self._makeOne()
+ inst.wrote_header = True
+ inst.finish()
+ self.assertFalse(inst.channel.written)
+
+ def test_finish_chunked_response(self):
+ inst = self._makeOne()
+ inst.wrote_header = True
+ inst.chunked_response = True
+ inst.finish()
+ self.assertEqual(inst.channel.written, b"0\r\n\r\n")
+
+ def test_write_wrote_header(self):
+ inst = self._makeOne()
+ inst.wrote_header = True
+ inst.complete = True
+ inst.content_length = 3
+ inst.write(b"abc")
+ self.assertEqual(inst.channel.written, b"abc")
+
+ def test_write_header_not_written(self):
+ inst = self._makeOne()
+ inst.wrote_header = False
+ inst.complete = True
+ inst.write(b"abc")
+ self.assertTrue(inst.channel.written)
+ self.assertEqual(inst.wrote_header, True)
+
+ def test_write_start_response_uncalled(self):
+ inst = self._makeOne()
+ self.assertRaises(RuntimeError, inst.write, b"")
+
+ def test_write_chunked_response(self):
+ inst = self._makeOne()
+ inst.wrote_header = True
+ inst.chunked_response = True
+ inst.complete = True
+ inst.write(b"abc")
+ self.assertEqual(inst.channel.written, b"3\r\nabc\r\n")
+
+ def test_write_preexisting_content_length(self):
+ inst = self._makeOne()
+ inst.wrote_header = True
+ inst.complete = True
+ inst.content_length = 1
+ inst.logger = DummyLogger()
+ inst.write(b"abc")
+ self.assertTrue(inst.channel.written)
+ self.assertEqual(inst.logged_write_excess, True)
+ self.assertEqual(len(inst.logger.logged), 1)
+
+
+class TestWSGITask(unittest.TestCase):
+ def _makeOne(self, channel=None, request=None):
+ if channel is None:
+ channel = DummyChannel()
+ if request is None:
+ request = DummyParser()
+ from waitress.task import WSGITask
+
+ return WSGITask(channel, request)
+
+ def test_service(self):
+ inst = self._makeOne()
+
+ def execute():
+ inst.executed = True
+
+ inst.execute = execute
+ inst.complete = True
+ inst.service()
+ self.assertTrue(inst.start_time)
+ self.assertTrue(inst.close_on_finish)
+ self.assertTrue(inst.channel.written)
+ self.assertEqual(inst.executed, True)
+
+ def test_service_server_raises_socket_error(self):
+ import socket
+
+ inst = self._makeOne()
+
+ def execute():
+ raise socket.error
+
+ inst.execute = execute
+ self.assertRaises(socket.error, inst.service)
+ self.assertTrue(inst.start_time)
+ self.assertTrue(inst.close_on_finish)
+ self.assertFalse(inst.channel.written)
+
+ def test_execute_app_calls_start_response_twice_wo_exc_info(self):
+ def app(environ, start_response):
+ start_response("200 OK", [])
+ start_response("200 OK", [])
+
+ inst = self._makeOne()
+ inst.channel.server.application = app
+ self.assertRaises(AssertionError, inst.execute)
+
+ def test_execute_app_calls_start_response_w_exc_info_complete(self):
+ def app(environ, start_response):
+ start_response("200 OK", [], [ValueError, ValueError(), None])
+ return [b"a"]
+
+ inst = self._makeOne()
+ inst.complete = True
+ inst.channel.server.application = app
+ inst.execute()
+ self.assertTrue(inst.complete)
+ self.assertEqual(inst.status, "200 OK")
+ self.assertTrue(inst.channel.written)
+
+ def test_execute_app_calls_start_response_w_excinf_headers_unwritten(self):
+ def app(environ, start_response):
+ start_response("200 OK", [], [ValueError, None, None])
+ return [b"a"]
+
+ inst = self._makeOne()
+ inst.wrote_header = False
+ inst.channel.server.application = app
+ inst.response_headers = [("a", "b")]
+ inst.execute()
+ self.assertTrue(inst.complete)
+ self.assertEqual(inst.status, "200 OK")
+ self.assertTrue(inst.channel.written)
+ self.assertFalse(("a", "b") in inst.response_headers)
+
+ def test_execute_app_calls_start_response_w_excinf_headers_written(self):
+ def app(environ, start_response):
+ start_response("200 OK", [], [ValueError, ValueError(), None])
+
+ inst = self._makeOne()
+ inst.complete = True
+ inst.wrote_header = True
+ inst.channel.server.application = app
+ self.assertRaises(ValueError, inst.execute)
+
+ def test_execute_bad_header_key(self):
+ def app(environ, start_response):
+ start_response("200 OK", [(None, "a")])
+
+ inst = self._makeOne()
+ inst.channel.server.application = app
+ self.assertRaises(AssertionError, inst.execute)
+
+ def test_execute_bad_header_value(self):
+ def app(environ, start_response):
+ start_response("200 OK", [("a", None)])
+
+ inst = self._makeOne()
+ inst.channel.server.application = app
+ self.assertRaises(AssertionError, inst.execute)
+
+ def test_execute_hopbyhop_header(self):
+ def app(environ, start_response):
+ start_response("200 OK", [("Connection", "close")])
+
+ inst = self._makeOne()
+ inst.channel.server.application = app
+ self.assertRaises(AssertionError, inst.execute)
+
+ def test_execute_bad_header_value_control_characters(self):
+ def app(environ, start_response):
+ start_response("200 OK", [("a", "\n")])
+
+ inst = self._makeOne()
+ inst.channel.server.application = app
+ self.assertRaises(ValueError, inst.execute)
+
+ def test_execute_bad_header_name_control_characters(self):
+ def app(environ, start_response):
+ start_response("200 OK", [("a\r", "value")])
+
+ inst = self._makeOne()
+ inst.channel.server.application = app
+ self.assertRaises(ValueError, inst.execute)
+
+ def test_execute_bad_status_control_characters(self):
+ def app(environ, start_response):
+ start_response("200 OK\r", [])
+
+ inst = self._makeOne()
+ inst.channel.server.application = app
+ self.assertRaises(ValueError, inst.execute)
+
+ def test_preserve_header_value_order(self):
+ def app(environ, start_response):
+ write = start_response("200 OK", [("C", "b"), ("A", "b"), ("A", "a")])
+ write(b"abc")
+ return []
+
+ inst = self._makeOne()
+ inst.channel.server.application = app
+ inst.execute()
+ self.assertTrue(b"A: b\r\nA: a\r\nC: b\r\n" in inst.channel.written)
+
+ def test_execute_bad_status_value(self):
+ def app(environ, start_response):
+ start_response(None, [])
+
+ inst = self._makeOne()
+ inst.channel.server.application = app
+ self.assertRaises(AssertionError, inst.execute)
+
+ def test_execute_with_content_length_header(self):
+ def app(environ, start_response):
+ start_response("200 OK", [("Content-Length", "1")])
+ return [b"a"]
+
+ inst = self._makeOne()
+ inst.channel.server.application = app
+ inst.execute()
+ self.assertEqual(inst.content_length, 1)
+
+ def test_execute_app_calls_write(self):
+ def app(environ, start_response):
+ write = start_response("200 OK", [("Content-Length", "3")])
+ write(b"abc")
+ return []
+
+ inst = self._makeOne()
+ inst.channel.server.application = app
+ inst.execute()
+ self.assertEqual(inst.channel.written[-3:], b"abc")
+
+ def test_execute_app_returns_len1_chunk_without_cl(self):
+ def app(environ, start_response):
+ start_response("200 OK", [])
+ return [b"abc"]
+
+ inst = self._makeOne()
+ inst.channel.server.application = app
+ inst.execute()
+ self.assertEqual(inst.content_length, 3)
+
+ def test_execute_app_returns_empty_chunk_as_first(self):
+ def app(environ, start_response):
+ start_response("200 OK", [])
+ return ["", b"abc"]
+
+ inst = self._makeOne()
+ inst.channel.server.application = app
+ inst.execute()
+ self.assertEqual(inst.content_length, None)
+
+ def test_execute_app_returns_too_many_bytes(self):
+ def app(environ, start_response):
+ start_response("200 OK", [("Content-Length", "1")])
+ return [b"abc"]
+
+ inst = self._makeOne()
+ inst.channel.server.application = app
+ inst.logger = DummyLogger()
+ inst.execute()
+ self.assertEqual(inst.close_on_finish, True)
+ self.assertEqual(len(inst.logger.logged), 1)
+
+ def test_execute_app_returns_too_few_bytes(self):
+ def app(environ, start_response):
+ start_response("200 OK", [("Content-Length", "3")])
+ return [b"a"]
+
+ inst = self._makeOne()
+ inst.channel.server.application = app
+ inst.logger = DummyLogger()
+ inst.execute()
+ self.assertEqual(inst.close_on_finish, True)
+ self.assertEqual(len(inst.logger.logged), 1)
+
+ def test_execute_app_do_not_warn_on_head(self):
+ def app(environ, start_response):
+ start_response("200 OK", [("Content-Length", "3")])
+ return [b""]
+
+ inst = self._makeOne()
+ inst.request.command = "HEAD"
+ inst.channel.server.application = app
+ inst.logger = DummyLogger()
+ inst.execute()
+ self.assertEqual(inst.close_on_finish, True)
+ self.assertEqual(len(inst.logger.logged), 0)
+
+ def test_execute_app_without_body_204_logged(self):
+ def app(environ, start_response):
+ start_response("204 No Content", [("Content-Length", "3")])
+ return [b"abc"]
+
+ inst = self._makeOne()
+ inst.channel.server.application = app
+ inst.logger = DummyLogger()
+ inst.execute()
+ self.assertEqual(inst.close_on_finish, True)
+ self.assertNotIn(b"abc", inst.channel.written)
+ self.assertNotIn(b"Content-Length", inst.channel.written)
+ self.assertNotIn(b"Transfer-Encoding", inst.channel.written)
+ self.assertEqual(len(inst.logger.logged), 1)
+
+ def test_execute_app_without_body_304_logged(self):
+ def app(environ, start_response):
+ start_response("304 Not Modified", [("Content-Length", "3")])
+ return [b"abc"]
+
+ inst = self._makeOne()
+ inst.channel.server.application = app
+ inst.logger = DummyLogger()
+ inst.execute()
+ self.assertEqual(inst.close_on_finish, True)
+ self.assertNotIn(b"abc", inst.channel.written)
+ self.assertNotIn(b"Content-Length", inst.channel.written)
+ self.assertNotIn(b"Transfer-Encoding", inst.channel.written)
+ self.assertEqual(len(inst.logger.logged), 1)
+
+ def test_execute_app_returns_closeable(self):
+ class closeable(list):
+ def close(self):
+ self.closed = True
+
+ foo = closeable([b"abc"])
+
+ def app(environ, start_response):
+ start_response("200 OK", [("Content-Length", "3")])
+ return foo
+
+ inst = self._makeOne()
+ inst.channel.server.application = app
+ inst.execute()
+ self.assertEqual(foo.closed, True)
+
+ def test_execute_app_returns_filewrapper_prepare_returns_True(self):
+ from waitress.buffers import ReadOnlyFileBasedBuffer
+
+ f = io.BytesIO(b"abc")
+ app_iter = ReadOnlyFileBasedBuffer(f, 8192)
+
+ def app(environ, start_response):
+ start_response("200 OK", [("Content-Length", "3")])
+ return app_iter
+
+ inst = self._makeOne()
+ inst.channel.server.application = app
+ inst.execute()
+ self.assertTrue(inst.channel.written) # header
+ self.assertEqual(inst.channel.otherdata, [app_iter])
+
+ def test_execute_app_returns_filewrapper_prepare_returns_True_nocl(self):
+ from waitress.buffers import ReadOnlyFileBasedBuffer
+
+ f = io.BytesIO(b"abc")
+ app_iter = ReadOnlyFileBasedBuffer(f, 8192)
+
+ def app(environ, start_response):
+ start_response("200 OK", [])
+ return app_iter
+
+ inst = self._makeOne()
+ inst.channel.server.application = app
+ inst.execute()
+ self.assertTrue(inst.channel.written) # header
+ self.assertEqual(inst.channel.otherdata, [app_iter])
+ self.assertEqual(inst.content_length, 3)
+
+ def test_execute_app_returns_filewrapper_prepare_returns_True_badcl(self):
+ from waitress.buffers import ReadOnlyFileBasedBuffer
+
+ f = io.BytesIO(b"abc")
+ app_iter = ReadOnlyFileBasedBuffer(f, 8192)
+
+ def app(environ, start_response):
+ start_response("200 OK", [])
+ return app_iter
+
+ inst = self._makeOne()
+ inst.channel.server.application = app
+ inst.content_length = 10
+ inst.response_headers = [("Content-Length", "10")]
+ inst.execute()
+ self.assertTrue(inst.channel.written) # header
+ self.assertEqual(inst.channel.otherdata, [app_iter])
+ self.assertEqual(inst.content_length, 3)
+ self.assertEqual(dict(inst.response_headers)["Content-Length"], "3")
+
+ def test_get_environment_already_cached(self):
+ inst = self._makeOne()
+ inst.environ = object()
+ self.assertEqual(inst.get_environment(), inst.environ)
+
+ def test_get_environment_path_startswith_more_than_one_slash(self):
+ inst = self._makeOne()
+ request = DummyParser()
+ request.path = "///abc"
+ inst.request = request
+ environ = inst.get_environment()
+ self.assertEqual(environ["PATH_INFO"], "/abc")
+
+ def test_get_environment_path_empty(self):
+ inst = self._makeOne()
+ request = DummyParser()
+ request.path = ""
+ inst.request = request
+ environ = inst.get_environment()
+ self.assertEqual(environ["PATH_INFO"], "")
+
+ def test_get_environment_no_query(self):
+ inst = self._makeOne()
+ request = DummyParser()
+ inst.request = request
+ environ = inst.get_environment()
+ self.assertEqual(environ["QUERY_STRING"], "")
+
+ def test_get_environment_with_query(self):
+ inst = self._makeOne()
+ request = DummyParser()
+ request.query = "abc"
+ inst.request = request
+ environ = inst.get_environment()
+ self.assertEqual(environ["QUERY_STRING"], "abc")
+
+ def test_get_environ_with_url_prefix_miss(self):
+ inst = self._makeOne()
+ inst.channel.server.adj.url_prefix = "/foo"
+ request = DummyParser()
+ request.path = "/bar"
+ inst.request = request
+ environ = inst.get_environment()
+ self.assertEqual(environ["PATH_INFO"], "/bar")
+ self.assertEqual(environ["SCRIPT_NAME"], "/foo")
+
+ def test_get_environ_with_url_prefix_hit(self):
+ inst = self._makeOne()
+ inst.channel.server.adj.url_prefix = "/foo"
+ request = DummyParser()
+ request.path = "/foo/fuz"
+ inst.request = request
+ environ = inst.get_environment()
+ self.assertEqual(environ["PATH_INFO"], "/fuz")
+ self.assertEqual(environ["SCRIPT_NAME"], "/foo")
+
+ def test_get_environ_with_url_prefix_empty_path(self):
+ inst = self._makeOne()
+ inst.channel.server.adj.url_prefix = "/foo"
+ request = DummyParser()
+ request.path = "/foo"
+ inst.request = request
+ environ = inst.get_environment()
+ self.assertEqual(environ["PATH_INFO"], "")
+ self.assertEqual(environ["SCRIPT_NAME"], "/foo")
+
+ def test_get_environment_values(self):
+ import sys
+
+ inst = self._makeOne()
+ request = DummyParser()
+ request.headers = {
+ "CONTENT_TYPE": "abc",
+ "CONTENT_LENGTH": "10",
+ "X_FOO": "BAR",
+ "CONNECTION": "close",
+ }
+ request.query = "abc"
+ inst.request = request
+ environ = inst.get_environment()
+
+ # nail the keys of environ
+ self.assertEqual(
+ sorted(environ.keys()),
+ [
+ "CONTENT_LENGTH",
+ "CONTENT_TYPE",
+ "HTTP_CONNECTION",
+ "HTTP_X_FOO",
+ "PATH_INFO",
+ "QUERY_STRING",
+ "REMOTE_ADDR",
+ "REMOTE_HOST",
+ "REMOTE_PORT",
+ "REQUEST_METHOD",
+ "SCRIPT_NAME",
+ "SERVER_NAME",
+ "SERVER_PORT",
+ "SERVER_PROTOCOL",
+ "SERVER_SOFTWARE",
+ "wsgi.errors",
+ "wsgi.file_wrapper",
+ "wsgi.input",
+ "wsgi.input_terminated",
+ "wsgi.multiprocess",
+ "wsgi.multithread",
+ "wsgi.run_once",
+ "wsgi.url_scheme",
+ "wsgi.version",
+ ],
+ )
+
+ self.assertEqual(environ["REQUEST_METHOD"], "GET")
+ self.assertEqual(environ["SERVER_PORT"], "80")
+ self.assertEqual(environ["SERVER_NAME"], "localhost")
+ self.assertEqual(environ["SERVER_SOFTWARE"], "waitress")
+ self.assertEqual(environ["SERVER_PROTOCOL"], "HTTP/1.0")
+ self.assertEqual(environ["SCRIPT_NAME"], "")
+ self.assertEqual(environ["HTTP_CONNECTION"], "close")
+ self.assertEqual(environ["PATH_INFO"], "/")
+ self.assertEqual(environ["QUERY_STRING"], "abc")
+ self.assertEqual(environ["REMOTE_ADDR"], "127.0.0.1")
+ self.assertEqual(environ["REMOTE_HOST"], "127.0.0.1")
+ self.assertEqual(environ["REMOTE_PORT"], "39830")
+ self.assertEqual(environ["CONTENT_TYPE"], "abc")
+ self.assertEqual(environ["CONTENT_LENGTH"], "10")
+ self.assertEqual(environ["HTTP_X_FOO"], "BAR")
+ self.assertEqual(environ["wsgi.version"], (1, 0))
+ self.assertEqual(environ["wsgi.url_scheme"], "http")
+ self.assertEqual(environ["wsgi.errors"], sys.stderr)
+ self.assertEqual(environ["wsgi.multithread"], True)
+ self.assertEqual(environ["wsgi.multiprocess"], False)
+ self.assertEqual(environ["wsgi.run_once"], False)
+ self.assertEqual(environ["wsgi.input"], "stream")
+ self.assertEqual(environ["wsgi.input_terminated"], True)
+ self.assertEqual(inst.environ, environ)
+
+
+class TestErrorTask(unittest.TestCase):
+ def _makeOne(self, channel=None, request=None):
+ if channel is None:
+ channel = DummyChannel()
+ if request is None:
+ request = DummyParser()
+ request.error = self._makeDummyError()
+ from waitress.task import ErrorTask
+
+ return ErrorTask(channel, request)
+
+ def _makeDummyError(self):
+ from waitress.utilities import Error
+
+ e = Error("body")
+ e.code = 432
+ e.reason = "Too Ugly"
+ return e
+
+ def test_execute_http_10(self):
+ inst = self._makeOne()
+ inst.execute()
+ lines = filter_lines(inst.channel.written)
+ self.assertEqual(len(lines), 9)
+ self.assertEqual(lines[0], b"HTTP/1.0 432 Too Ugly")
+ self.assertEqual(lines[1], b"Connection: close")
+ self.assertEqual(lines[2], b"Content-Length: 43")
+ self.assertEqual(lines[3], b"Content-Type: text/plain")
+ self.assertTrue(lines[4])
+ self.assertEqual(lines[5], b"Server: waitress")
+ self.assertEqual(lines[6], b"Too Ugly")
+ self.assertEqual(lines[7], b"body")
+ self.assertEqual(lines[8], b"(generated by waitress)")
+
+ def test_execute_http_11(self):
+ inst = self._makeOne()
+ inst.version = "1.1"
+ inst.execute()
+ lines = filter_lines(inst.channel.written)
+ self.assertEqual(len(lines), 9)
+ self.assertEqual(lines[0], b"HTTP/1.1 432 Too Ugly")
+ self.assertEqual(lines[1], b"Connection: close")
+ self.assertEqual(lines[2], b"Content-Length: 43")
+ self.assertEqual(lines[3], b"Content-Type: text/plain")
+ self.assertTrue(lines[4])
+ self.assertEqual(lines[5], b"Server: waitress")
+ self.assertEqual(lines[6], b"Too Ugly")
+ self.assertEqual(lines[7], b"body")
+ self.assertEqual(lines[8], b"(generated by waitress)")
+
+ def test_execute_http_11_close(self):
+ inst = self._makeOne()
+ inst.version = "1.1"
+ inst.request.headers["CONNECTION"] = "close"
+ inst.execute()
+ lines = filter_lines(inst.channel.written)
+ self.assertEqual(len(lines), 9)
+ self.assertEqual(lines[0], b"HTTP/1.1 432 Too Ugly")
+ self.assertEqual(lines[1], b"Connection: close")
+ self.assertEqual(lines[2], b"Content-Length: 43")
+ self.assertEqual(lines[3], b"Content-Type: text/plain")
+ self.assertTrue(lines[4])
+ self.assertEqual(lines[5], b"Server: waitress")
+ self.assertEqual(lines[6], b"Too Ugly")
+ self.assertEqual(lines[7], b"body")
+ self.assertEqual(lines[8], b"(generated by waitress)")
+
+ def test_execute_http_11_keep_forces_close(self):
+ inst = self._makeOne()
+ inst.version = "1.1"
+ inst.request.headers["CONNECTION"] = "keep-alive"
+ inst.execute()
+ lines = filter_lines(inst.channel.written)
+ self.assertEqual(len(lines), 9)
+ self.assertEqual(lines[0], b"HTTP/1.1 432 Too Ugly")
+ self.assertEqual(lines[1], b"Connection: close")
+ self.assertEqual(lines[2], b"Content-Length: 43")
+ self.assertEqual(lines[3], b"Content-Type: text/plain")
+ self.assertTrue(lines[4])
+ self.assertEqual(lines[5], b"Server: waitress")
+ self.assertEqual(lines[6], b"Too Ugly")
+ self.assertEqual(lines[7], b"body")
+ self.assertEqual(lines[8], b"(generated by waitress)")
+
+
+class DummyTask(object):
+ serviced = False
+ cancelled = False
+
+ def service(self):
+ self.serviced = True
+
+ def cancel(self):
+ self.cancelled = True
+
+
+class DummyAdj(object):
+ log_socket_errors = True
+ ident = "waitress"
+ host = "127.0.0.1"
+ port = 80
+ url_prefix = ""
+
+
+class DummyServer(object):
+ server_name = "localhost"
+ effective_port = 80
+
+ def __init__(self):
+ self.adj = DummyAdj()
+
+
+class DummyChannel(object):
+ closed_when_done = False
+ adj = DummyAdj()
+ creation_time = 0
+ addr = ("127.0.0.1", 39830)
+
+ def __init__(self, server=None):
+ if server is None:
+ server = DummyServer()
+ self.server = server
+ self.written = b""
+ self.otherdata = []
+
+ def write_soon(self, data):
+ if isinstance(data, bytes):
+ self.written += data
+ else:
+ self.otherdata.append(data)
+ return len(data)
+
+
+class DummyParser(object):
+ version = "1.0"
+ command = "GET"
+ path = "/"
+ query = ""
+ url_scheme = "http"
+ expect_continue = False
+ headers_finished = False
+
+ def __init__(self):
+ self.headers = {}
+
+ def get_body_stream(self):
+ return "stream"
+
+
+def filter_lines(s):
+ return list(filter(None, s.split(b"\r\n")))
+
+
+class DummyLogger(object):
+ def __init__(self):
+ self.logged = []
+
+ def warning(self, msg, *args):
+ self.logged.append(msg % args)
+
+ def exception(self, msg, *args):
+ self.logged.append(msg % args)
diff --git a/libs/waitress/tests/test_trigger.py b/libs/waitress/tests/test_trigger.py
new file mode 100644
index 000000000..af740f68d
--- /dev/null
+++ b/libs/waitress/tests/test_trigger.py
@@ -0,0 +1,111 @@
+import unittest
+import os
+import sys
+
+if not sys.platform.startswith("win"):
+
+ class Test_trigger(unittest.TestCase):
+ def _makeOne(self, map):
+ from waitress.trigger import trigger
+
+ self.inst = trigger(map)
+ return self.inst
+
+ def tearDown(self):
+ self.inst.close() # prevent __del__ warning from file_dispatcher
+
+ def test__close(self):
+ map = {}
+ inst = self._makeOne(map)
+ fd1, fd2 = inst._fds
+ inst.close()
+ self.assertRaises(OSError, os.read, fd1, 1)
+ self.assertRaises(OSError, os.read, fd2, 1)
+
+ def test__physical_pull(self):
+ map = {}
+ inst = self._makeOne(map)
+ inst._physical_pull()
+ r = os.read(inst._fds[0], 1)
+ self.assertEqual(r, b"x")
+
+ def test_readable(self):
+ map = {}
+ inst = self._makeOne(map)
+ self.assertEqual(inst.readable(), True)
+
+ def test_writable(self):
+ map = {}
+ inst = self._makeOne(map)
+ self.assertEqual(inst.writable(), False)
+
+ def test_handle_connect(self):
+ map = {}
+ inst = self._makeOne(map)
+ self.assertEqual(inst.handle_connect(), None)
+
+ def test_close(self):
+ map = {}
+ inst = self._makeOne(map)
+ self.assertEqual(inst.close(), None)
+ self.assertEqual(inst._closed, True)
+
+ def test_handle_close(self):
+ map = {}
+ inst = self._makeOne(map)
+ self.assertEqual(inst.handle_close(), None)
+ self.assertEqual(inst._closed, True)
+
+ def test_pull_trigger_nothunk(self):
+ map = {}
+ inst = self._makeOne(map)
+ self.assertEqual(inst.pull_trigger(), None)
+ r = os.read(inst._fds[0], 1)
+ self.assertEqual(r, b"x")
+
+ def test_pull_trigger_thunk(self):
+ map = {}
+ inst = self._makeOne(map)
+ self.assertEqual(inst.pull_trigger(True), None)
+ self.assertEqual(len(inst.thunks), 1)
+ r = os.read(inst._fds[0], 1)
+ self.assertEqual(r, b"x")
+
+ def test_handle_read_socket_error(self):
+ map = {}
+ inst = self._makeOne(map)
+ result = inst.handle_read()
+ self.assertEqual(result, None)
+
+ def test_handle_read_no_socket_error(self):
+ map = {}
+ inst = self._makeOne(map)
+ inst.pull_trigger()
+ result = inst.handle_read()
+ self.assertEqual(result, None)
+
+ def test_handle_read_thunk(self):
+ map = {}
+ inst = self._makeOne(map)
+ inst.pull_trigger()
+ L = []
+ inst.thunks = [lambda: L.append(True)]
+ result = inst.handle_read()
+ self.assertEqual(result, None)
+ self.assertEqual(L, [True])
+ self.assertEqual(inst.thunks, [])
+
+ def test_handle_read_thunk_error(self):
+ map = {}
+ inst = self._makeOne(map)
+
+ def errorthunk():
+ raise ValueError
+
+ inst.pull_trigger(errorthunk)
+ L = []
+ inst.log_info = lambda *arg: L.append(arg)
+ result = inst.handle_read()
+ self.assertEqual(result, None)
+ self.assertEqual(len(L), 1)
+ self.assertEqual(inst.thunks, [])
diff --git a/libs/waitress/tests/test_utilities.py b/libs/waitress/tests/test_utilities.py
new file mode 100644
index 000000000..15cd24f5a
--- /dev/null
+++ b/libs/waitress/tests/test_utilities.py
@@ -0,0 +1,140 @@
+##############################################################################
+#
+# Copyright (c) 2002 Zope Foundation and Contributors.
+# All Rights Reserved.
+#
+# This software is subject to the provisions of the Zope Public License,
+# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
+# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
+# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
+# FOR A PARTICULAR PURPOSE.
+#
+##############################################################################
+
+import unittest
+
+
+class Test_parse_http_date(unittest.TestCase):
+ def _callFUT(self, v):
+ from waitress.utilities import parse_http_date
+
+ return parse_http_date(v)
+
+ def test_rfc850(self):
+ val = "Tuesday, 08-Feb-94 14:15:29 GMT"
+ result = self._callFUT(val)
+ self.assertEqual(result, 760716929)
+
+ def test_rfc822(self):
+ val = "Sun, 08 Feb 1994 14:15:29 GMT"
+ result = self._callFUT(val)
+ self.assertEqual(result, 760716929)
+
+ def test_neither(self):
+ val = ""
+ result = self._callFUT(val)
+ self.assertEqual(result, 0)
+
+
+class Test_build_http_date(unittest.TestCase):
+ def test_rountdrip(self):
+ from waitress.utilities import build_http_date, parse_http_date
+ from time import time
+
+ t = int(time())
+ self.assertEqual(t, parse_http_date(build_http_date(t)))
+
+
+class Test_unpack_rfc850(unittest.TestCase):
+ def _callFUT(self, val):
+ from waitress.utilities import unpack_rfc850, rfc850_reg
+
+ return unpack_rfc850(rfc850_reg.match(val.lower()))
+
+ def test_it(self):
+ val = "Tuesday, 08-Feb-94 14:15:29 GMT"
+ result = self._callFUT(val)
+ self.assertEqual(result, (1994, 2, 8, 14, 15, 29, 0, 0, 0))
+
+
+class Test_unpack_rfc_822(unittest.TestCase):
+ def _callFUT(self, val):
+ from waitress.utilities import unpack_rfc822, rfc822_reg
+
+ return unpack_rfc822(rfc822_reg.match(val.lower()))
+
+ def test_it(self):
+ val = "Sun, 08 Feb 1994 14:15:29 GMT"
+ result = self._callFUT(val)
+ self.assertEqual(result, (1994, 2, 8, 14, 15, 29, 0, 0, 0))
+
+
+class Test_find_double_newline(unittest.TestCase):
+ def _callFUT(self, val):
+ from waitress.utilities import find_double_newline
+
+ return find_double_newline(val)
+
+ def test_empty(self):
+ self.assertEqual(self._callFUT(b""), -1)
+
+ def test_one_linefeed(self):
+ self.assertEqual(self._callFUT(b"\n"), -1)
+
+ def test_double_linefeed(self):
+ self.assertEqual(self._callFUT(b"\n\n"), -1)
+
+ def test_one_crlf(self):
+ self.assertEqual(self._callFUT(b"\r\n"), -1)
+
+ def test_double_crfl(self):
+ self.assertEqual(self._callFUT(b"\r\n\r\n"), 4)
+
+ def test_mixed(self):
+ self.assertEqual(self._callFUT(b"\n\n00\r\n\r\n"), 8)
+
+
+class TestBadRequest(unittest.TestCase):
+ def _makeOne(self):
+ from waitress.utilities import BadRequest
+
+ return BadRequest(1)
+
+ def test_it(self):
+ inst = self._makeOne()
+ self.assertEqual(inst.body, 1)
+
+
+class Test_undquote(unittest.TestCase):
+ def _callFUT(self, value):
+ from waitress.utilities import undquote
+
+ return undquote(value)
+
+ def test_empty(self):
+ self.assertEqual(self._callFUT(""), "")
+
+ def test_quoted(self):
+ self.assertEqual(self._callFUT('"test"'), "test")
+
+ def test_unquoted(self):
+ self.assertEqual(self._callFUT("test"), "test")
+
+ def test_quoted_backslash_quote(self):
+ self.assertEqual(self._callFUT('"\\""'), '"')
+
+ def test_quoted_htab(self):
+ self.assertEqual(self._callFUT('"\t"'), "\t")
+
+ def test_quoted_backslash_htab(self):
+ self.assertEqual(self._callFUT('"\\\t"'), "\t")
+
+ def test_quoted_backslash_invalid(self):
+ self.assertRaises(ValueError, self._callFUT, '"\\"')
+
+ def test_invalid_quoting(self):
+ self.assertRaises(ValueError, self._callFUT, '"test')
+
+ def test_invalid_quoting_single_quote(self):
+ self.assertRaises(ValueError, self._callFUT, '"')
diff --git a/libs/waitress/tests/test_wasyncore.py b/libs/waitress/tests/test_wasyncore.py
new file mode 100644
index 000000000..9c235092f
--- /dev/null
+++ b/libs/waitress/tests/test_wasyncore.py
@@ -0,0 +1,1761 @@
+from waitress import wasyncore as asyncore
+from waitress import compat
+import contextlib
+import functools
+import gc
+import unittest
+import select
+import os
+import socket
+import sys
+import time
+import errno
+import re
+import struct
+import threading
+import warnings
+
+from io import BytesIO
+
+TIMEOUT = 3
+HAS_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
+HOST = "localhost"
+HOSTv4 = "127.0.0.1"
+HOSTv6 = "::1"
+
+# Filename used for testing
+if os.name == "java": # pragma: no cover
+ # Jython disallows @ in module names
+ TESTFN = "$test"
+else:
+ TESTFN = "@test"
+
+TESTFN = "{}_{}_tmp".format(TESTFN, os.getpid())
+
+
+class DummyLogger(object): # pragma: no cover
+ def __init__(self):
+ self.messages = []
+
+ def log(self, severity, message):
+ self.messages.append((severity, message))
+
+
+class WarningsRecorder(object): # pragma: no cover
+ """Convenience wrapper for the warnings list returned on
+ entry to the warnings.catch_warnings() context manager.
+ """
+
+ def __init__(self, warnings_list):
+ self._warnings = warnings_list
+ self._last = 0
+
+ @property
+ def warnings(self):
+ return self._warnings[self._last :]
+
+ def reset(self):
+ self._last = len(self._warnings)
+
+
+def _filterwarnings(filters, quiet=False): # pragma: no cover
+ """Catch the warnings, then check if all the expected
+ warnings have been raised and re-raise unexpected warnings.
+ If 'quiet' is True, only re-raise the unexpected warnings.
+ """
+ # Clear the warning registry of the calling module
+ # in order to re-raise the warnings.
+ frame = sys._getframe(2)
+ registry = frame.f_globals.get("__warningregistry__")
+ if registry:
+ registry.clear()
+ with warnings.catch_warnings(record=True) as w:
+ # Set filter "always" to record all warnings. Because
+ # test_warnings swap the module, we need to look up in
+ # the sys.modules dictionary.
+ sys.modules["warnings"].simplefilter("always")
+ yield WarningsRecorder(w)
+ # Filter the recorded warnings
+ reraise = list(w)
+ missing = []
+ for msg, cat in filters:
+ seen = False
+ for w in reraise[:]:
+ warning = w.message
+ # Filter out the matching messages
+ if re.match(msg, str(warning), re.I) and issubclass(warning.__class__, cat):
+ seen = True
+ reraise.remove(w)
+ if not seen and not quiet:
+ # This filter caught nothing
+ missing.append((msg, cat.__name__))
+ if reraise:
+ raise AssertionError("unhandled warning %s" % reraise[0])
+ if missing:
+ raise AssertionError("filter (%r, %s) did not catch any warning" % missing[0])
+
+
+def check_warnings(*filters, **kwargs): # pragma: no cover
+ """Context manager to silence warnings.
+
+ Accept 2-tuples as positional arguments:
+ ("message regexp", WarningCategory)
+
+ Optional argument:
+ - if 'quiet' is True, it does not fail if a filter catches nothing
+ (default True without argument,
+ default False if some filters are defined)
+
+ Without argument, it defaults to:
+ check_warnings(("", Warning), quiet=True)
+ """
+ quiet = kwargs.get("quiet")
+ if not filters:
+ filters = (("", Warning),)
+ # Preserve backward compatibility
+ if quiet is None:
+ quiet = True
+ return _filterwarnings(filters, quiet)
+
+
+def gc_collect(): # pragma: no cover
+ """Force as many objects as possible to be collected.
+
+ In non-CPython implementations of Python, this is needed because timely
+ deallocation is not guaranteed by the garbage collector. (Even in CPython
+ this can be the case in case of reference cycles.) This means that __del__
+ methods may be called later than expected and weakrefs may remain alive for
+ longer than expected. This function tries its best to force all garbage
+ objects to disappear.
+ """
+ gc.collect()
+ if sys.platform.startswith("java"):
+ time.sleep(0.1)
+ gc.collect()
+ gc.collect()
+
+
+def threading_setup(): # pragma: no cover
+ return (compat.thread._count(), None)
+
+
+def threading_cleanup(*original_values): # pragma: no cover
+ global environment_altered
+
+ _MAX_COUNT = 100
+
+ for count in range(_MAX_COUNT):
+ values = (compat.thread._count(), None)
+ if values == original_values:
+ break
+
+ if not count:
+ # Display a warning at the first iteration
+ environment_altered = True
+ sys.stderr.write(
+ "Warning -- threading_cleanup() failed to cleanup "
+ "%s threads" % (values[0] - original_values[0])
+ )
+ sys.stderr.flush()
+
+ values = None
+
+ time.sleep(0.01)
+ gc_collect()
+
+
+def reap_threads(func): # pragma: no cover
+ """Use this function when threads are being used. This will
+ ensure that the threads are cleaned up even when the test fails.
+ """
+
+ @functools.wraps(func)
+ def decorator(*args):
+ key = threading_setup()
+ try:
+ return func(*args)
+ finally:
+ threading_cleanup(*key)
+
+ return decorator
+
+
+def join_thread(thread, timeout=30.0): # pragma: no cover
+ """Join a thread. Raise an AssertionError if the thread is still alive
+ after timeout seconds.
+ """
+ thread.join(timeout)
+ if thread.is_alive():
+ msg = "failed to join the thread in %.1f seconds" % timeout
+ raise AssertionError(msg)
+
+
+def bind_port(sock, host=HOST): # pragma: no cover
+ """Bind the socket to a free port and return the port number. Relies on
+ ephemeral ports in order to ensure we are using an unbound port. This is
+ important as many tests may be running simultaneously, especially in a
+ buildbot environment. This method raises an exception if the sock.family
+ is AF_INET and sock.type is SOCK_STREAM, *and* the socket has SO_REUSEADDR
+ or SO_REUSEPORT set on it. Tests should *never* set these socket options
+ for TCP/IP sockets. The only case for setting these options is testing
+ multicasting via multiple UDP sockets.
+
+ Additionally, if the SO_EXCLUSIVEADDRUSE socket option is available (i.e.
+ on Windows), it will be set on the socket. This will prevent anyone else
+ from bind()'ing to our host/port for the duration of the test.
+ """
+
+ if sock.family == socket.AF_INET and sock.type == socket.SOCK_STREAM:
+ if hasattr(socket, "SO_REUSEADDR"):
+ if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) == 1:
+ raise RuntimeError(
+ "tests should never set the SO_REUSEADDR "
+ "socket option on TCP/IP sockets!"
+ )
+ if hasattr(socket, "SO_REUSEPORT"):
+ try:
+ if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 1:
+ raise RuntimeError(
+ "tests should never set the SO_REUSEPORT "
+ "socket option on TCP/IP sockets!"
+ )
+ except OSError:
+ # Python's socket module was compiled using modern headers
+ # thus defining SO_REUSEPORT but this process is running
+ # under an older kernel that does not support SO_REUSEPORT.
+ pass
+ if hasattr(socket, "SO_EXCLUSIVEADDRUSE"):
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1)
+
+ sock.bind((host, 0))
+ port = sock.getsockname()[1]
+ return port
+
+
+def closewrapper(sock): # pragma: no cover
+ try:
+ yield sock
+ finally:
+ sock.close()
+
+
+class dummysocket: # pragma: no cover
+ def __init__(self):
+ self.closed = False
+
+ def close(self):
+ self.closed = True
+
+ def fileno(self):
+ return 42
+
+ def setblocking(self, yesno):
+ self.isblocking = yesno
+
+ def getpeername(self):
+ return "peername"
+
+
+class dummychannel: # pragma: no cover
+ def __init__(self):
+ self.socket = dummysocket()
+
+ def close(self):
+ self.socket.close()
+
+
+class exitingdummy: # pragma: no cover
+ def __init__(self):
+ pass
+
+ def handle_read_event(self):
+ raise asyncore.ExitNow()
+
+ handle_write_event = handle_read_event
+ handle_close = handle_read_event
+ handle_expt_event = handle_read_event
+
+
+class crashingdummy:
+ def __init__(self):
+ self.error_handled = False
+
+ def handle_read_event(self):
+ raise Exception()
+
+ handle_write_event = handle_read_event
+ handle_close = handle_read_event
+ handle_expt_event = handle_read_event
+
+ def handle_error(self):
+ self.error_handled = True
+
+
+# used when testing senders; just collects what it gets until newline is sent
+def capture_server(evt, buf, serv): # pragma no cover
+ try:
+ serv.listen(0)
+ conn, addr = serv.accept()
+ except socket.timeout:
+ pass
+ else:
+ n = 200
+ start = time.time()
+ while n > 0 and time.time() - start < 3.0:
+ r, w, e = select.select([conn], [], [], 0.1)
+ if r:
+ n -= 1
+ data = conn.recv(10)
+ # keep everything except for the newline terminator
+ buf.write(data.replace(b"\n", b""))
+ if b"\n" in data:
+ break
+ time.sleep(0.01)
+
+ conn.close()
+ finally:
+ serv.close()
+ evt.set()
+
+
+def bind_unix_socket(sock, addr): # pragma: no cover
+ """Bind a unix socket, raising SkipTest if PermissionError is raised."""
+ assert sock.family == socket.AF_UNIX
+ try:
+ sock.bind(addr)
+ except PermissionError:
+ sock.close()
+ raise unittest.SkipTest("cannot bind AF_UNIX sockets")
+
+
+def bind_af_aware(sock, addr):
+ """Helper function to bind a socket according to its family."""
+ if HAS_UNIX_SOCKETS and sock.family == socket.AF_UNIX:
+ # Make sure the path doesn't exist.
+ unlink(addr)
+ bind_unix_socket(sock, addr)
+ else:
+ sock.bind(addr)
+
+
+if sys.platform.startswith("win"): # pragma: no cover
+
+ def _waitfor(func, pathname, waitall=False):
+ # Perform the operation
+ func(pathname)
+ # Now setup the wait loop
+ if waitall:
+ dirname = pathname
+ else:
+ dirname, name = os.path.split(pathname)
+ dirname = dirname or "."
+ # Check for `pathname` to be removed from the filesystem.
+ # The exponential backoff of the timeout amounts to a total
+ # of ~1 second after which the deletion is probably an error
+ # anyway.
+ # Testing on an [email protected] shows that usually only 1 iteration is
+ # required when contention occurs.
+ timeout = 0.001
+ while timeout < 1.0:
+ # Note we are only testing for the existence of the file(s) in
+ # the contents of the directory regardless of any security or
+ # access rights. If we have made it this far, we have sufficient
+ # permissions to do that much using Python's equivalent of the
+ # Windows API FindFirstFile.
+ # Other Windows APIs can fail or give incorrect results when
+ # dealing with files that are pending deletion.
+ L = os.listdir(dirname)
+ if not (L if waitall else name in L):
+ return
+ # Increase the timeout and try again
+ time.sleep(timeout)
+ timeout *= 2
+ warnings.warn(
+ "tests may fail, delete still pending for " + pathname,
+ RuntimeWarning,
+ stacklevel=4,
+ )
+
+ def _unlink(filename):
+ _waitfor(os.unlink, filename)
+
+
+else:
+ _unlink = os.unlink
+
+
+def unlink(filename):
+ try:
+ _unlink(filename)
+ except OSError:
+ pass
+
+
+def _is_ipv6_enabled(): # pragma: no cover
+ """Check whether IPv6 is enabled on this host."""
+ if compat.HAS_IPV6:
+ sock = None
+ try:
+ sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
+ sock.bind(("::1", 0))
+ return True
+ except socket.error:
+ pass
+ finally:
+ if sock:
+ sock.close()
+ return False
+
+
+IPV6_ENABLED = _is_ipv6_enabled()
+
+
+class HelperFunctionTests(unittest.TestCase):
+ def test_readwriteexc(self):
+ # Check exception handling behavior of read, write and _exception
+
+ # check that ExitNow exceptions in the object handler method
+ # bubbles all the way up through asyncore read/write/_exception calls
+ tr1 = exitingdummy()
+ self.assertRaises(asyncore.ExitNow, asyncore.read, tr1)
+ self.assertRaises(asyncore.ExitNow, asyncore.write, tr1)
+ self.assertRaises(asyncore.ExitNow, asyncore._exception, tr1)
+
+ # check that an exception other than ExitNow in the object handler
+ # method causes the handle_error method to get called
+ tr2 = crashingdummy()
+ asyncore.read(tr2)
+ self.assertEqual(tr2.error_handled, True)
+
+ tr2 = crashingdummy()
+ asyncore.write(tr2)
+ self.assertEqual(tr2.error_handled, True)
+
+ tr2 = crashingdummy()
+ asyncore._exception(tr2)
+ self.assertEqual(tr2.error_handled, True)
+
+ # asyncore.readwrite uses constants in the select module that
+ # are not present in Windows systems (see this thread:
+ # http://mail.python.org/pipermail/python-list/2001-October/109973.html)
+ # These constants should be present as long as poll is available
+
+ @unittest.skipUnless(hasattr(select, "poll"), "select.poll required")
+ def test_readwrite(self):
+ # Check that correct methods are called by readwrite()
+
+ attributes = ("read", "expt", "write", "closed", "error_handled")
+
+ expected = (
+ (select.POLLIN, "read"),
+ (select.POLLPRI, "expt"),
+ (select.POLLOUT, "write"),
+ (select.POLLERR, "closed"),
+ (select.POLLHUP, "closed"),
+ (select.POLLNVAL, "closed"),
+ )
+
+ class testobj:
+ def __init__(self):
+ self.read = False
+ self.write = False
+ self.closed = False
+ self.expt = False
+ self.error_handled = False
+
+ def handle_read_event(self):
+ self.read = True
+
+ def handle_write_event(self):
+ self.write = True
+
+ def handle_close(self):
+ self.closed = True
+
+ def handle_expt_event(self):
+ self.expt = True
+
+ # def handle_error(self):
+ # self.error_handled = True
+
+ for flag, expectedattr in expected:
+ tobj = testobj()
+ self.assertEqual(getattr(tobj, expectedattr), False)
+ asyncore.readwrite(tobj, flag)
+
+ # Only the attribute modified by the routine we expect to be
+ # called should be True.
+ for attr in attributes:
+ self.assertEqual(getattr(tobj, attr), attr == expectedattr)
+
+ # check that ExitNow exceptions in the object handler method
+ # bubbles all the way up through asyncore readwrite call
+ tr1 = exitingdummy()
+ self.assertRaises(asyncore.ExitNow, asyncore.readwrite, tr1, flag)
+
+ # check that an exception other than ExitNow in the object handler
+ # method causes the handle_error method to get called
+ tr2 = crashingdummy()
+ self.assertEqual(tr2.error_handled, False)
+ asyncore.readwrite(tr2, flag)
+ self.assertEqual(tr2.error_handled, True)
+
+ def test_closeall(self):
+ self.closeall_check(False)
+
+ def test_closeall_default(self):
+ self.closeall_check(True)
+
+ def closeall_check(self, usedefault):
+ # Check that close_all() closes everything in a given map
+
+ l = []
+ testmap = {}
+ for i in range(10):
+ c = dummychannel()
+ l.append(c)
+ self.assertEqual(c.socket.closed, False)
+ testmap[i] = c
+
+ if usedefault:
+ socketmap = asyncore.socket_map
+ try:
+ asyncore.socket_map = testmap
+ asyncore.close_all()
+ finally:
+ testmap, asyncore.socket_map = asyncore.socket_map, socketmap
+ else:
+ asyncore.close_all(testmap)
+
+ self.assertEqual(len(testmap), 0)
+
+ for c in l:
+ self.assertEqual(c.socket.closed, True)
+
+ def test_compact_traceback(self):
+ try:
+ raise Exception("I don't like spam!")
+ except:
+ real_t, real_v, real_tb = sys.exc_info()
+ r = asyncore.compact_traceback()
+
+ (f, function, line), t, v, info = r
+ self.assertEqual(os.path.split(f)[-1], "test_wasyncore.py")
+ self.assertEqual(function, "test_compact_traceback")
+ self.assertEqual(t, real_t)
+ self.assertEqual(v, real_v)
+ self.assertEqual(info, "[%s|%s|%s]" % (f, function, line))
+
+
+class DispatcherTests(unittest.TestCase):
+ def setUp(self):
+ pass
+
+ def tearDown(self):
+ asyncore.close_all()
+
+ def test_basic(self):
+ d = asyncore.dispatcher()
+ self.assertEqual(d.readable(), True)
+ self.assertEqual(d.writable(), True)
+
+ def test_repr(self):
+ d = asyncore.dispatcher()
+ self.assertEqual(repr(d), "<waitress.wasyncore.dispatcher at %#x>" % id(d))
+
+ def test_log_info(self):
+ import logging
+
+ inst = asyncore.dispatcher(map={})
+ logger = DummyLogger()
+ inst.logger = logger
+ inst.log_info("message", "warning")
+ self.assertEqual(logger.messages, [(logging.WARN, "message")])
+
+ def test_log(self):
+ import logging
+
+ inst = asyncore.dispatcher()
+ logger = DummyLogger()
+ inst.logger = logger
+ inst.log("message")
+ self.assertEqual(logger.messages, [(logging.DEBUG, "message")])
+
+ def test_unhandled(self):
+ import logging
+
+ inst = asyncore.dispatcher()
+ logger = DummyLogger()
+ inst.logger = logger
+
+ inst.handle_expt()
+ inst.handle_read()
+ inst.handle_write()
+ inst.handle_connect()
+
+ expected = [
+ (logging.WARN, "unhandled incoming priority event"),
+ (logging.WARN, "unhandled read event"),
+ (logging.WARN, "unhandled write event"),
+ (logging.WARN, "unhandled connect event"),
+ ]
+ self.assertEqual(logger.messages, expected)
+
+ def test_strerror(self):
+ # refers to bug #8573
+ err = asyncore._strerror(errno.EPERM)
+ if hasattr(os, "strerror"):
+ self.assertEqual(err, os.strerror(errno.EPERM))
+ err = asyncore._strerror(-1)
+ self.assertTrue(err != "")
+
+
+class dispatcherwithsend_noread(asyncore.dispatcher_with_send): # pragma: no cover
+ def readable(self):
+ return False
+
+ def handle_connect(self):
+ pass
+
+
+class DispatcherWithSendTests(unittest.TestCase):
+ def setUp(self):
+ pass
+
+ def tearDown(self):
+ asyncore.close_all()
+
+ @reap_threads
+ def test_send(self):
+ evt = threading.Event()
+ sock = socket.socket()
+ sock.settimeout(3)
+ port = bind_port(sock)
+
+ cap = BytesIO()
+ args = (evt, cap, sock)
+ t = threading.Thread(target=capture_server, args=args)
+ t.start()
+ try:
+ # wait a little longer for the server to initialize (it sometimes
+ # refuses connections on slow machines without this wait)
+ time.sleep(0.2)
+
+ data = b"Suppose there isn't a 16-ton weight?"
+ d = dispatcherwithsend_noread()
+ d.create_socket()
+ d.connect((HOST, port))
+
+ # give time for socket to connect
+ time.sleep(0.1)
+
+ d.send(data)
+ d.send(data)
+ d.send(b"\n")
+
+ n = 1000
+ while d.out_buffer and n > 0: # pragma: no cover
+ asyncore.poll()
+ n -= 1
+
+ evt.wait()
+
+ self.assertEqual(cap.getvalue(), data * 2)
+ finally:
+ join_thread(t, timeout=TIMEOUT)
+
+
+ hasattr(asyncore, "file_wrapper"), "asyncore.file_wrapper required"
+)
+class FileWrapperTest(unittest.TestCase):
+ def setUp(self):
+ self.d = b"It's not dead, it's sleeping!"
+ with open(TESTFN, "wb") as file:
+ file.write(self.d)
+
+ def tearDown(self):
+ unlink(TESTFN)
+
+ def test_recv(self):
+ fd = os.open(TESTFN, os.O_RDONLY)
+ w = asyncore.file_wrapper(fd)
+ os.close(fd)
+
+ self.assertNotEqual(w.fd, fd)
+ self.assertNotEqual(w.fileno(), fd)
+ self.assertEqual(w.recv(13), b"It's not dead")
+ self.assertEqual(w.read(6), b", it's")
+ w.close()
+ self.assertRaises(OSError, w.read, 1)
+
+ def test_send(self):
+ d1 = b"Come again?"
+ d2 = b"I want to buy some cheese."
+ fd = os.open(TESTFN, os.O_WRONLY | os.O_APPEND)
+ w = asyncore.file_wrapper(fd)
+ os.close(fd)
+
+ w.write(d1)
+ w.send(d2)
+ w.close()
+ with open(TESTFN, "rb") as file:
+ self.assertEqual(file.read(), self.d + d1 + d2)
+
+ @unittest.skipUnless(
+ hasattr(asyncore, "file_dispatcher"), "asyncore.file_dispatcher required"
+ )
+ def test_dispatcher(self):
+ fd = os.open(TESTFN, os.O_RDONLY)
+ data = []
+
+ class FileDispatcher(asyncore.file_dispatcher):
+ def handle_read(self):
+ data.append(self.recv(29))
+
+ FileDispatcher(fd)
+ os.close(fd)
+ asyncore.loop(timeout=0.01, use_poll=True, count=2)
+ self.assertEqual(b"".join(data), self.d)
+
+ def test_resource_warning(self):
+ # Issue #11453
+ got_warning = False
+ while got_warning is False:
+ # we try until we get the outcome we want because this
+ # test is not deterministic (gc_collect() may not
+ fd = os.open(TESTFN, os.O_RDONLY)
+ f = asyncore.file_wrapper(fd)
+
+ os.close(fd)
+
+ try:
+ with check_warnings(("", compat.ResourceWarning)):
+ f = None
+ gc_collect()
+ except AssertionError: # pragma: no cover
+ pass
+ else:
+ got_warning = True
+
+ def test_close_twice(self):
+ fd = os.open(TESTFN, os.O_RDONLY)
+ f = asyncore.file_wrapper(fd)
+ os.close(fd)
+
+ os.close(f.fd) # file_wrapper dupped fd
+ with self.assertRaises(OSError):
+ f.close()
+
+ self.assertEqual(f.fd, -1)
+ # calling close twice should not fail
+ f.close()
+
+
+class BaseTestHandler(asyncore.dispatcher): # pragma: no cover
+ def __init__(self, sock=None):
+ asyncore.dispatcher.__init__(self, sock)
+ self.flag = False
+
+ def handle_accept(self):
+ raise Exception("handle_accept not supposed to be called")
+
+ def handle_accepted(self):
+ raise Exception("handle_accepted not supposed to be called")
+
+ def handle_connect(self):
+ raise Exception("handle_connect not supposed to be called")
+
+ def handle_expt(self):
+ raise Exception("handle_expt not supposed to be called")
+
+ def handle_close(self):
+ raise Exception("handle_close not supposed to be called")
+
+ def handle_error(self):
+ raise
+
+
+class BaseServer(asyncore.dispatcher):
+ """A server which listens on an address and dispatches the
+ connection to a handler.
+ """
+
+ def __init__(self, family, addr, handler=BaseTestHandler):
+ asyncore.dispatcher.__init__(self)
+ self.create_socket(family)
+ self.set_reuse_addr()
+ bind_af_aware(self.socket, addr)
+ self.listen(5)
+ self.handler = handler
+
+ @property
+ def address(self):
+ return self.socket.getsockname()
+
+ def handle_accepted(self, sock, addr):
+ self.handler(sock)
+
+ def handle_error(self): # pragma: no cover
+ raise
+
+
+class BaseClient(BaseTestHandler):
+ def __init__(self, family, address):
+ BaseTestHandler.__init__(self)
+ self.create_socket(family)
+ self.connect(address)
+
+ def handle_connect(self):
+ pass
+
+
+class BaseTestAPI:
+ def tearDown(self):
+ asyncore.close_all(ignore_all=True)
+
+ def loop_waiting_for_flag(self, instance, timeout=5): # pragma: no cover
+ timeout = float(timeout) / 100
+ count = 100
+ while asyncore.socket_map and count > 0:
+ asyncore.loop(timeout=0.01, count=1, use_poll=self.use_poll)
+ if instance.flag:
+ return
+ count -= 1
+ time.sleep(timeout)
+ self.fail("flag not set")
+
+ def test_handle_connect(self):
+ # make sure handle_connect is called on connect()
+
+ class TestClient(BaseClient):
+ def handle_connect(self):
+ self.flag = True
+
+ server = BaseServer(self.family, self.addr)
+ client = TestClient(self.family, server.address)
+ self.loop_waiting_for_flag(client)
+
+ def test_handle_accept(self):
+ # make sure handle_accept() is called when a client connects
+
+ class TestListener(BaseTestHandler):
+ def __init__(self, family, addr):
+ BaseTestHandler.__init__(self)
+ self.create_socket(family)
+ bind_af_aware(self.socket, addr)
+ self.listen(5)
+ self.address = self.socket.getsockname()
+
+ def handle_accept(self):
+ self.flag = True
+
+ server = TestListener(self.family, self.addr)
+ client = BaseClient(self.family, server.address)
+ self.loop_waiting_for_flag(server)
+
+ def test_handle_accepted(self):
+ # make sure handle_accepted() is called when a client connects
+
+ class TestListener(BaseTestHandler):
+ def __init__(self, family, addr):
+ BaseTestHandler.__init__(self)
+ self.create_socket(family)
+ bind_af_aware(self.socket, addr)
+ self.listen(5)
+ self.address = self.socket.getsockname()
+
+ def handle_accept(self):
+ asyncore.dispatcher.handle_accept(self)
+
+ def handle_accepted(self, sock, addr):
+ sock.close()
+ self.flag = True
+
+ server = TestListener(self.family, self.addr)
+ client = BaseClient(self.family, server.address)
+ self.loop_waiting_for_flag(server)
+
+ def test_handle_read(self):
+ # make sure handle_read is called on data received
+
+ class TestClient(BaseClient):
+ def handle_read(self):
+ self.flag = True
+
+ class TestHandler(BaseTestHandler):
+ def __init__(self, conn):
+ BaseTestHandler.__init__(self, conn)
+ self.send(b"x" * 1024)
+
+ server = BaseServer(self.family, self.addr, TestHandler)
+ client = TestClient(self.family, server.address)
+ self.loop_waiting_for_flag(client)
+
+ def test_handle_write(self):
+ # make sure handle_write is called
+
+ class TestClient(BaseClient):
+ def handle_write(self):
+ self.flag = True
+
+ server = BaseServer(self.family, self.addr)
+ client = TestClient(self.family, server.address)
+ self.loop_waiting_for_flag(client)
+
+ def test_handle_close(self):
+ # make sure handle_close is called when the other end closes
+ # the connection
+
+ class TestClient(BaseClient):
+ def handle_read(self):
+ # in order to make handle_close be called we are supposed
+ # to make at least one recv() call
+ self.recv(1024)
+
+ def handle_close(self):
+ self.flag = True
+ self.close()
+
+ class TestHandler(BaseTestHandler):
+ def __init__(self, conn):
+ BaseTestHandler.__init__(self, conn)
+ self.close()
+
+ server = BaseServer(self.family, self.addr, TestHandler)
+ client = TestClient(self.family, server.address)
+ self.loop_waiting_for_flag(client)
+
+ def test_handle_close_after_conn_broken(self):
+ # Check that ECONNRESET/EPIPE is correctly handled (issues #5661 and
+ # #11265).
+
+ data = b"\0" * 128
+
+ class TestClient(BaseClient):
+ def handle_write(self):
+ self.send(data)
+
+ def handle_close(self):
+ self.flag = True
+ self.close()
+
+ def handle_expt(self): # pragma: no cover
+ # needs to exist for MacOS testing
+ self.flag = True
+ self.close()
+
+ class TestHandler(BaseTestHandler):
+ def handle_read(self):
+ self.recv(len(data))
+ self.close()
+
+ def writable(self):
+ return False
+
+ server = BaseServer(self.family, self.addr, TestHandler)
+ client = TestClient(self.family, server.address)
+ self.loop_waiting_for_flag(client)
+
+ @unittest.skipIf(
+ sys.platform.startswith("sunos"), "OOB support is broken on Solaris"
+ )
+ def test_handle_expt(self):
+ # Make sure handle_expt is called on OOB data received.
+ # Note: this might fail on some platforms as OOB data is
+ # tenuously supported and rarely used.
+ if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX:
+ self.skipTest("Not applicable to AF_UNIX sockets.")
+
+ if sys.platform == "darwin" and self.use_poll: # pragma: no cover
+ self.skipTest("poll may fail on macOS; see issue #28087")
+
+ class TestClient(BaseClient):
+ def handle_expt(self):
+ self.socket.recv(1024, socket.MSG_OOB)
+ self.flag = True
+
+ class TestHandler(BaseTestHandler):
+ def __init__(self, conn):
+ BaseTestHandler.__init__(self, conn)
+ self.socket.send(compat.tobytes(chr(244)), socket.MSG_OOB)
+
+ server = BaseServer(self.family, self.addr, TestHandler)
+ client = TestClient(self.family, server.address)
+ self.loop_waiting_for_flag(client)
+
+ def test_handle_error(self):
+ class TestClient(BaseClient):
+ def handle_write(self):
+ 1.0 / 0
+
+ def handle_error(self):
+ self.flag = True
+ try:
+ raise
+ except ZeroDivisionError:
+ pass
+ else: # pragma: no cover
+ raise Exception("exception not raised")
+
+ server = BaseServer(self.family, self.addr)
+ client = TestClient(self.family, server.address)
+ self.loop_waiting_for_flag(client)
+
+ def test_connection_attributes(self):
+ server = BaseServer(self.family, self.addr)
+ client = BaseClient(self.family, server.address)
+
+ # we start disconnected
+ self.assertFalse(server.connected)
+ self.assertTrue(server.accepting)
+ # this can't be taken for granted across all platforms
+ # self.assertFalse(client.connected)
+ self.assertFalse(client.accepting)
+
+ # execute some loops so that client connects to server
+ asyncore.loop(timeout=0.01, use_poll=self.use_poll, count=100)
+ self.assertFalse(server.connected)
+ self.assertTrue(server.accepting)
+ self.assertTrue(client.connected)
+ self.assertFalse(client.accepting)
+
+ # disconnect the client
+ client.close()
+ self.assertFalse(server.connected)
+ self.assertTrue(server.accepting)
+ self.assertFalse(client.connected)
+ self.assertFalse(client.accepting)
+
+ # stop serving
+ server.close()
+ self.assertFalse(server.connected)
+ self.assertFalse(server.accepting)
+
+ def test_create_socket(self):
+ s = asyncore.dispatcher()
+ s.create_socket(self.family)
+ # self.assertEqual(s.socket.type, socket.SOCK_STREAM)
+ self.assertEqual(s.socket.family, self.family)
+ self.assertEqual(s.socket.gettimeout(), 0)
+ # self.assertFalse(s.socket.get_inheritable())
+
+ def test_bind(self):
+ if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX:
+ self.skipTest("Not applicable to AF_UNIX sockets.")
+ s1 = asyncore.dispatcher()
+ s1.create_socket(self.family)
+ s1.bind(self.addr)
+ s1.listen(5)
+ port = s1.socket.getsockname()[1]
+
+ s2 = asyncore.dispatcher()
+ s2.create_socket(self.family)
+ # EADDRINUSE indicates the socket was correctly bound
+ self.assertRaises(socket.error, s2.bind, (self.addr[0], port))
+
+ def test_set_reuse_addr(self): # pragma: no cover
+ if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX:
+ self.skipTest("Not applicable to AF_UNIX sockets.")
+
+ with closewrapper(socket.socket(self.family)) as sock:
+ try:
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ except OSError:
+ unittest.skip("SO_REUSEADDR not supported on this platform")
+ else:
+ # if SO_REUSEADDR succeeded for sock we expect asyncore
+ # to do the same
+ s = asyncore.dispatcher(socket.socket(self.family))
+ self.assertFalse(
+ s.socket.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
+ )
+ s.socket.close()
+ s.create_socket(self.family)
+ s.set_reuse_addr()
+ self.assertTrue(
+ s.socket.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
+ )
+
+ @reap_threads
+ def test_quick_connect(self): # pragma: no cover
+ # see: http://bugs.python.org/issue10340
+ if self.family not in (socket.AF_INET, getattr(socket, "AF_INET6", object())):
+ self.skipTest("test specific to AF_INET and AF_INET6")
+
+ server = BaseServer(self.family, self.addr)
+ # run the thread 500 ms: the socket should be connected in 200 ms
+ t = threading.Thread(target=lambda: asyncore.loop(timeout=0.1, count=5))
+ t.start()
+ try:
+ sock = socket.socket(self.family, socket.SOCK_STREAM)
+ with closewrapper(sock) as s:
+ s.settimeout(0.2)
+ s.setsockopt(
+ socket.SOL_SOCKET, socket.SO_LINGER, struct.pack("ii", 1, 0)
+ )
+
+ try:
+ s.connect(server.address)
+ except OSError:
+ pass
+ finally:
+ join_thread(t, timeout=TIMEOUT)
+
+
+class TestAPI_UseIPv4Sockets(BaseTestAPI):
+ family = socket.AF_INET
+ addr = (HOST, 0)
+
+
[email protected](IPV6_ENABLED, "IPv6 support required")
+class TestAPI_UseIPv6Sockets(BaseTestAPI):
+ family = socket.AF_INET6
+ addr = (HOSTv6, 0)
+
+
[email protected](HAS_UNIX_SOCKETS, "Unix sockets required")
+class TestAPI_UseUnixSockets(BaseTestAPI):
+ if HAS_UNIX_SOCKETS:
+ family = socket.AF_UNIX
+ addr = TESTFN
+
+ def tearDown(self):
+ unlink(self.addr)
+ BaseTestAPI.tearDown(self)
+
+
+class TestAPI_UseIPv4Select(TestAPI_UseIPv4Sockets, unittest.TestCase):
+ use_poll = False
+
+
[email protected](hasattr(select, "poll"), "select.poll required")
+class TestAPI_UseIPv4Poll(TestAPI_UseIPv4Sockets, unittest.TestCase):
+ use_poll = True
+
+
+class TestAPI_UseIPv6Select(TestAPI_UseIPv6Sockets, unittest.TestCase):
+ use_poll = False
+
+
[email protected](hasattr(select, "poll"), "select.poll required")
+class TestAPI_UseIPv6Poll(TestAPI_UseIPv6Sockets, unittest.TestCase):
+ use_poll = True
+
+
+class TestAPI_UseUnixSocketsSelect(TestAPI_UseUnixSockets, unittest.TestCase):
+ use_poll = False
+
+
[email protected](hasattr(select, "poll"), "select.poll required")
+class TestAPI_UseUnixSocketsPoll(TestAPI_UseUnixSockets, unittest.TestCase):
+ use_poll = True
+
+
+class Test__strerror(unittest.TestCase):
+ def _callFUT(self, err):
+ from waitress.wasyncore import _strerror
+
+ return _strerror(err)
+
+ def test_gardenpath(self):
+ self.assertEqual(self._callFUT(1), "Operation not permitted")
+
+ def test_unknown(self):
+ self.assertEqual(self._callFUT("wut"), "Unknown error wut")
+
+
+class Test_read(unittest.TestCase):
+ def _callFUT(self, dispatcher):
+ from waitress.wasyncore import read
+
+ return read(dispatcher)
+
+ def test_gardenpath(self):
+ inst = DummyDispatcher()
+ self._callFUT(inst)
+ self.assertTrue(inst.read_event_handled)
+ self.assertFalse(inst.error_handled)
+
+ def test_reraised(self):
+ from waitress.wasyncore import ExitNow
+
+ inst = DummyDispatcher(ExitNow)
+ self.assertRaises(ExitNow, self._callFUT, inst)
+ self.assertTrue(inst.read_event_handled)
+ self.assertFalse(inst.error_handled)
+
+ def test_non_reraised(self):
+ inst = DummyDispatcher(OSError)
+ self._callFUT(inst)
+ self.assertTrue(inst.read_event_handled)
+ self.assertTrue(inst.error_handled)
+
+
+class Test_write(unittest.TestCase):
+ def _callFUT(self, dispatcher):
+ from waitress.wasyncore import write
+
+ return write(dispatcher)
+
+ def test_gardenpath(self):
+ inst = DummyDispatcher()
+ self._callFUT(inst)
+ self.assertTrue(inst.write_event_handled)
+ self.assertFalse(inst.error_handled)
+
+ def test_reraised(self):
+ from waitress.wasyncore import ExitNow
+
+ inst = DummyDispatcher(ExitNow)
+ self.assertRaises(ExitNow, self._callFUT, inst)
+ self.assertTrue(inst.write_event_handled)
+ self.assertFalse(inst.error_handled)
+
+ def test_non_reraised(self):
+ inst = DummyDispatcher(OSError)
+ self._callFUT(inst)
+ self.assertTrue(inst.write_event_handled)
+ self.assertTrue(inst.error_handled)
+
+
+class Test__exception(unittest.TestCase):
+ def _callFUT(self, dispatcher):
+ from waitress.wasyncore import _exception
+
+ return _exception(dispatcher)
+
+ def test_gardenpath(self):
+ inst = DummyDispatcher()
+ self._callFUT(inst)
+ self.assertTrue(inst.expt_event_handled)
+ self.assertFalse(inst.error_handled)
+
+ def test_reraised(self):
+ from waitress.wasyncore import ExitNow
+
+ inst = DummyDispatcher(ExitNow)
+ self.assertRaises(ExitNow, self._callFUT, inst)
+ self.assertTrue(inst.expt_event_handled)
+ self.assertFalse(inst.error_handled)
+
+ def test_non_reraised(self):
+ inst = DummyDispatcher(OSError)
+ self._callFUT(inst)
+ self.assertTrue(inst.expt_event_handled)
+ self.assertTrue(inst.error_handled)
+
+
[email protected](hasattr(select, "poll"), "select.poll required")
+class Test_readwrite(unittest.TestCase):
+ def _callFUT(self, obj, flags):
+ from waitress.wasyncore import readwrite
+
+ return readwrite(obj, flags)
+
+ def test_handle_read_event(self):
+ flags = 0
+ flags |= select.POLLIN
+ inst = DummyDispatcher()
+ self._callFUT(inst, flags)
+ self.assertTrue(inst.read_event_handled)
+
+ def test_handle_write_event(self):
+ flags = 0
+ flags |= select.POLLOUT
+ inst = DummyDispatcher()
+ self._callFUT(inst, flags)
+ self.assertTrue(inst.write_event_handled)
+
+ def test_handle_expt_event(self):
+ flags = 0
+ flags |= select.POLLPRI
+ inst = DummyDispatcher()
+ self._callFUT(inst, flags)
+ self.assertTrue(inst.expt_event_handled)
+
+ def test_handle_close(self):
+ flags = 0
+ flags |= select.POLLHUP
+ inst = DummyDispatcher()
+ self._callFUT(inst, flags)
+ self.assertTrue(inst.close_handled)
+
+ def test_socketerror_not_in_disconnected(self):
+ flags = 0
+ flags |= select.POLLIN
+ inst = DummyDispatcher(socket.error(errno.EALREADY, "EALREADY"))
+ self._callFUT(inst, flags)
+ self.assertTrue(inst.read_event_handled)
+ self.assertTrue(inst.error_handled)
+
+ def test_socketerror_in_disconnected(self):
+ flags = 0
+ flags |= select.POLLIN
+ inst = DummyDispatcher(socket.error(errno.ECONNRESET, "ECONNRESET"))
+ self._callFUT(inst, flags)
+ self.assertTrue(inst.read_event_handled)
+ self.assertTrue(inst.close_handled)
+
+ def test_exception_in_reraised(self):
+ from waitress import wasyncore
+
+ flags = 0
+ flags |= select.POLLIN
+ inst = DummyDispatcher(wasyncore.ExitNow)
+ self.assertRaises(wasyncore.ExitNow, self._callFUT, inst, flags)
+ self.assertTrue(inst.read_event_handled)
+
+ def test_exception_not_in_reraised(self):
+ flags = 0
+ flags |= select.POLLIN
+ inst = DummyDispatcher(ValueError)
+ self._callFUT(inst, flags)
+ self.assertTrue(inst.error_handled)
+
+
+class Test_poll(unittest.TestCase):
+ def _callFUT(self, timeout=0.0, map=None):
+ from waitress.wasyncore import poll
+
+ return poll(timeout, map)
+
+ def test_nothing_writable_nothing_readable_but_map_not_empty(self):
+ # i read the mock.patch docs. nerp.
+ dummy_time = DummyTime()
+ map = {0: DummyDispatcher()}
+ try:
+ from waitress import wasyncore
+
+ old_time = wasyncore.time
+ wasyncore.time = dummy_time
+ result = self._callFUT(map=map)
+ finally:
+ wasyncore.time = old_time
+ self.assertEqual(result, None)
+ self.assertEqual(dummy_time.sleepvals, [0.0])
+
+ def test_select_raises_EINTR(self):
+ # i read the mock.patch docs. nerp.
+ dummy_select = DummySelect(select.error(errno.EINTR))
+ disp = DummyDispatcher()
+ disp.readable = lambda: True
+ map = {0: disp}
+ try:
+ from waitress import wasyncore
+
+ old_select = wasyncore.select
+ wasyncore.select = dummy_select
+ result = self._callFUT(map=map)
+ finally:
+ wasyncore.select = old_select
+ self.assertEqual(result, None)
+ self.assertEqual(dummy_select.selected, [([0], [], [0], 0.0)])
+
+ def test_select_raises_non_EINTR(self):
+ # i read the mock.patch docs. nerp.
+ dummy_select = DummySelect(select.error(errno.EBADF))
+ disp = DummyDispatcher()
+ disp.readable = lambda: True
+ map = {0: disp}
+ try:
+ from waitress import wasyncore
+
+ old_select = wasyncore.select
+ wasyncore.select = dummy_select
+ self.assertRaises(select.error, self._callFUT, map=map)
+ finally:
+ wasyncore.select = old_select
+ self.assertEqual(dummy_select.selected, [([0], [], [0], 0.0)])
+
+
+class Test_poll2(unittest.TestCase):
+ def _callFUT(self, timeout=0.0, map=None):
+ from waitress.wasyncore import poll2
+
+ return poll2(timeout, map)
+
+ def test_select_raises_EINTR(self):
+ # i read the mock.patch docs. nerp.
+ pollster = DummyPollster(exc=select.error(errno.EINTR))
+ dummy_select = DummySelect(pollster=pollster)
+ disp = DummyDispatcher()
+ map = {0: disp}
+ try:
+ from waitress import wasyncore
+
+ old_select = wasyncore.select
+ wasyncore.select = dummy_select
+ self._callFUT(map=map)
+ finally:
+ wasyncore.select = old_select
+ self.assertEqual(pollster.polled, [0.0])
+
+ def test_select_raises_non_EINTR(self):
+ # i read the mock.patch docs. nerp.
+ pollster = DummyPollster(exc=select.error(errno.EBADF))
+ dummy_select = DummySelect(pollster=pollster)
+ disp = DummyDispatcher()
+ map = {0: disp}
+ try:
+ from waitress import wasyncore
+
+ old_select = wasyncore.select
+ wasyncore.select = dummy_select
+ self.assertRaises(select.error, self._callFUT, map=map)
+ finally:
+ wasyncore.select = old_select
+ self.assertEqual(pollster.polled, [0.0])
+
+
+class Test_dispatcher(unittest.TestCase):
+ def _makeOne(self, sock=None, map=None):
+ from waitress.wasyncore import dispatcher
+
+ return dispatcher(sock=sock, map=map)
+
+ def test_unexpected_getpeername_exc(self):
+ sock = dummysocket()
+
+ def getpeername():
+ raise socket.error(errno.EBADF)
+
+ map = {}
+ sock.getpeername = getpeername
+ self.assertRaises(socket.error, self._makeOne, sock=sock, map=map)
+ self.assertEqual(map, {})
+
+ def test___repr__accepting(self):
+ sock = dummysocket()
+ map = {}
+ inst = self._makeOne(sock=sock, map=map)
+ inst.accepting = True
+ inst.addr = ("localhost", 8080)
+ result = repr(inst)
+ expected = "<waitress.wasyncore.dispatcher listening localhost:8080 at"
+ self.assertEqual(result[: len(expected)], expected)
+
+ def test___repr__connected(self):
+ sock = dummysocket()
+ map = {}
+ inst = self._makeOne(sock=sock, map=map)
+ inst.accepting = False
+ inst.connected = True
+ inst.addr = ("localhost", 8080)
+ result = repr(inst)
+ expected = "<waitress.wasyncore.dispatcher connected localhost:8080 at"
+ self.assertEqual(result[: len(expected)], expected)
+
+ def test_set_reuse_addr_with_socketerror(self):
+ sock = dummysocket()
+ map = {}
+
+ def setsockopt(*arg, **kw):
+ sock.errored = True
+ raise socket.error
+
+ sock.setsockopt = setsockopt
+ sock.getsockopt = lambda *arg: 0
+ inst = self._makeOne(sock=sock, map=map)
+ inst.set_reuse_addr()
+ self.assertTrue(sock.errored)
+
+ def test_connect_raise_socket_error(self):
+ sock = dummysocket()
+ map = {}
+ sock.connect_ex = lambda *arg: 1
+ inst = self._makeOne(sock=sock, map=map)
+ self.assertRaises(socket.error, inst.connect, 0)
+
+ def test_accept_raise_TypeError(self):
+ sock = dummysocket()
+ map = {}
+
+ def accept(*arg, **kw):
+ raise TypeError
+
+ sock.accept = accept
+ inst = self._makeOne(sock=sock, map=map)
+ result = inst.accept()
+ self.assertEqual(result, None)
+
+ def test_accept_raise_unexpected_socketerror(self):
+ sock = dummysocket()
+ map = {}
+
+ def accept(*arg, **kw):
+ raise socket.error(122)
+
+ sock.accept = accept
+ inst = self._makeOne(sock=sock, map=map)
+ self.assertRaises(socket.error, inst.accept)
+
+ def test_send_raise_EWOULDBLOCK(self):
+ sock = dummysocket()
+ map = {}
+
+ def send(*arg, **kw):
+ raise socket.error(errno.EWOULDBLOCK)
+
+ sock.send = send
+ inst = self._makeOne(sock=sock, map=map)
+ result = inst.send("a")
+ self.assertEqual(result, 0)
+
+ def test_send_raise_unexpected_socketerror(self):
+ sock = dummysocket()
+ map = {}
+
+ def send(*arg, **kw):
+ raise socket.error(122)
+
+ sock.send = send
+ inst = self._makeOne(sock=sock, map=map)
+ self.assertRaises(socket.error, inst.send, "a")
+
+ def test_recv_raises_disconnect(self):
+ sock = dummysocket()
+ map = {}
+
+ def recv(*arg, **kw):
+ raise socket.error(errno.ECONNRESET)
+
+ def handle_close():
+ inst.close_handled = True
+
+ sock.recv = recv
+ inst = self._makeOne(sock=sock, map=map)
+ inst.handle_close = handle_close
+ result = inst.recv(1)
+ self.assertEqual(result, b"")
+ self.assertTrue(inst.close_handled)
+
+ def test_close_raises_unknown_socket_error(self):
+ sock = dummysocket()
+ map = {}
+
+ def close():
+ raise socket.error(122)
+
+ sock.close = close
+ inst = self._makeOne(sock=sock, map=map)
+ inst.del_channel = lambda: None
+ self.assertRaises(socket.error, inst.close)
+
+ def test_handle_read_event_not_accepting_not_connected_connecting(self):
+ sock = dummysocket()
+ map = {}
+ inst = self._makeOne(sock=sock, map=map)
+
+ def handle_connect_event():
+ inst.connect_event_handled = True
+
+ def handle_read():
+ inst.read_handled = True
+
+ inst.handle_connect_event = handle_connect_event
+ inst.handle_read = handle_read
+ inst.accepting = False
+ inst.connected = False
+ inst.connecting = True
+ inst.handle_read_event()
+ self.assertTrue(inst.connect_event_handled)
+ self.assertTrue(inst.read_handled)
+
+ def test_handle_connect_event_getsockopt_returns_error(self):
+ sock = dummysocket()
+ sock.getsockopt = lambda *arg: 122
+ map = {}
+ inst = self._makeOne(sock=sock, map=map)
+ self.assertRaises(socket.error, inst.handle_connect_event)
+
+ def test_handle_expt_event_getsockopt_returns_error(self):
+ sock = dummysocket()
+ sock.getsockopt = lambda *arg: 122
+ map = {}
+ inst = self._makeOne(sock=sock, map=map)
+
+ def handle_close():
+ inst.close_handled = True
+
+ inst.handle_close = handle_close
+ inst.handle_expt_event()
+ self.assertTrue(inst.close_handled)
+
+ def test_handle_write_event_while_accepting(self):
+ sock = dummysocket()
+ map = {}
+ inst = self._makeOne(sock=sock, map=map)
+ inst.accepting = True
+ result = inst.handle_write_event()
+ self.assertEqual(result, None)
+
+ def test_handle_error_gardenpath(self):
+ sock = dummysocket()
+ map = {}
+ inst = self._makeOne(sock=sock, map=map)
+
+ def handle_close():
+ inst.close_handled = True
+
+ def compact_traceback(*arg, **kw):
+ return None, None, None, None
+
+ def log_info(self, *arg):
+ inst.logged_info = arg
+
+ inst.handle_close = handle_close
+ inst.compact_traceback = compact_traceback
+ inst.log_info = log_info
+ inst.handle_error()
+ self.assertTrue(inst.close_handled)
+ self.assertEqual(inst.logged_info, ("error",))
+
+ def test_handle_close(self):
+ sock = dummysocket()
+ map = {}
+ inst = self._makeOne(sock=sock, map=map)
+
+ def log_info(self, *arg):
+ inst.logged_info = arg
+
+ def close():
+ inst._closed = True
+
+ inst.log_info = log_info
+ inst.close = close
+ inst.handle_close()
+ self.assertTrue(inst._closed)
+
+ def test_handle_accepted(self):
+ sock = dummysocket()
+ map = {}
+ inst = self._makeOne(sock=sock, map=map)
+ inst.handle_accepted(sock, "1")
+ self.assertTrue(sock.closed)
+
+
+class Test_dispatcher_with_send(unittest.TestCase):
+ def _makeOne(self, sock=None, map=None):
+ from waitress.wasyncore import dispatcher_with_send
+
+ return dispatcher_with_send(sock=sock, map=map)
+
+ def test_writable(self):
+ sock = dummysocket()
+ map = {}
+ inst = self._makeOne(sock=sock, map=map)
+ inst.out_buffer = b"123"
+ inst.connected = True
+ self.assertTrue(inst.writable())
+
+
+class Test_close_all(unittest.TestCase):
+ def _callFUT(self, map=None, ignore_all=False):
+ from waitress.wasyncore import close_all
+
+ return close_all(map, ignore_all)
+
+ def test_socketerror_on_close_ebadf(self):
+ disp = DummyDispatcher(exc=socket.error(errno.EBADF))
+ map = {0: disp}
+ self._callFUT(map)
+ self.assertEqual(map, {})
+
+ def test_socketerror_on_close_non_ebadf(self):
+ disp = DummyDispatcher(exc=socket.error(errno.EAGAIN))
+ map = {0: disp}
+ self.assertRaises(socket.error, self._callFUT, map)
+
+ def test_reraised_exc_on_close(self):
+ disp = DummyDispatcher(exc=KeyboardInterrupt)
+ map = {0: disp}
+ self.assertRaises(KeyboardInterrupt, self._callFUT, map)
+
+ def test_unknown_exc_on_close(self):
+ disp = DummyDispatcher(exc=RuntimeError)
+ map = {0: disp}
+ self.assertRaises(RuntimeError, self._callFUT, map)
+
+
+class DummyDispatcher(object):
+ read_event_handled = False
+ write_event_handled = False
+ expt_event_handled = False
+ error_handled = False
+ close_handled = False
+ accepting = False
+
+ def __init__(self, exc=None):
+ self.exc = exc
+
+ def handle_read_event(self):
+ self.read_event_handled = True
+ if self.exc is not None:
+ raise self.exc
+
+ def handle_write_event(self):
+ self.write_event_handled = True
+ if self.exc is not None:
+ raise self.exc
+
+ def handle_expt_event(self):
+ self.expt_event_handled = True
+ if self.exc is not None:
+ raise self.exc
+
+ def handle_error(self):
+ self.error_handled = True
+
+ def handle_close(self):
+ self.close_handled = True
+
+ def readable(self):
+ return False
+
+ def writable(self):
+ return False
+
+ def close(self):
+ if self.exc is not None:
+ raise self.exc
+
+
+class DummyTime(object):
+ def __init__(self):
+ self.sleepvals = []
+
+ def sleep(self, val):
+ self.sleepvals.append(val)
+
+
+class DummySelect(object):
+ error = select.error
+
+ def __init__(self, exc=None, pollster=None):
+ self.selected = []
+ self.pollster = pollster
+ self.exc = exc
+
+ def select(self, *arg):
+ self.selected.append(arg)
+ if self.exc is not None:
+ raise self.exc
+
+ def poll(self):
+ return self.pollster
+
+
+class DummyPollster(object):
+ def __init__(self, exc=None):
+ self.polled = []
+ self.exc = exc
+
+ def poll(self, timeout):
+ self.polled.append(timeout)
+ if self.exc is not None:
+ raise self.exc
+ else: # pragma: no cover
+ return []
diff --git a/libs/waitress/trigger.py b/libs/waitress/trigger.py
new file mode 100644
index 000000000..6a57c1275
--- /dev/null
+++ b/libs/waitress/trigger.py
@@ -0,0 +1,203 @@
+##############################################################################
+#
+# Copyright (c) 2001-2005 Zope Foundation and Contributors.
+# All Rights Reserved.
+#
+# This software is subject to the provisions of the Zope Public License,
+# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
+# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
+# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
+# FOR A PARTICULAR PURPOSE
+#
+##############################################################################
+
+import os
+import socket
+import errno
+import threading
+
+from . import wasyncore
+
+# Wake up a call to select() running in the main thread.
+#
+# This is useful in a context where you are using Medusa's I/O
+# subsystem to deliver data, but the data is generated by another
+# thread. Normally, if Medusa is in the middle of a call to
+# select(), new output data generated by another thread will have
+# to sit until the call to select() either times out or returns.
+# If the trigger is 'pulled' by another thread, it should immediately
+# generate a READ event on the trigger object, which will force the
+# select() invocation to return.
+#
+# A common use for this facility: letting Medusa manage I/O for a
+# large number of connections; but routing each request through a
+# thread chosen from a fixed-size thread pool. When a thread is
+# acquired, a transaction is performed, but output data is
+# accumulated into buffers that will be emptied more efficiently
+# by Medusa. [picture a server that can process database queries
+# rapidly, but doesn't want to tie up threads waiting to send data
+# to low-bandwidth connections]
+#
+# The other major feature provided by this class is the ability to
+# move work back into the main thread: if you call pull_trigger()
+# with a thunk argument, when select() wakes up and receives the
+# event it will call your thunk from within that thread. The main
+# purpose of this is to remove the need to wrap thread locks around
+# Medusa's data structures, which normally do not need them. [To see
+# why this is true, imagine this scenario: A thread tries to push some
+# new data onto a channel's outgoing data queue at the same time that
+# the main thread is trying to remove some]
+
+
+class _triggerbase(object):
+ """OS-independent base class for OS-dependent trigger class."""
+
+ kind = None # subclass must set to "pipe" or "loopback"; used by repr
+
+ def __init__(self):
+ self._closed = False
+
+ # `lock` protects the `thunks` list from being traversed and
+ # appended to simultaneously.
+ self.lock = threading.Lock()
+
+ # List of no-argument callbacks to invoke when the trigger is
+ # pulled. These run in the thread running the wasyncore mainloop,
+ # regardless of which thread pulls the trigger.
+ self.thunks = []
+
+ def readable(self):
+ return True
+
+ def writable(self):
+ return False
+
+ def handle_connect(self):
+ pass
+
+ def handle_close(self):
+ self.close()
+
+ # Override the wasyncore close() method, because it doesn't know about
+ # (so can't close) all the gimmicks we have open. Subclass must
+ # supply a _close() method to do platform-specific closing work. _close()
+ # will be called iff we're not already closed.
+ def close(self):
+ if not self._closed:
+ self._closed = True
+ self.del_channel()
+ self._close() # subclass does OS-specific stuff
+
+ def pull_trigger(self, thunk=None):
+ if thunk:
+ with self.lock:
+ self.thunks.append(thunk)
+ self._physical_pull()
+
+ def handle_read(self):
+ try:
+ self.recv(8192)
+ except (OSError, socket.error):
+ return
+ with self.lock:
+ for thunk in self.thunks:
+ try:
+ thunk()
+ except:
+ nil, t, v, tbinfo = wasyncore.compact_traceback()
+ self.log_info(
+ "exception in trigger thunk: (%s:%s %s)" % (t, v, tbinfo)
+ )
+ self.thunks = []
+
+
+if os.name == "posix":
+
+ class trigger(_triggerbase, wasyncore.file_dispatcher):
+ kind = "pipe"
+
+ def __init__(self, map):
+ _triggerbase.__init__(self)
+ r, self.trigger = self._fds = os.pipe()
+ wasyncore.file_dispatcher.__init__(self, r, map=map)
+
+ def _close(self):
+ for fd in self._fds:
+ os.close(fd)
+ self._fds = []
+ wasyncore.file_dispatcher.close(self)
+
+ def _physical_pull(self):
+ os.write(self.trigger, b"x")
+
+
+else: # pragma: no cover
+ # Windows version; uses just sockets, because a pipe isn't select'able
+ # on Windows.
+
+ class trigger(_triggerbase, wasyncore.dispatcher):
+ kind = "loopback"
+
+ def __init__(self, map):
+ _triggerbase.__init__(self)
+
+ # Get a pair of connected sockets. The trigger is the 'w'
+ # end of the pair, which is connected to 'r'. 'r' is put
+ # in the wasyncore socket map. "pulling the trigger" then
+ # means writing something on w, which will wake up r.
+
+ w = socket.socket()
+ # Disable buffering -- pulling the trigger sends 1 byte,
+ # and we want that sent immediately, to wake up wasyncore's
+ # select() ASAP.
+ w.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+
+ count = 0
+ while True:
+ count += 1
+ # Bind to a local port; for efficiency, let the OS pick
+ # a free port for us.
+ # Unfortunately, stress tests showed that we may not
+ # be able to connect to that port ("Address already in
+ # use") despite that the OS picked it. This appears
+ # to be a race bug in the Windows socket implementation.
+ # So we loop until a connect() succeeds (almost always
+ # on the first try). See the long thread at
+ # http://mail.zope.org/pipermail/zope/2005-July/160433.html
+ # for hideous details.
+ a = socket.socket()
+ a.bind(("127.0.0.1", 0))
+ connect_address = a.getsockname() # assigned (host, port) pair
+ a.listen(1)
+ try:
+ w.connect(connect_address)
+ break # success
+ except socket.error as detail:
+ if detail[0] != errno.WSAEADDRINUSE:
+ # "Address already in use" is the only error
+ # I've seen on two WinXP Pro SP2 boxes, under
+ # Pythons 2.3.5 and 2.4.1.
+ raise
+ # (10048, 'Address already in use')
+ # assert count <= 2 # never triggered in Tim's tests
+ if count >= 10: # I've never seen it go above 2
+ a.close()
+ w.close()
+ raise RuntimeError("Cannot bind trigger!")
+ # Close `a` and try again. Note: I originally put a short
+ # sleep() here, but it didn't appear to help or hurt.
+ a.close()
+
+ r, addr = a.accept() # r becomes wasyncore's (self.)socket
+ a.close()
+ self.trigger = w
+ wasyncore.dispatcher.__init__(self, r, map=map)
+
+ def _close(self):
+ # self.socket is r, and self.trigger is w, from __init__
+ self.socket.close()
+ self.trigger.close()
+
+ def _physical_pull(self):
+ self.trigger.send(b"x")
diff --git a/libs/waitress/utilities.py b/libs/waitress/utilities.py
new file mode 100644
index 000000000..556bed20a
--- /dev/null
+++ b/libs/waitress/utilities.py
@@ -0,0 +1,320 @@
+##############################################################################
+#
+# Copyright (c) 2004 Zope Foundation and Contributors.
+# All Rights Reserved.
+#
+# This software is subject to the provisions of the Zope Public License,
+# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
+# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
+# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
+# FOR A PARTICULAR PURPOSE.
+#
+##############################################################################
+"""Utility functions
+"""
+
+import calendar
+import errno
+import logging
+import os
+import re
+import stat
+import time
+
+from .rfc7230 import OBS_TEXT, VCHAR
+
+logger = logging.getLogger("waitress")
+queue_logger = logging.getLogger("waitress.queue")
+
+
+def find_double_newline(s):
+ """Returns the position just after a double newline in the given string."""
+ pos = s.find(b"\r\n\r\n")
+
+ if pos >= 0:
+ pos += 4
+
+ return pos
+
+
+def concat(*args):
+ return "".join(args)
+
+
+def join(seq, field=" "):
+ return field.join(seq)
+
+
+def group(s):
+ return "(" + s + ")"
+
+
+short_days = ["sun", "mon", "tue", "wed", "thu", "fri", "sat"]
+long_days = [
+ "sunday",
+ "monday",
+ "tuesday",
+ "wednesday",
+ "thursday",
+ "friday",
+ "saturday",
+]
+
+short_day_reg = group(join(short_days, "|"))
+long_day_reg = group(join(long_days, "|"))
+
+daymap = {}
+
+for i in range(7):
+ daymap[short_days[i]] = i
+ daymap[long_days[i]] = i
+
+hms_reg = join(3 * [group("[0-9][0-9]")], ":")
+
+months = [
+ "jan",
+ "feb",
+ "mar",
+ "apr",
+ "may",
+ "jun",
+ "jul",
+ "aug",
+ "sep",
+ "oct",
+ "nov",
+ "dec",
+]
+
+monmap = {}
+
+for i in range(12):
+ monmap[months[i]] = i + 1
+
+months_reg = group(join(months, "|"))
+
+# From draft-ietf-http-v11-spec-07.txt/3.3.1
+# Sun, 06 Nov 1994 08:49:37 GMT ; RFC 822, updated by RFC 1123
+# Sunday, 06-Nov-94 08:49:37 GMT ; RFC 850, obsoleted by RFC 1036
+# Sun Nov 6 08:49:37 1994 ; ANSI C's asctime() format
+
+# rfc822 format
+rfc822_date = join(
+ [
+ concat(short_day_reg, ","), # day
+ group("[0-9][0-9]?"), # date
+ months_reg, # month
+ group("[0-9]+"), # year
+ hms_reg, # hour minute second
+ "gmt",
+ ],
+ " ",
+)
+
+rfc822_reg = re.compile(rfc822_date)
+
+
+def unpack_rfc822(m):
+ g = m.group
+
+ return (
+ int(g(4)), # year
+ monmap[g(3)], # month
+ int(g(2)), # day
+ int(g(5)), # hour
+ int(g(6)), # minute
+ int(g(7)), # second
+ 0,
+ 0,
+ 0,
+ )
+
+
+# rfc850 format
+rfc850_date = join(
+ [
+ concat(long_day_reg, ","),
+ join([group("[0-9][0-9]?"), months_reg, group("[0-9]+")], "-"),
+ hms_reg,
+ "gmt",
+ ],
+ " ",
+)
+
+rfc850_reg = re.compile(rfc850_date)
+# they actually unpack the same way
+def unpack_rfc850(m):
+ g = m.group
+ yr = g(4)
+
+ if len(yr) == 2:
+ yr = "19" + yr
+
+ return (
+ int(yr), # year
+ monmap[g(3)], # month
+ int(g(2)), # day
+ int(g(5)), # hour
+ int(g(6)), # minute
+ int(g(7)), # second
+ 0,
+ 0,
+ 0,
+ )
+
+
+# parsdate.parsedate - ~700/sec.
+# parse_http_date - ~1333/sec.
+
+weekdayname = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
+monthname = [
+ None,
+ "Jan",
+ "Feb",
+ "Mar",
+ "Apr",
+ "May",
+ "Jun",
+ "Jul",
+ "Aug",
+ "Sep",
+ "Oct",
+ "Nov",
+ "Dec",
+]
+
+
+def build_http_date(when):
+ year, month, day, hh, mm, ss, wd, y, z = time.gmtime(when)
+
+ return "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
+ weekdayname[wd],
+ day,
+ monthname[month],
+ year,
+ hh,
+ mm,
+ ss,
+ )
+
+
+def parse_http_date(d):
+ d = d.lower()
+ m = rfc850_reg.match(d)
+
+ if m and m.end() == len(d):
+ retval = int(calendar.timegm(unpack_rfc850(m)))
+ else:
+ m = rfc822_reg.match(d)
+
+ if m and m.end() == len(d):
+ retval = int(calendar.timegm(unpack_rfc822(m)))
+ else:
+ return 0
+
+ return retval
+
+
+# RFC 5234 Appendix B.1 "Core Rules":
+# VCHAR = %x21-7E
+# ; visible (printing) characters
+vchar_re = VCHAR
+
+# RFC 7230 Section 3.2.6 "Field Value Components":
+# quoted-string = DQUOTE *( qdtext / quoted-pair ) DQUOTE
+# qdtext = HTAB / SP /%x21 / %x23-5B / %x5D-7E / obs-text
+# obs-text = %x80-FF
+# quoted-pair = "\" ( HTAB / SP / VCHAR / obs-text )
+obs_text_re = OBS_TEXT
+
+# The '\\' between \x5b and \x5d is needed to escape \x5d (']')
+qdtext_re = "[\t \x21\x23-\x5b\\\x5d-\x7e" + obs_text_re + "]"
+
+quoted_pair_re = r"\\" + "([\t " + vchar_re + obs_text_re + "])"
+quoted_string_re = '"(?:(?:' + qdtext_re + ")|(?:" + quoted_pair_re + '))*"'
+
+quoted_string = re.compile(quoted_string_re)
+quoted_pair = re.compile(quoted_pair_re)
+
+
+def undquote(value):
+ if value.startswith('"') and value.endswith('"'):
+ # So it claims to be DQUOTE'ed, let's validate that
+ matches = quoted_string.match(value)
+
+ if matches and matches.end() == len(value):
+ # Remove the DQUOTE's from the value
+ value = value[1:-1]
+
+ # Remove all backslashes that are followed by a valid vchar or
+ # obs-text
+ value = quoted_pair.sub(r"\1", value)
+
+ return value
+ elif not value.startswith('"') and not value.endswith('"'):
+ return value
+
+ raise ValueError("Invalid quoting in value")
+
+
+def cleanup_unix_socket(path):
+ try:
+ st = os.stat(path)
+ except OSError as exc:
+ if exc.errno != errno.ENOENT:
+ raise # pragma: no cover
+ else:
+ if stat.S_ISSOCK(st.st_mode):
+ try:
+ os.remove(path)
+ except OSError: # pragma: no cover
+ # avoid race condition error during tests
+ pass
+
+
+class Error(object):
+ code = 500
+ reason = "Internal Server Error"
+
+ def __init__(self, body):
+ self.body = body
+
+ def to_response(self):
+ status = "%s %s" % (self.code, self.reason)
+ body = "%s\r\n\r\n%s" % (self.reason, self.body)
+ tag = "\r\n\r\n(generated by waitress)"
+ body = body + tag
+ headers = [("Content-Type", "text/plain")]
+
+ return status, headers, body
+
+ def wsgi_response(self, environ, start_response):
+ status, headers, body = self.to_response()
+ start_response(status, headers)
+ yield body
+
+
+class BadRequest(Error):
+ code = 400
+ reason = "Bad Request"
+
+
+class RequestHeaderFieldsTooLarge(BadRequest):
+ code = 431
+ reason = "Request Header Fields Too Large"
+
+
+class RequestEntityTooLarge(BadRequest):
+ code = 413
+ reason = "Request Entity Too Large"
+
+
+class InternalServerError(Error):
+ code = 500
+ reason = "Internal Server Error"
+
+
+class ServerNotImplemented(Error):
+ code = 501
+ reason = "Not Implemented"
diff --git a/libs/waitress/wasyncore.py b/libs/waitress/wasyncore.py
new file mode 100644
index 000000000..09bcafaa0
--- /dev/null
+++ b/libs/waitress/wasyncore.py
@@ -0,0 +1,693 @@
+# -*- Mode: Python -*-
+# Id: asyncore.py,v 2.51 2000/09/07 22:29:26 rushing Exp
+# Author: Sam Rushing <[email protected]>
+
+# ======================================================================
+# Copyright 1996 by Sam Rushing
+#
+# All Rights Reserved
+#
+# Permission to use, copy, modify, and distribute this software and
+# its documentation for any purpose and without fee is hereby
+# granted, provided that the above copyright notice appear in all
+# copies and that both that copyright notice and this permission
+# notice appear in supporting documentation, and that the name of Sam
+# Rushing not be used in advertising or publicity pertaining to
+# distribution of the software without specific, written prior
+# permission.
+#
+# SAM RUSHING DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE,
+# INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN
+# NO EVENT SHALL SAM RUSHING BE LIABLE FOR ANY SPECIAL, INDIRECT OR
+# CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
+# OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
+# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
+# CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+# ======================================================================
+
+"""Basic infrastructure for asynchronous socket service clients and servers.
+
+There are only two ways to have a program on a single processor do "more
+than one thing at a time". Multi-threaded programming is the simplest and
+most popular way to do it, but there is another very different technique,
+that lets you have nearly all the advantages of multi-threading, without
+actually using multiple threads. it's really only practical if your program
+is largely I/O bound. If your program is CPU bound, then pre-emptive
+scheduled threads are probably what you really need. Network servers are
+rarely CPU-bound, however.
+
+If your operating system supports the select() system call in its I/O
+library (and nearly all do), then you can use it to juggle multiple
+communication channels at once; doing other work while your I/O is taking
+place in the "background." Although this strategy can seem strange and
+complex, especially at first, it is in many ways easier to understand and
+control than multi-threaded programming. The module documented here solves
+many of the difficult problems for you, making the task of building
+sophisticated high-performance network servers and clients a snap.
+
+NB: this is a fork of asyncore from the stdlib that we've (the waitress
+developers) named 'wasyncore' to ensure forward compatibility, as asyncore
+in the stdlib will be dropped soon. It is neither a copy of the 2.7 asyncore
+nor the 3.X asyncore; it is a version compatible with either 2.7 or 3.X.
+"""
+
+from . import compat
+from . import utilities
+
+import logging
+import select
+import socket
+import sys
+import time
+import warnings
+
+import os
+from errno import (
+ EALREADY,
+ EINPROGRESS,
+ EWOULDBLOCK,
+ ECONNRESET,
+ EINVAL,
+ ENOTCONN,
+ ESHUTDOWN,
+ EISCONN,
+ EBADF,
+ ECONNABORTED,
+ EPIPE,
+ EAGAIN,
+ EINTR,
+ errorcode,
+)
+
+_DISCONNECTED = frozenset({ECONNRESET, ENOTCONN, ESHUTDOWN, ECONNABORTED, EPIPE, EBADF})
+
+try:
+ socket_map
+except NameError:
+ socket_map = {}
+
+
+def _strerror(err):
+ try:
+ return os.strerror(err)
+ except (TypeError, ValueError, OverflowError, NameError):
+ return "Unknown error %s" % err
+
+
+class ExitNow(Exception):
+ pass
+
+
+_reraised_exceptions = (ExitNow, KeyboardInterrupt, SystemExit)
+
+
+def read(obj):
+ try:
+ obj.handle_read_event()
+ except _reraised_exceptions:
+ raise
+ except:
+ obj.handle_error()
+
+
+def write(obj):
+ try:
+ obj.handle_write_event()
+ except _reraised_exceptions:
+ raise
+ except:
+ obj.handle_error()
+
+
+def _exception(obj):
+ try:
+ obj.handle_expt_event()
+ except _reraised_exceptions:
+ raise
+ except:
+ obj.handle_error()
+
+
+def readwrite(obj, flags):
+ try:
+ if flags & select.POLLIN:
+ obj.handle_read_event()
+ if flags & select.POLLOUT:
+ obj.handle_write_event()
+ if flags & select.POLLPRI:
+ obj.handle_expt_event()
+ if flags & (select.POLLHUP | select.POLLERR | select.POLLNVAL):
+ obj.handle_close()
+ except socket.error as e:
+ if e.args[0] not in _DISCONNECTED:
+ obj.handle_error()
+ else:
+ obj.handle_close()
+ except _reraised_exceptions:
+ raise
+ except:
+ obj.handle_error()
+
+
+def poll(timeout=0.0, map=None):
+ if map is None: # pragma: no cover
+ map = socket_map
+ if map:
+ r = []
+ w = []
+ e = []
+ for fd, obj in list(map.items()): # list() call FBO py3
+ is_r = obj.readable()
+ is_w = obj.writable()
+ if is_r:
+ r.append(fd)
+ # accepting sockets should not be writable
+ if is_w and not obj.accepting:
+ w.append(fd)
+ if is_r or is_w:
+ e.append(fd)
+ if [] == r == w == e:
+ time.sleep(timeout)
+ return
+
+ try:
+ r, w, e = select.select(r, w, e, timeout)
+ except select.error as err:
+ if err.args[0] != EINTR:
+ raise
+ else:
+ return
+
+ for fd in r:
+ obj = map.get(fd)
+ if obj is None: # pragma: no cover
+ continue
+ read(obj)
+
+ for fd in w:
+ obj = map.get(fd)
+ if obj is None: # pragma: no cover
+ continue
+ write(obj)
+
+ for fd in e:
+ obj = map.get(fd)
+ if obj is None: # pragma: no cover
+ continue
+ _exception(obj)
+
+
+def poll2(timeout=0.0, map=None):
+ # Use the poll() support added to the select module in Python 2.0
+ if map is None: # pragma: no cover
+ map = socket_map
+ if timeout is not None:
+ # timeout is in milliseconds
+ timeout = int(timeout * 1000)
+ pollster = select.poll()
+ if map:
+ for fd, obj in list(map.items()):
+ flags = 0
+ if obj.readable():
+ flags |= select.POLLIN | select.POLLPRI
+ # accepting sockets should not be writable
+ if obj.writable() and not obj.accepting:
+ flags |= select.POLLOUT
+ if flags:
+ pollster.register(fd, flags)
+
+ try:
+ r = pollster.poll(timeout)
+ except select.error as err:
+ if err.args[0] != EINTR:
+ raise
+ r = []
+
+ for fd, flags in r:
+ obj = map.get(fd)
+ if obj is None: # pragma: no cover
+ continue
+ readwrite(obj, flags)
+
+
+poll3 = poll2 # Alias for backward compatibility
+
+
+def loop(timeout=30.0, use_poll=False, map=None, count=None):
+ if map is None: # pragma: no cover
+ map = socket_map
+
+ if use_poll and hasattr(select, "poll"):
+ poll_fun = poll2
+ else:
+ poll_fun = poll
+
+ if count is None: # pragma: no cover
+ while map:
+ poll_fun(timeout, map)
+
+ else:
+ while map and count > 0:
+ poll_fun(timeout, map)
+ count = count - 1
+
+
+def compact_traceback():
+ t, v, tb = sys.exc_info()
+ tbinfo = []
+ if not tb: # pragma: no cover
+ raise AssertionError("traceback does not exist")
+ while tb:
+ tbinfo.append(
+ (
+ tb.tb_frame.f_code.co_filename,
+ tb.tb_frame.f_code.co_name,
+ str(tb.tb_lineno),
+ )
+ )
+ tb = tb.tb_next
+
+ # just to be safe
+ del tb
+
+ file, function, line = tbinfo[-1]
+ info = " ".join(["[%s|%s|%s]" % x for x in tbinfo])
+ return (file, function, line), t, v, info
+
+
+class dispatcher:
+
+ debug = False
+ connected = False
+ accepting = False
+ connecting = False
+ closing = False
+ addr = None
+ ignore_log_types = frozenset({"warning"})
+ logger = utilities.logger
+ compact_traceback = staticmethod(compact_traceback) # for testing
+
+ def __init__(self, sock=None, map=None):
+ if map is None: # pragma: no cover
+ self._map = socket_map
+ else:
+ self._map = map
+
+ self._fileno = None
+
+ if sock:
+ # Set to nonblocking just to make sure for cases where we
+ # get a socket from a blocking source.
+ sock.setblocking(0)
+ self.set_socket(sock, map)
+ self.connected = True
+ # The constructor no longer requires that the socket
+ # passed be connected.
+ try:
+ self.addr = sock.getpeername()
+ except socket.error as err:
+ if err.args[0] in (ENOTCONN, EINVAL):
+ # To handle the case where we got an unconnected
+ # socket.
+ self.connected = False
+ else:
+ # The socket is broken in some unknown way, alert
+ # the user and remove it from the map (to prevent
+ # polling of broken sockets).
+ self.del_channel(map)
+ raise
+ else:
+ self.socket = None
+
+ def __repr__(self):
+ status = [self.__class__.__module__ + "." + compat.qualname(self.__class__)]
+ if self.accepting and self.addr:
+ status.append("listening")
+ elif self.connected:
+ status.append("connected")
+ if self.addr is not None:
+ try:
+ status.append("%s:%d" % self.addr)
+ except TypeError: # pragma: no cover
+ status.append(repr(self.addr))
+ return "<%s at %#x>" % (" ".join(status), id(self))
+
+ __str__ = __repr__
+
+ def add_channel(self, map=None):
+ # self.log_info('adding channel %s' % self)
+ if map is None:
+ map = self._map
+ map[self._fileno] = self
+
+ def del_channel(self, map=None):
+ fd = self._fileno
+ if map is None:
+ map = self._map
+ if fd in map:
+ # self.log_info('closing channel %d:%s' % (fd, self))
+ del map[fd]
+ self._fileno = None
+
+ def create_socket(self, family=socket.AF_INET, type=socket.SOCK_STREAM):
+ self.family_and_type = family, type
+ sock = socket.socket(family, type)
+ sock.setblocking(0)
+ self.set_socket(sock)
+
+ def set_socket(self, sock, map=None):
+ self.socket = sock
+ self._fileno = sock.fileno()
+ self.add_channel(map)
+
+ def set_reuse_addr(self):
+ # try to re-use a server port if possible
+ try:
+ self.socket.setsockopt(
+ socket.SOL_SOCKET,
+ socket.SO_REUSEADDR,
+ self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) | 1,
+ )
+ except socket.error:
+ pass
+
+ # ==================================================
+ # predicates for select()
+ # these are used as filters for the lists of sockets
+ # to pass to select().
+ # ==================================================
+
+ def readable(self):
+ return True
+
+ def writable(self):
+ return True
+
+ # ==================================================
+ # socket object methods.
+ # ==================================================
+
+ def listen(self, num):
+ self.accepting = True
+ if os.name == "nt" and num > 5: # pragma: no cover
+ num = 5
+ return self.socket.listen(num)
+
+ def bind(self, addr):
+ self.addr = addr
+ return self.socket.bind(addr)
+
+ def connect(self, address):
+ self.connected = False
+ self.connecting = True
+ err = self.socket.connect_ex(address)
+ if (
+ err in (EINPROGRESS, EALREADY, EWOULDBLOCK)
+ or err == EINVAL
+ and os.name == "nt"
+ ): # pragma: no cover
+ self.addr = address
+ return
+ if err in (0, EISCONN):
+ self.addr = address
+ self.handle_connect_event()
+ else:
+ raise socket.error(err, errorcode[err])
+
+ def accept(self):
+ # XXX can return either an address pair or None
+ try:
+ conn, addr = self.socket.accept()
+ except TypeError:
+ return None
+ except socket.error as why:
+ if why.args[0] in (EWOULDBLOCK, ECONNABORTED, EAGAIN):
+ return None
+ else:
+ raise
+ else:
+ return conn, addr
+
+ def send(self, data):
+ try:
+ result = self.socket.send(data)
+ return result
+ except socket.error as why:
+ if why.args[0] == EWOULDBLOCK:
+ return 0
+ elif why.args[0] in _DISCONNECTED:
+ self.handle_close()
+ return 0
+ else:
+ raise
+
+ def recv(self, buffer_size):
+ try:
+ data = self.socket.recv(buffer_size)
+ if not data:
+ # a closed connection is indicated by signaling
+ # a read condition, and having recv() return 0.
+ self.handle_close()
+ return b""
+ else:
+ return data
+ except socket.error as why:
+ # winsock sometimes raises ENOTCONN
+ if why.args[0] in _DISCONNECTED:
+ self.handle_close()
+ return b""
+ else:
+ raise
+
+ def close(self):
+ self.connected = False
+ self.accepting = False
+ self.connecting = False
+ self.del_channel()
+ if self.socket is not None:
+ try:
+ self.socket.close()
+ except socket.error as why:
+ if why.args[0] not in (ENOTCONN, EBADF):
+ raise
+
+ # log and log_info may be overridden to provide more sophisticated
+ # logging and warning methods. In general, log is for 'hit' logging
+ # and 'log_info' is for informational, warning and error logging.
+
+ def log(self, message):
+ self.logger.log(logging.DEBUG, message)
+
+ def log_info(self, message, type="info"):
+ severity = {
+ "info": logging.INFO,
+ "warning": logging.WARN,
+ "error": logging.ERROR,
+ }
+ self.logger.log(severity.get(type, logging.INFO), message)
+
+ def handle_read_event(self):
+ if self.accepting:
+ # accepting sockets are never connected, they "spawn" new
+ # sockets that are connected
+ self.handle_accept()
+ elif not self.connected:
+ if self.connecting:
+ self.handle_connect_event()
+ self.handle_read()
+ else:
+ self.handle_read()
+
+ def handle_connect_event(self):
+ err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
+ if err != 0:
+ raise socket.error(err, _strerror(err))
+ self.handle_connect()
+ self.connected = True
+ self.connecting = False
+
+ def handle_write_event(self):
+ if self.accepting:
+ # Accepting sockets shouldn't get a write event.
+ # We will pretend it didn't happen.
+ return
+
+ if not self.connected:
+ if self.connecting:
+ self.handle_connect_event()
+ self.handle_write()
+
+ def handle_expt_event(self):
+ # handle_expt_event() is called if there might be an error on the
+ # socket, or if there is OOB data
+ # check for the error condition first
+ err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
+ if err != 0:
+ # we can get here when select.select() says that there is an
+ # exceptional condition on the socket
+ # since there is an error, we'll go ahead and close the socket
+ # like we would in a subclassed handle_read() that received no
+ # data
+ self.handle_close()
+ else:
+ self.handle_expt()
+
+ def handle_error(self):
+ nil, t, v, tbinfo = self.compact_traceback()
+
+ # sometimes a user repr method will crash.
+ try:
+ self_repr = repr(self)
+ except: # pragma: no cover
+ self_repr = "<__repr__(self) failed for object at %0x>" % id(self)
+
+ self.log_info(
+ "uncaptured python exception, closing channel %s (%s:%s %s)"
+ % (self_repr, t, v, tbinfo),
+ "error",
+ )
+ self.handle_close()
+
+ def handle_expt(self):
+ self.log_info("unhandled incoming priority event", "warning")
+
+ def handle_read(self):
+ self.log_info("unhandled read event", "warning")
+
+ def handle_write(self):
+ self.log_info("unhandled write event", "warning")
+
+ def handle_connect(self):
+ self.log_info("unhandled connect event", "warning")
+
+ def handle_accept(self):
+ pair = self.accept()
+ if pair is not None:
+ self.handle_accepted(*pair)
+
+ def handle_accepted(self, sock, addr):
+ sock.close()
+ self.log_info("unhandled accepted event", "warning")
+
+ def handle_close(self):
+ self.log_info("unhandled close event", "warning")
+ self.close()
+
+
+# ---------------------------------------------------------------------------
+# adds simple buffered output capability, useful for simple clients.
+# [for more sophisticated usage use asynchat.async_chat]
+# ---------------------------------------------------------------------------
+
+
+class dispatcher_with_send(dispatcher):
+ def __init__(self, sock=None, map=None):
+ dispatcher.__init__(self, sock, map)
+ self.out_buffer = b""
+
+ def initiate_send(self):
+ num_sent = 0
+ num_sent = dispatcher.send(self, self.out_buffer[:65536])
+ self.out_buffer = self.out_buffer[num_sent:]
+
+ handle_write = initiate_send
+
+ def writable(self):
+ return (not self.connected) or len(self.out_buffer)
+
+ def send(self, data):
+ if self.debug: # pragma: no cover
+ self.log_info("sending %s" % repr(data))
+ self.out_buffer = self.out_buffer + data
+ self.initiate_send()
+
+
+def close_all(map=None, ignore_all=False):
+ if map is None: # pragma: no cover
+ map = socket_map
+ for x in list(map.values()): # list() FBO py3
+ try:
+ x.close()
+ except socket.error as x:
+ if x.args[0] == EBADF:
+ pass
+ elif not ignore_all:
+ raise
+ except _reraised_exceptions:
+ raise
+ except:
+ if not ignore_all:
+ raise
+ map.clear()
+
+
+# Asynchronous File I/O:
+#
+# After a little research (reading man pages on various unixen, and
+# digging through the linux kernel), I've determined that select()
+# isn't meant for doing asynchronous file i/o.
+# Heartening, though - reading linux/mm/filemap.c shows that linux
+# supports asynchronous read-ahead. So _MOST_ of the time, the data
+# will be sitting in memory for us already when we go to read it.
+#
+# What other OS's (besides NT) support async file i/o? [VMS?]
+#
+# Regardless, this is useful for pipes, and stdin/stdout...
+
+if os.name == "posix":
+
+ class file_wrapper:
+ # Here we override just enough to make a file
+ # look like a socket for the purposes of asyncore.
+ # The passed fd is automatically os.dup()'d
+
+ def __init__(self, fd):
+ self.fd = os.dup(fd)
+
+ def __del__(self):
+ if self.fd >= 0:
+ warnings.warn("unclosed file %r" % self, compat.ResourceWarning)
+ self.close()
+
+ def recv(self, *args):
+ return os.read(self.fd, *args)
+
+ def send(self, *args):
+ return os.write(self.fd, *args)
+
+ def getsockopt(self, level, optname, buflen=None): # pragma: no cover
+ if level == socket.SOL_SOCKET and optname == socket.SO_ERROR and not buflen:
+ return 0
+ raise NotImplementedError(
+ "Only asyncore specific behaviour " "implemented."
+ )
+
+ read = recv
+ write = send
+
+ def close(self):
+ if self.fd < 0:
+ return
+ fd = self.fd
+ self.fd = -1
+ os.close(fd)
+
+ def fileno(self):
+ return self.fd
+
+ class file_dispatcher(dispatcher):
+ def __init__(self, fd, map=None):
+ dispatcher.__init__(self, None, map)
+ self.connected = True
+ try:
+ fd = fd.fileno()
+ except AttributeError:
+ pass
+ self.set_file(fd)
+ # set it to non-blocking mode
+ compat.set_nonblocking(fd)
+
+ def set_file(self, fd):
+ self.socket = file_wrapper(fd)
+ self._fileno = self.socket.fileno()
+ self.add_channel()