diff options
author | morpheus65535 <[email protected]> | 2021-10-20 20:46:22 -0400 |
---|---|---|
committer | morpheus65535 <[email protected]> | 2021-10-20 20:46:22 -0400 |
commit | 39fe3141d51b01479e7b585ad0b5ee5df1767226 (patch) | |
tree | 8f2617fbe98281498fd9617c5346a5bfdc841963 | |
parent | 8b0b965c8f2bbeccf277892c617bfa602028bd8c (diff) | |
download | bazarr-39fe3141d51b01479e7b585ad0b5ee5df1767226.tar.gz bazarr-39fe3141d51b01479e7b585ad0b5ee5df1767226.zip |
Moved back from gevent to waitress web server. This should prevent UI disconnection occurring during heavy task like syncing subtitles.v1.0.1-beta.1
35 files changed, 4955 insertions, 87 deletions
diff --git a/bazarr/app.py b/bazarr/app.py index 1a9e25f4f..a67810c57 100644 --- a/bazarr/app.py +++ b/bazarr/app.py @@ -28,7 +28,7 @@ def create_app(): else: app.config["DEBUG"] = False - socketio.init_app(app, path=base_url.rstrip('/')+'/api/socket.io', cors_allowed_origins='*', async_mode='gevent') + socketio.init_app(app, path=base_url.rstrip('/')+'/api/socket.io', cors_allowed_origins='*', async_mode='threading') return app diff --git a/bazarr/database.py b/bazarr/database.py index 9924f8d72..85f420110 100644 --- a/bazarr/database.py +++ b/bazarr/database.py @@ -3,7 +3,7 @@ import atexit import json import ast import logging -import gevent +import time from peewee import * from playhouse.sqliteq import SqliteQueueDatabase from playhouse.shortcuts import model_to_dict @@ -15,7 +15,7 @@ from config import settings, get_array_from from get_args import args database = SqliteQueueDatabase(os.path.join(args.config_dir, 'db', 'bazarr.db'), - use_gevent=True, + use_gevent=False, autostart=True, queue_max_size=256) migrator = SqliteMigrator(database) @@ -284,7 +284,7 @@ def init_db(): if not System.select().count(): System.insert({System.configured: '0', System.updated: '0'}).execute() except: - gevent.sleep(0.1) + time.sleep(0.1) else: tables_created = True diff --git a/bazarr/get_episodes.py b/bazarr/get_episodes.py index ebbec73ec..c93f2a693 100644 --- a/bazarr/get_episodes.py +++ b/bazarr/get_episodes.py @@ -3,7 +3,6 @@ import os import requests import logging -from gevent import sleep from peewee import DoesNotExist from database import get_exclusion_clause, TableEpisodes, TableShows @@ -45,7 +44,6 @@ def sync_episodes(series_id=None, send_event=True): series_count = len(seriesIdList) for i, seriesId in enumerate(seriesIdList): - sleep() if send_event: show_progress(id='episodes_progress', header='Syncing episodes...', @@ -70,7 +68,6 @@ def sync_episodes(series_id=None, send_event=True): episode['episodeFile'] = item[0] for episode in episodes: - sleep() if 'hasFile' in episode: if episode['hasFile'] is True: if 'episodeFile' in episode: @@ -91,7 +88,6 @@ def sync_episodes(series_id=None, send_event=True): removed_episodes = list(set(current_episodes_db_list) - set(current_episodes_sonarr)) for removed_episode in removed_episodes: - sleep() episode_to_delete = TableEpisodes.select(TableEpisodes.sonarrSeriesId, TableEpisodes.sonarrEpisodeId)\ .where(TableEpisodes.sonarrEpisodeId == removed_episode)\ .dicts()\ @@ -124,7 +120,6 @@ def sync_episodes(series_id=None, send_event=True): episodes_to_update_list = [i for i in episodes_to_update if i not in episode_in_db_list] for updated_episode in episodes_to_update_list: - sleep() TableEpisodes.update(updated_episode).where(TableEpisodes.sonarrEpisodeId == updated_episode['sonarrEpisodeId']).execute() altered_episodes.append([updated_episode['sonarrEpisodeId'], @@ -133,7 +128,6 @@ def sync_episodes(series_id=None, send_event=True): # Insert new episodes in DB for added_episode in episodes_to_add: - sleep() result = TableEpisodes.insert(added_episode).on_conflict(action='IGNORE').execute() if result > 0: altered_episodes.append([added_episode['sonarrEpisodeId'], @@ -147,7 +141,6 @@ def sync_episodes(series_id=None, send_event=True): # Store subtitles for added or modified episodes for i, altered_episode in enumerate(altered_episodes, 1): - sleep() store_subtitles(altered_episode[1], path_mappings.path_replace(altered_episode[1])) logging.debug('BAZARR All episodes synced from Sonarr into database.') diff --git a/bazarr/get_movies.py b/bazarr/get_movies.py index 8749ff6fd..0ae5b17c2 100644 --- a/bazarr/get_movies.py +++ b/bazarr/get_movies.py @@ -5,7 +5,6 @@ import requests import logging import operator from functools import reduce -from gevent import sleep from peewee import DoesNotExist from config import settings, url_radarr @@ -64,7 +63,6 @@ def update_movies(send_event=True): # Build new and updated movies movies_count = len(movies) for i, movie in enumerate(movies): - sleep() if send_event: show_progress(id='movies_progress', header='Syncing movies...', @@ -96,7 +94,6 @@ def update_movies(send_event=True): removed_movies = list(set(current_movies_db_list) - set(current_movies_radarr)) for removed_movie in removed_movies: - sleep() TableMovies.delete().where(TableMovies.tmdbId == removed_movie).execute() # Update movies in DB @@ -129,7 +126,6 @@ def update_movies(send_event=True): movies_to_update_list = [i for i in movies_to_update if i not in movies_in_db_list] for updated_movie in movies_to_update_list: - sleep() TableMovies.update(updated_movie).where(TableMovies.tmdbId == updated_movie['tmdbId']).execute() altered_movies.append([updated_movie['tmdbId'], updated_movie['path'], @@ -138,7 +134,6 @@ def update_movies(send_event=True): # Insert new movies in DB for added_movie in movies_to_add: - sleep() result = TableMovies.insert(added_movie).on_conflict(action='IGNORE').execute() if result > 0: altered_movies.append([added_movie['tmdbId'], @@ -153,7 +148,6 @@ def update_movies(send_event=True): # Store subtitles for added or modified movies for i, altered_movie in enumerate(altered_movies, 1): - sleep() store_subtitles_movie(altered_movie[1], path_mappings.path_replace_movie(altered_movie[1])) logging.debug('BAZARR All movies synced from Radarr into database.') diff --git a/bazarr/get_series.py b/bazarr/get_series.py index 45b0941f6..c92af286b 100644 --- a/bazarr/get_series.py +++ b/bazarr/get_series.py @@ -3,7 +3,6 @@ import os import requests import logging -from gevent import sleep from peewee import DoesNotExist from config import settings, url_sonarr @@ -51,7 +50,6 @@ def update_series(send_event=True): series_count = len(series) for i, show in enumerate(series): - sleep() if send_event: show_progress(id='series_progress', header='Syncing series...', @@ -78,7 +76,6 @@ def update_series(send_event=True): removed_series = list(set(current_shows_db_list) - set(current_shows_sonarr)) for series in removed_series: - sleep() TableShows.delete().where(TableShows.sonarrSeriesId == series).execute() if send_event: event_stream(type='series', action='delete', payload=series) @@ -106,7 +103,6 @@ def update_series(send_event=True): series_to_update_list = [i for i in series_to_update if i not in series_in_db_list] for updated_series in series_to_update_list: - sleep() TableShows.update(updated_series).where(TableShows.sonarrSeriesId == updated_series['sonarrSeriesId']).execute() if send_event: @@ -114,7 +110,6 @@ def update_series(send_event=True): # Insert new series in DB for added_series in series_to_add: - sleep() result = TableShows.insert(added_series).on_conflict(action='IGNORE').execute() if result: list_missing_subtitles(no=added_series['sonarrSeriesId']) diff --git a/bazarr/init.py b/bazarr/init.py index 6560ad911..f90c9947e 100644 --- a/bazarr/init.py +++ b/bazarr/init.py @@ -54,7 +54,7 @@ def is_virtualenv(): # deploy requirements.txt if not args.no_update: try: - import lxml, numpy, webrtcvad, gevent, geventwebsocket, setuptools + import lxml, numpy, webrtcvad, setuptools except ImportError: try: import pip diff --git a/bazarr/list_subtitles.py b/bazarr/list_subtitles.py index 10d1a87c2..aa81b025e 100644 --- a/bazarr/list_subtitles.py +++ b/bazarr/list_subtitles.py @@ -8,7 +8,6 @@ import re from guess_language import guess_language from subliminal_patch import core, search_external_subtitles from subzero.language import Language -from gevent import sleep from custom_lang import CustomLanguage from database import get_profiles_list, get_profile_cutoff, TableEpisodes, TableShows, TableMovies @@ -237,7 +236,6 @@ def list_missing_subtitles(no=None, epno=None, send_event=True): use_embedded_subs = settings.general.getboolean('use_embedded_subs') for episode_subtitles in episodes_subtitles: - sleep() missing_subtitles_text = '[]' if episode_subtitles['profileId']: # get desired subtitles @@ -348,7 +346,6 @@ def list_missing_subtitles_movies(no=None, send_event=True): use_embedded_subs = settings.general.getboolean('use_embedded_subs') for movie_subtitles in movies_subtitles: - sleep() missing_subtitles_text = '[]' if movie_subtitles['profileId']: # get desired subtitles @@ -450,7 +447,6 @@ def series_full_scan_subtitles(): count_episodes = len(episodes) for i, episode in enumerate(episodes): - sleep() show_progress(id='episodes_disk_scan', header='Full disk scan...', name='Episodes subtitles', @@ -470,7 +466,6 @@ def movies_full_scan_subtitles(): count_movies = len(movies) for i, movie in enumerate(movies): - sleep() show_progress(id='movies_disk_scan', header='Full disk scan...', name='Movies subtitles', @@ -491,7 +486,6 @@ def series_scan_subtitles(no): .dicts() for episode in episodes: - sleep() store_subtitles(episode['path'], path_mappings.path_replace(episode['path']), use_cache=False) @@ -502,7 +496,6 @@ def movies_scan_subtitles(no): .dicts() for movie in movies: - sleep() store_subtitles_movie(movie['path'], path_mappings.path_replace_movie(movie['path']), use_cache=False) diff --git a/bazarr/logger.py b/bazarr/logger.py index 44812a90f..1893bb534 100644 --- a/bazarr/logger.py +++ b/bazarr/logger.py @@ -117,9 +117,8 @@ def configure_logging(debug=False): logging.getLogger("srt").setLevel(logging.ERROR) logging.getLogger("SignalRCoreClient").setLevel(logging.CRITICAL) logging.getLogger("websocket").setLevel(logging.CRITICAL) - logging.getLogger("geventwebsocket.handler").setLevel(logging.WARNING) - logging.getLogger("geventwebsocket.handler").setLevel(logging.WARNING) + logging.getLogger("werkzeug").setLevel(logging.WARNING) logging.getLogger("engineio.server").setLevel(logging.WARNING) logging.getLogger("knowit").setLevel(logging.CRITICAL) logging.getLogger("enzyme").setLevel(logging.CRITICAL) diff --git a/bazarr/main.py b/bazarr/main.py index 92b077817..9b84ed0c8 100644 --- a/bazarr/main.py +++ b/bazarr/main.py @@ -1,13 +1,5 @@ # coding=utf-8 -# Gevent monkey patch if gevent available. If not, it will be installed on during the init process. -try: - from gevent import monkey, Greenlet, joinall -except ImportError: - pass -else: - monkey.patch_all() - import os bazarr_version = 'unknown' @@ -34,6 +26,7 @@ from urllib.parse import unquote from get_languages import load_language_in_db from flask import make_response, request, redirect, abort, render_template, Response, session, flash, url_for, \ send_file, stream_with_context +from threading import Thread from get_series import * from get_episodes import * @@ -202,11 +195,10 @@ def proxy(protocol, url): return dict(status=False, error=result.raise_for_status()) -greenlets = [] if settings.general.getboolean('use_sonarr'): - greenlets.append(Greenlet.spawn(sonarr_signalr_client.start)) + Thread(target=sonarr_signalr_client.start).start() if settings.general.getboolean('use_radarr'): - greenlets.append(Greenlet.spawn(radarr_signalr_client.start)) + Thread(target=radarr_signalr_client.start).start() if __name__ == "__main__": diff --git a/bazarr/scheduler.py b/bazarr/scheduler.py index 36d18191d..8cde8262c 100644 --- a/bazarr/scheduler.py +++ b/bazarr/scheduler.py @@ -12,7 +12,7 @@ if not args.no_update: from check_update import check_if_new_update, check_releases else: from check_update import check_releases -from apscheduler.schedulers.gevent import GeventScheduler +from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.triggers.interval import IntervalTrigger from apscheduler.triggers.cron import CronTrigger from apscheduler.triggers.date import DateTrigger @@ -30,7 +30,7 @@ class Scheduler: def __init__(self): self.__running_tasks = [] - self.aps_scheduler = GeventScheduler() + self.aps_scheduler = BackgroundScheduler() # task listener def task_listener_add(event): diff --git a/bazarr/server.py b/bazarr/server.py index 1a9053ee2..b414d8ee6 100644 --- a/bazarr/server.py +++ b/bazarr/server.py @@ -4,8 +4,7 @@ import warnings import logging import os import io -from gevent import pywsgi -from geventwebsocket.handler import WebSocketHandler +from waitress.server import create_server from get_args import args from config import settings, base_url @@ -27,23 +26,23 @@ class Server: # Mute Python3 BrokenPipeError warnings.simplefilter("ignore", BrokenPipeError) - self.server = pywsgi.WSGIServer((str(settings.general.ip), - int(args.port) if args.port else int(settings.general.port)), - app, - handler_class=WebSocketHandler) + self.server = create_server(app, + host=str(settings.general.ip), + port=int(args.port) if args.port else int(settings.general.port), + threads=100) def start(self): try: logging.info( 'BAZARR is started and waiting for request on http://' + str(settings.general.ip) + ':' + (str( args.port) if args.port else str(settings.general.port)) + str(base_url)) - self.server.serve_forever() + self.server.run() except KeyboardInterrupt: self.shutdown() def shutdown(self): try: - self.server.stop() + self.server.close() except Exception as e: logging.error('BAZARR Cannot stop Waitress: ' + repr(e)) else: @@ -60,7 +59,7 @@ class Server: def restart(self): try: - self.server.stop() + self.server.close() except Exception as e: logging.error('BAZARR Cannot stop Waitress: ' + repr(e)) else: diff --git a/bazarr/signalr_client.py b/bazarr/signalr_client.py index 9968bacfb..f8c3c7e8f 100644 --- a/bazarr/signalr_client.py +++ b/bazarr/signalr_client.py @@ -2,9 +2,9 @@ import logging -import gevent import json import os +import time from requests import Session from signalr import Connection from requests.exceptions import ConnectionError @@ -36,7 +36,6 @@ class SonarrSignalrClient: if get_sonarr_info.is_legacy(): logging.warning('BAZARR can only sync from Sonarr v3 SignalR feed to get real-time update. You should ' 'consider upgrading your version({}).'.format(get_sonarr_info.version())) - raise gevent.GreenletExit else: logging.info('BAZARR trying to connect to Sonarr SignalR feed...') self.configure() @@ -44,14 +43,13 @@ class SonarrSignalrClient: try: self.connection.start() except ConnectionError: - gevent.sleep(5) + time.sleep(5) except json.decoder.JSONDecodeError: logging.error("BAZARR cannot parse JSON returned by SignalR feed. This is caused by a permissions " "issue when Sonarr try to access its /config/.config directory. You should fix " "permissions on that directory and restart Sonarr. Also, if you're a Docker image " "user, you should make sure you properly defined PUID/PGID environment variables. " "Otherwise, please contact Sonarr support.") - raise gevent.GreenletExit else: logging.info('BAZARR SignalR client for Sonarr is connected and waiting for events.') finally: @@ -107,7 +105,7 @@ class RadarrSignalrClient: try: self.connection.start() except ConnectionError: - gevent.sleep(5) + time.sleep(5) def stop(self): logging.info('BAZARR SignalR client for Radarr is now disconnected.') diff --git a/libs/signalr/__init__.py b/libs/signalr/__init__.py index 3d155c5c6..7742eeb58 100644 --- a/libs/signalr/__init__.py +++ b/libs/signalr/__init__.py @@ -1,8 +1,3 @@ -from gevent import monkey - -monkey.patch_socket() -monkey.patch_ssl() - from ._connection import Connection -__version__ = '0.0.7' +__version__ = '0.0.12' diff --git a/libs/signalr/_connection.py b/libs/signalr/_connection.py index 377606f99..6471ba670 100644 --- a/libs/signalr/_connection.py +++ b/libs/signalr/_connection.py @@ -1,6 +1,6 @@ import json -import gevent import sys +from threading import Thread from signalr.events import EventHook from signalr.hubs import Hub from signalr.transports import AutoTransport @@ -15,14 +15,16 @@ class Connection: self.qs = {} self.__send_counter = -1 self.token = None + self.id = None self.data = None self.received = EventHook() self.error = EventHook() self.starting = EventHook() self.stopping = EventHook() self.exception = EventHook() + self.is_open = False self.__transport = AutoTransport(session, self) - self.__greenlet = None + self.__listener_thread = None self.started = False def handle_error(**kwargs): @@ -48,27 +50,32 @@ class Connection: negotiate_data = self.__transport.negotiate() self.token = negotiate_data['ConnectionToken'] + self.id = negotiate_data['ConnectionId'] listener = self.__transport.start() def wrapped_listener(): - try: - listener() - gevent.sleep() - except: - self.exception.fire(*sys.exc_info()) - - self.__greenlet = gevent.spawn(wrapped_listener) + while self.is_open: + try: + listener() + except: + self.exception.fire(*sys.exc_info()) + self.is_open = False + + self.is_open = True + self.__listener_thread = Thread(target=wrapped_listener) + self.__listener_thread.start() self.started = True def wait(self, timeout=30): - gevent.joinall([self.__greenlet], timeout) + Thread.join(self.__listener_thread, timeout) def send(self, data): self.__transport.send(data) def close(self): - gevent.kill(self.__greenlet) + self.is_open = False + self.__listener_thread.join() self.__transport.close() def register_hub(self, name): diff --git a/libs/signalr/transports/_sse_transport.py b/libs/signalr/transports/_sse_transport.py index 63d978643..7faaf936a 100644 --- a/libs/signalr/transports/_sse_transport.py +++ b/libs/signalr/transports/_sse_transport.py @@ -12,11 +12,16 @@ class ServerSentEventsTransport(Transport): return 'serverSentEvents' def start(self): - self.__response = sseclient.SSEClient(self._get_url('connect'), session=self._session) + connect_url = self._get_url('connect') + self.__response = iter(sseclient.SSEClient(connect_url, session=self._session)) self._session.get(self._get_url('start')) def _receive(): - for notification in self.__response: + try: + notification = next(self.__response) + except StopIteration: + return + else: if notification.data != 'initialized': self._handle_notification(notification.data) diff --git a/libs/signalr/transports/_transport.py b/libs/signalr/transports/_transport.py index c0d0d4278..af62672fd 100644 --- a/libs/signalr/transports/_transport.py +++ b/libs/signalr/transports/_transport.py @@ -1,13 +1,12 @@ from abc import abstractmethod import json import sys - +import threading if sys.version_info[0] < 3: from urllib import quote_plus else: from urllib.parse import quote_plus -import gevent class Transport: @@ -48,7 +47,7 @@ class Transport: if len(message) > 0: data = json.loads(message) self._connection.received.fire(**data) - gevent.sleep() + #thread.sleep() #TODO: investigate if we should sleep here def _get_url(self, action, **kwargs): args = kwargs.copy() diff --git a/libs/signalr/transports/_ws_transport.py b/libs/signalr/transports/_ws_transport.py index 14fefa6cc..4d9a80ad1 100644 --- a/libs/signalr/transports/_ws_transport.py +++ b/libs/signalr/transports/_ws_transport.py @@ -1,7 +1,6 @@ import json import sys -import gevent if sys.version_info[0] < 3: from urlparse import urlparse, urlunparse @@ -39,14 +38,14 @@ class WebSocketsTransport(Transport): self._session.get(self._get_url('start')) def _receive(): - for notification in self.ws: - self._handle_notification(notification) + notification = self.ws.recv() + self._handle_notification(notification) return _receive def send(self, data): self.ws.send(json.dumps(data)) - gevent.sleep() + #thread.sleep() #TODO: inveistage if we should sleep here or not def close(self): self.ws.close() diff --git a/libs/version.txt b/libs/version.txt index c3fc11bff..51a2ab472 100644 --- a/libs/version.txt +++ b/libs/version.txt @@ -44,6 +44,7 @@ subliminal=2.1.0dev tzlocal=2.1b1 twine=3.4.1 urllib3=1.23 +waitress=2.0.0 websocket-client=1.0.0 ## indirect dependencies diff --git a/libs/waitress/__init__.py b/libs/waitress/__init__.py new file mode 100644 index 000000000..bbb99da03 --- /dev/null +++ b/libs/waitress/__init__.py @@ -0,0 +1,46 @@ +import logging + +from waitress.server import create_server + + +def serve(app, **kw): + _server = kw.pop("_server", create_server) # test shim + _quiet = kw.pop("_quiet", False) # test shim + _profile = kw.pop("_profile", False) # test shim + if not _quiet: # pragma: no cover + # idempotent if logging has already been set up + logging.basicConfig() + server = _server(app, **kw) + if not _quiet: # pragma: no cover + server.print_listen("Serving on http://{}:{}") + if _profile: # pragma: no cover + profile("server.run()", globals(), locals(), (), False) + else: + server.run() + + +def serve_paste(app, global_conf, **kw): + serve(app, **kw) + return 0 + + +def profile(cmd, globals, locals, sort_order, callers): # pragma: no cover + # runs a command under the profiler and print profiling output at shutdown + import os + import profile + import pstats + import tempfile + + fd, fn = tempfile.mkstemp() + try: + profile.runctx(cmd, globals, locals, fn) + stats = pstats.Stats(fn) + stats.strip_dirs() + # calls,time,cumulative and cumulative,calls,time are useful + stats.sort_stats(*(sort_order or ("cumulative", "calls", "time"))) + if callers: + stats.print_callers(0.3) + else: + stats.print_stats(0.3) + finally: + os.remove(fn) diff --git a/libs/waitress/__main__.py b/libs/waitress/__main__.py new file mode 100644 index 000000000..9bcd07e59 --- /dev/null +++ b/libs/waitress/__main__.py @@ -0,0 +1,3 @@ +from waitress.runner import run # pragma nocover + +run() # pragma nocover diff --git a/libs/waitress/adjustments.py b/libs/waitress/adjustments.py new file mode 100644 index 000000000..466b5c4a9 --- /dev/null +++ b/libs/waitress/adjustments.py @@ -0,0 +1,523 @@ +############################################################################## +# +# Copyright (c) 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +"""Adjustments are tunable parameters. +""" +import getopt +import socket +import warnings + +from .compat import HAS_IPV6, WIN +from .proxy_headers import PROXY_HEADERS + +truthy = frozenset(("t", "true", "y", "yes", "on", "1")) + +KNOWN_PROXY_HEADERS = frozenset( + header.lower().replace("_", "-") for header in PROXY_HEADERS +) + + +def asbool(s): + """Return the boolean value ``True`` if the case-lowered value of string + input ``s`` is any of ``t``, ``true``, ``y``, ``on``, or ``1``, otherwise + return the boolean value ``False``. If ``s`` is the value ``None``, + return ``False``. If ``s`` is already one of the boolean values ``True`` + or ``False``, return it.""" + if s is None: + return False + if isinstance(s, bool): + return s + s = str(s).strip() + return s.lower() in truthy + + +def asoctal(s): + """Convert the given octal string to an actual number.""" + return int(s, 8) + + +def aslist_cronly(value): + if isinstance(value, str): + value = filter(None, [x.strip() for x in value.splitlines()]) + return list(value) + + +def aslist(value): + """Return a list of strings, separating the input based on newlines + and, if flatten=True (the default), also split on spaces within + each line.""" + values = aslist_cronly(value) + result = [] + for value in values: + subvalues = value.split() + result.extend(subvalues) + return result + + +def asset(value): + return set(aslist(value)) + + +def slash_fixed_str(s): + s = s.strip() + if s: + # always have a leading slash, replace any number of leading slashes + # with a single slash, and strip any trailing slashes + s = "/" + s.lstrip("/").rstrip("/") + return s + + +def str_iftruthy(s): + return str(s) if s else None + + +def as_socket_list(sockets): + """Checks if the elements in the list are of type socket and + removes them if not.""" + return [sock for sock in sockets if isinstance(sock, socket.socket)] + + +class _str_marker(str): + pass + + +class _int_marker(int): + pass + + +class _bool_marker: + pass + + +class Adjustments: + """This class contains tunable parameters.""" + + _params = ( + ("host", str), + ("port", int), + ("ipv4", asbool), + ("ipv6", asbool), + ("listen", aslist), + ("threads", int), + ("trusted_proxy", str_iftruthy), + ("trusted_proxy_count", int), + ("trusted_proxy_headers", asset), + ("log_untrusted_proxy_headers", asbool), + ("clear_untrusted_proxy_headers", asbool), + ("url_scheme", str), + ("url_prefix", slash_fixed_str), + ("backlog", int), + ("recv_bytes", int), + ("send_bytes", int), + ("outbuf_overflow", int), + ("outbuf_high_watermark", int), + ("inbuf_overflow", int), + ("connection_limit", int), + ("cleanup_interval", int), + ("channel_timeout", int), + ("log_socket_errors", asbool), + ("max_request_header_size", int), + ("max_request_body_size", int), + ("expose_tracebacks", asbool), + ("ident", str_iftruthy), + ("asyncore_loop_timeout", int), + ("asyncore_use_poll", asbool), + ("unix_socket", str), + ("unix_socket_perms", asoctal), + ("sockets", as_socket_list), + ("channel_request_lookahead", int), + ("server_name", str), + ) + + _param_map = dict(_params) + + # hostname or IP address to listen on + host = _str_marker("0.0.0.0") + + # TCP port to listen on + port = _int_marker(8080) + + listen = ["{}:{}".format(host, port)] + + # number of threads available for tasks + threads = 4 + + # Host allowed to overrid ``wsgi.url_scheme`` via header + trusted_proxy = None + + # How many proxies we trust when chained + # + # X-Forwarded-For: 192.0.2.1, "[2001:db8::1]" + # + # or + # + # Forwarded: for=192.0.2.1, For="[2001:db8::1]" + # + # means there were (potentially), two proxies involved. If we know there is + # only 1 valid proxy, then that initial IP address "192.0.2.1" is not + # trusted and we completely ignore it. If there are two trusted proxies in + # the path, this value should be set to a higher number. + trusted_proxy_count = None + + # Which of the proxy headers should we trust, this is a set where you + # either specify forwarded or one or more of forwarded-host, forwarded-for, + # forwarded-proto, forwarded-port. + trusted_proxy_headers = set() + + # Would you like waitress to log warnings about untrusted proxy headers + # that were encountered while processing the proxy headers? This only makes + # sense to set when you have a trusted_proxy, and you expect the upstream + # proxy server to filter invalid headers + log_untrusted_proxy_headers = False + + # Should waitress clear any proxy headers that are not deemed trusted from + # the environ? Change to True by default in 2.x + clear_untrusted_proxy_headers = _bool_marker + + # default ``wsgi.url_scheme`` value + url_scheme = "http" + + # default ``SCRIPT_NAME`` value, also helps reset ``PATH_INFO`` + # when nonempty + url_prefix = "" + + # server identity (sent in Server: header) + ident = "waitress" + + # backlog is the value waitress passes to pass to socket.listen() This is + # the maximum number of incoming TCP connections that will wait in an OS + # queue for an available channel. From listen(1): "If a connection + # request arrives when the queue is full, the client may receive an error + # with an indication of ECONNREFUSED or, if the underlying protocol + # supports retransmission, the request may be ignored so that a later + # reattempt at connection succeeds." + backlog = 1024 + + # recv_bytes is the argument to pass to socket.recv(). + recv_bytes = 8192 + + # deprecated setting controls how many bytes will be buffered before + # being flushed to the socket + send_bytes = 1 + + # A tempfile should be created if the pending output is larger than + # outbuf_overflow, which is measured in bytes. The default is 1MB. This + # is conservative. + outbuf_overflow = 1048576 + + # The app_iter will pause when pending output is larger than this value + # in bytes. + outbuf_high_watermark = 16777216 + + # A tempfile should be created if the pending input is larger than + # inbuf_overflow, which is measured in bytes. The default is 512K. This + # is conservative. + inbuf_overflow = 524288 + + # Stop creating new channels if too many are already active (integer). + # Each channel consumes at least one file descriptor, and, depending on + # the input and output body sizes, potentially up to three. The default + # is conservative, but you may need to increase the number of file + # descriptors available to the Waitress process on most platforms in + # order to safely change it (see ``ulimit -a`` "open files" setting). + # Note that this doesn't control the maximum number of TCP connections + # that can be waiting for processing; the ``backlog`` argument controls + # that. + connection_limit = 100 + + # Minimum seconds between cleaning up inactive channels. + cleanup_interval = 30 + + # Maximum seconds to leave an inactive connection open. + channel_timeout = 120 + + # Boolean: turn off to not log premature client disconnects. + log_socket_errors = True + + # maximum number of bytes of all request headers combined (256K default) + max_request_header_size = 262144 + + # maximum number of bytes in request body (1GB default) + max_request_body_size = 1073741824 + + # expose tracebacks of uncaught exceptions + expose_tracebacks = False + + # Path to a Unix domain socket to use. + unix_socket = None + + # Path to a Unix domain socket to use. + unix_socket_perms = 0o600 + + # The socket options to set on receiving a connection. It is a list of + # (level, optname, value) tuples. TCP_NODELAY disables the Nagle + # algorithm for writes (Waitress already buffers its writes). + socket_options = [ + (socket.SOL_TCP, socket.TCP_NODELAY, 1), + ] + + # The asyncore.loop timeout value + asyncore_loop_timeout = 1 + + # The asyncore.loop flag to use poll() instead of the default select(). + asyncore_use_poll = False + + # Enable IPv4 by default + ipv4 = True + + # Enable IPv6 by default + ipv6 = True + + # A list of sockets that waitress will use to accept connections. They can + # be used for e.g. socket activation + sockets = [] + + # By setting this to a value larger than zero, each channel stays readable + # and continues to read requests from the client even if a request is still + # running, until the number of buffered requests exceeds this value. + # This allows detecting if a client closed the connection while its request + # is being processed. + channel_request_lookahead = 0 + + # This setting controls the SERVER_NAME of the WSGI environment, this is + # only ever used if the remote client sent a request without a Host header + # (or when using the Proxy settings, without forwarding a Host header) + server_name = "waitress.invalid" + + def __init__(self, **kw): + + if "listen" in kw and ("host" in kw or "port" in kw): + raise ValueError("host or port may not be set if listen is set.") + + if "listen" in kw and "sockets" in kw: + raise ValueError("socket may not be set if listen is set.") + + if "sockets" in kw and ("host" in kw or "port" in kw): + raise ValueError("host or port may not be set if sockets is set.") + + if "sockets" in kw and "unix_socket" in kw: + raise ValueError("unix_socket may not be set if sockets is set") + + if "unix_socket" in kw and ("host" in kw or "port" in kw): + raise ValueError("unix_socket may not be set if host or port is set") + + if "unix_socket" in kw and "listen" in kw: + raise ValueError("unix_socket may not be set if listen is set") + + if "send_bytes" in kw: + warnings.warn( + "send_bytes will be removed in a future release", DeprecationWarning + ) + + for k, v in kw.items(): + if k not in self._param_map: + raise ValueError("Unknown adjustment %r" % k) + setattr(self, k, self._param_map[k](v)) + + if not isinstance(self.host, _str_marker) or not isinstance( + self.port, _int_marker + ): + self.listen = ["{}:{}".format(self.host, self.port)] + + enabled_families = socket.AF_UNSPEC + + if not self.ipv4 and not HAS_IPV6: # pragma: no cover + raise ValueError( + "IPv4 is disabled but IPv6 is not available. Cowardly refusing to start." + ) + + if self.ipv4 and not self.ipv6: + enabled_families = socket.AF_INET + + if not self.ipv4 and self.ipv6 and HAS_IPV6: + enabled_families = socket.AF_INET6 + + wanted_sockets = [] + hp_pairs = [] + for i in self.listen: + if ":" in i: + (host, port) = i.rsplit(":", 1) + + # IPv6 we need to make sure that we didn't split on the address + if "]" in port: # pragma: nocover + (host, port) = (i, str(self.port)) + else: + (host, port) = (i, str(self.port)) + + if WIN: # pragma: no cover + try: + # Try turning the port into an integer + port = int(port) + + except Exception: + raise ValueError( + "Windows does not support service names instead of port numbers" + ) + + try: + if "[" in host and "]" in host: # pragma: nocover + host = host.strip("[").rstrip("]") + + if host == "*": + host = None + + for s in socket.getaddrinfo( + host, + port, + enabled_families, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + socket.AI_PASSIVE, + ): + (family, socktype, proto, _, sockaddr) = s + + # It seems that getaddrinfo() may sometimes happily return + # the same result multiple times, this of course makes + # bind() very unhappy... + # + # Split on %, and drop the zone-index from the host in the + # sockaddr. Works around a bug in OS X whereby + # getaddrinfo() returns the same link-local interface with + # two different zone-indices (which makes no sense what so + # ever...) yet treats them equally when we attempt to bind(). + if ( + sockaddr[1] == 0 + or (sockaddr[0].split("%", 1)[0], sockaddr[1]) not in hp_pairs + ): + wanted_sockets.append((family, socktype, proto, sockaddr)) + hp_pairs.append((sockaddr[0].split("%", 1)[0], sockaddr[1])) + + except Exception: + raise ValueError("Invalid host/port specified.") + + if self.trusted_proxy_count is not None and self.trusted_proxy is None: + raise ValueError( + "trusted_proxy_count has no meaning without setting " "trusted_proxy" + ) + + elif self.trusted_proxy_count is None: + self.trusted_proxy_count = 1 + + if self.trusted_proxy_headers and self.trusted_proxy is None: + raise ValueError( + "trusted_proxy_headers has no meaning without setting " "trusted_proxy" + ) + + if self.trusted_proxy_headers: + self.trusted_proxy_headers = { + header.lower() for header in self.trusted_proxy_headers + } + + unknown_values = self.trusted_proxy_headers - KNOWN_PROXY_HEADERS + if unknown_values: + raise ValueError( + "Received unknown trusted_proxy_headers value (%s) expected one " + "of %s" + % (", ".join(unknown_values), ", ".join(KNOWN_PROXY_HEADERS)) + ) + + if ( + "forwarded" in self.trusted_proxy_headers + and self.trusted_proxy_headers - {"forwarded"} + ): + raise ValueError( + "The Forwarded proxy header and the " + "X-Forwarded-{By,Host,Proto,Port,For} headers are mutually " + "exclusive. Can't trust both!" + ) + + elif self.trusted_proxy is not None: + warnings.warn( + "No proxy headers were marked as trusted, but trusted_proxy was set. " + "Implicitly trusting X-Forwarded-Proto for backwards compatibility. " + "This will be removed in future versions of waitress.", + DeprecationWarning, + ) + self.trusted_proxy_headers = {"x-forwarded-proto"} + + if self.clear_untrusted_proxy_headers is _bool_marker: + warnings.warn( + "In future versions of Waitress clear_untrusted_proxy_headers will be " + "set to True by default. You may opt-out by setting this value to " + "False, or opt-in explicitly by setting this to True.", + DeprecationWarning, + ) + self.clear_untrusted_proxy_headers = False + + self.listen = wanted_sockets + + self.check_sockets(self.sockets) + + @classmethod + def parse_args(cls, argv): + """Pre-parse command line arguments for input into __init__. Note that + this does not cast values into adjustment types, it just creates a + dictionary suitable for passing into __init__, where __init__ does the + casting. + """ + long_opts = ["help", "call"] + for opt, cast in cls._params: + opt = opt.replace("_", "-") + if cast is asbool: + long_opts.append(opt) + long_opts.append("no-" + opt) + else: + long_opts.append(opt + "=") + + kw = { + "help": False, + "call": False, + } + + opts, args = getopt.getopt(argv, "", long_opts) + for opt, value in opts: + param = opt.lstrip("-").replace("-", "_") + + if param == "listen": + kw["listen"] = "{} {}".format(kw.get("listen", ""), value) + continue + + if param.startswith("no_"): + param = param[3:] + kw[param] = "false" + elif param in ("help", "call"): + kw[param] = True + elif cls._param_map[param] is asbool: + kw[param] = "true" + else: + kw[param] = value + + return kw, args + + @classmethod + def check_sockets(cls, sockets): + has_unix_socket = False + has_inet_socket = False + has_unsupported_socket = False + for sock in sockets: + if ( + sock.family == socket.AF_INET or sock.family == socket.AF_INET6 + ) and sock.type == socket.SOCK_STREAM: + has_inet_socket = True + elif ( + hasattr(socket, "AF_UNIX") + and sock.family == socket.AF_UNIX + and sock.type == socket.SOCK_STREAM + ): + has_unix_socket = True + else: + has_unsupported_socket = True + if has_unix_socket and has_inet_socket: + raise ValueError("Internet and UNIX sockets may not be mixed.") + if has_unsupported_socket: + raise ValueError("Only Internet or UNIX stream sockets may be used.") diff --git a/libs/waitress/buffers.py b/libs/waitress/buffers.py new file mode 100644 index 000000000..0086fe8f3 --- /dev/null +++ b/libs/waitress/buffers.py @@ -0,0 +1,308 @@ +############################################################################## +# +# Copyright (c) 2001-2004 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +"""Buffers +""" +from io import BytesIO + +# copy_bytes controls the size of temp. strings for shuffling data around. +COPY_BYTES = 1 << 18 # 256K + +# The maximum number of bytes to buffer in a simple string. +STRBUF_LIMIT = 8192 + + +class FileBasedBuffer: + + remain = 0 + + def __init__(self, file, from_buffer=None): + self.file = file + if from_buffer is not None: + from_file = from_buffer.getfile() + read_pos = from_file.tell() + from_file.seek(0) + while True: + data = from_file.read(COPY_BYTES) + if not data: + break + file.write(data) + self.remain = int(file.tell() - read_pos) + from_file.seek(read_pos) + file.seek(read_pos) + + def __len__(self): + return self.remain + + def __nonzero__(self): + return True + + __bool__ = __nonzero__ # py3 + + def append(self, s): + file = self.file + read_pos = file.tell() + file.seek(0, 2) + file.write(s) + file.seek(read_pos) + self.remain = self.remain + len(s) + + def get(self, numbytes=-1, skip=False): + file = self.file + if not skip: + read_pos = file.tell() + if numbytes < 0: + # Read all + res = file.read() + else: + res = file.read(numbytes) + if skip: + self.remain -= len(res) + else: + file.seek(read_pos) + return res + + def skip(self, numbytes, allow_prune=0): + if self.remain < numbytes: + raise ValueError( + "Can't skip %d bytes in buffer of %d bytes" % (numbytes, self.remain) + ) + self.file.seek(numbytes, 1) + self.remain = self.remain - numbytes + + def newfile(self): + raise NotImplementedError() + + def prune(self): + file = self.file + if self.remain == 0: + read_pos = file.tell() + file.seek(0, 2) + sz = file.tell() + file.seek(read_pos) + if sz == 0: + # Nothing to prune. + return + nf = self.newfile() + while True: + data = file.read(COPY_BYTES) + if not data: + break + nf.write(data) + self.file = nf + + def getfile(self): + return self.file + + def close(self): + if hasattr(self.file, "close"): + self.file.close() + self.remain = 0 + + +class TempfileBasedBuffer(FileBasedBuffer): + def __init__(self, from_buffer=None): + FileBasedBuffer.__init__(self, self.newfile(), from_buffer) + + def newfile(self): + from tempfile import TemporaryFile + + return TemporaryFile("w+b") + + +class BytesIOBasedBuffer(FileBasedBuffer): + def __init__(self, from_buffer=None): + if from_buffer is not None: + FileBasedBuffer.__init__(self, BytesIO(), from_buffer) + else: + # Shortcut. :-) + self.file = BytesIO() + + def newfile(self): + return BytesIO() + + +def _is_seekable(fp): + if hasattr(fp, "seekable"): + return fp.seekable() + return hasattr(fp, "seek") and hasattr(fp, "tell") + + +class ReadOnlyFileBasedBuffer(FileBasedBuffer): + # used as wsgi.file_wrapper + + def __init__(self, file, block_size=32768): + self.file = file + self.block_size = block_size # for __iter__ + + def prepare(self, size=None): + if _is_seekable(self.file): + start_pos = self.file.tell() + self.file.seek(0, 2) + end_pos = self.file.tell() + self.file.seek(start_pos) + fsize = end_pos - start_pos + if size is None: + self.remain = fsize + else: + self.remain = min(fsize, size) + return self.remain + + def get(self, numbytes=-1, skip=False): + # never read more than self.remain (it can be user-specified) + if numbytes == -1 or numbytes > self.remain: + numbytes = self.remain + file = self.file + if not skip: + read_pos = file.tell() + res = file.read(numbytes) + if skip: + self.remain -= len(res) + else: + file.seek(read_pos) + return res + + def __iter__(self): # called by task if self.filelike has no seek/tell + return self + + def next(self): + val = self.file.read(self.block_size) + if not val: + raise StopIteration + return val + + __next__ = next # py3 + + def append(self, s): + raise NotImplementedError + + +class OverflowableBuffer: + """ + This buffer implementation has four stages: + - No data + - Bytes-based buffer + - BytesIO-based buffer + - Temporary file storage + The first two stages are fastest for simple transfers. + """ + + overflowed = False + buf = None + strbuf = b"" # Bytes-based buffer. + + def __init__(self, overflow): + # overflow is the maximum to be stored in a StringIO buffer. + self.overflow = overflow + + def __len__(self): + buf = self.buf + if buf is not None: + # use buf.__len__ rather than len(buf) FBO of not getting + # OverflowError on Python 2 + return buf.__len__() + else: + return self.strbuf.__len__() + + def __nonzero__(self): + # use self.__len__ rather than len(self) FBO of not getting + # OverflowError on Python 2 + return self.__len__() > 0 + + __bool__ = __nonzero__ # py3 + + def _create_buffer(self): + strbuf = self.strbuf + if len(strbuf) >= self.overflow: + self._set_large_buffer() + else: + self._set_small_buffer() + buf = self.buf + if strbuf: + buf.append(self.strbuf) + self.strbuf = b"" + return buf + + def _set_small_buffer(self): + self.buf = BytesIOBasedBuffer(self.buf) + self.overflowed = False + + def _set_large_buffer(self): + self.buf = TempfileBasedBuffer(self.buf) + self.overflowed = True + + def append(self, s): + buf = self.buf + if buf is None: + strbuf = self.strbuf + if len(strbuf) + len(s) < STRBUF_LIMIT: + self.strbuf = strbuf + s + return + buf = self._create_buffer() + buf.append(s) + # use buf.__len__ rather than len(buf) FBO of not getting + # OverflowError on Python 2 + sz = buf.__len__() + if not self.overflowed: + if sz >= self.overflow: + self._set_large_buffer() + + def get(self, numbytes=-1, skip=False): + buf = self.buf + if buf is None: + strbuf = self.strbuf + if not skip: + return strbuf + buf = self._create_buffer() + return buf.get(numbytes, skip) + + def skip(self, numbytes, allow_prune=False): + buf = self.buf + if buf is None: + if allow_prune and numbytes == len(self.strbuf): + # We could slice instead of converting to + # a buffer, but that would eat up memory in + # large transfers. + self.strbuf = b"" + return + buf = self._create_buffer() + buf.skip(numbytes, allow_prune) + + def prune(self): + """ + A potentially expensive operation that removes all data + already retrieved from the buffer. + """ + buf = self.buf + if buf is None: + self.strbuf = b"" + return + buf.prune() + if self.overflowed: + # use buf.__len__ rather than len(buf) FBO of not getting + # OverflowError on Python 2 + sz = buf.__len__() + if sz < self.overflow: + # Revert to a faster buffer. + self._set_small_buffer() + + def getfile(self): + buf = self.buf + if buf is None: + buf = self._create_buffer() + return buf.getfile() + + def close(self): + buf = self.buf + if buf is not None: + buf.close() diff --git a/libs/waitress/channel.py b/libs/waitress/channel.py new file mode 100644 index 000000000..296a16aaf --- /dev/null +++ b/libs/waitress/channel.py @@ -0,0 +1,487 @@ +############################################################################## +# +# Copyright (c) 2001, 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +import socket +import threading +import time +import traceback + +from waitress.buffers import OverflowableBuffer, ReadOnlyFileBasedBuffer +from waitress.parser import HTTPRequestParser +from waitress.task import ErrorTask, WSGITask +from waitress.utilities import InternalServerError + +from . import wasyncore + + +class ClientDisconnected(Exception): + """ Raised when attempting to write to a closed socket.""" + + +class HTTPChannel(wasyncore.dispatcher): + """ + Setting self.requests = [somerequest] prevents more requests from being + received until the out buffers have been flushed. + + Setting self.requests = [] allows more requests to be received. + """ + + task_class = WSGITask + error_task_class = ErrorTask + parser_class = HTTPRequestParser + + # A request that has not been received yet completely is stored here + request = None + last_activity = 0 # Time of last activity + will_close = False # set to True to close the socket. + close_when_flushed = False # set to True to close the socket when flushed + sent_continue = False # used as a latch after sending 100 continue + total_outbufs_len = 0 # total bytes ready to send + current_outbuf_count = 0 # total bytes written to current outbuf + + # + # ASYNCHRONOUS METHODS (including __init__) + # + + def __init__(self, server, sock, addr, adj, map=None): + self.server = server + self.adj = adj + self.outbufs = [OverflowableBuffer(adj.outbuf_overflow)] + self.creation_time = self.last_activity = time.time() + self.sendbuf_len = sock.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF) + + # requests_lock used to push/pop requests and modify the request that is + # currently being created + self.requests_lock = threading.Lock() + # outbuf_lock used to access any outbuf (expected to use an RLock) + self.outbuf_lock = threading.Condition() + + wasyncore.dispatcher.__init__(self, sock, map=map) + + # Don't let wasyncore.dispatcher throttle self.addr on us. + self.addr = addr + self.requests = [] + + def check_client_disconnected(self): + """ + This method is inserted into the environment of any created task so it + may occasionally check if the client has disconnected and interrupt + execution. + """ + return not self.connected + + def writable(self): + # if there's data in the out buffer or we've been instructed to close + # the channel (possibly by our server maintenance logic), run + # handle_write + + return self.total_outbufs_len or self.will_close or self.close_when_flushed + + def handle_write(self): + # Precondition: there's data in the out buffer to be sent, or + # there's a pending will_close request + + if not self.connected: + # we dont want to close the channel twice + + return + + # try to flush any pending output + + if not self.requests: + # 1. There are no running tasks, so we don't need to try to lock + # the outbuf before sending + # 2. The data in the out buffer should be sent as soon as possible + # because it's either data left over from task output + # or a 100 Continue line sent within "received". + flush = self._flush_some + elif self.total_outbufs_len >= self.adj.send_bytes: + # 1. There's a running task, so we need to try to lock + # the outbuf before sending + # 2. Only try to send if the data in the out buffer is larger + # than self.adj_bytes to avoid TCP fragmentation + flush = self._flush_some_if_lockable + else: + # 1. There's not enough data in the out buffer to bother to send + # right now. + flush = None + + if flush: + try: + flush() + except OSError: + if self.adj.log_socket_errors: + self.logger.exception("Socket error") + self.will_close = True + except Exception: # pragma: nocover + self.logger.exception("Unexpected exception when flushing") + self.will_close = True + + if self.close_when_flushed and not self.total_outbufs_len: + self.close_when_flushed = False + self.will_close = True + + if self.will_close: + self.handle_close() + + def readable(self): + # We might want to read more requests. We can only do this if: + # 1. We're not already about to close the connection. + # 2. We're not waiting to flush remaining data before closing the + # connection + # 3. There are not too many tasks already queued + # 4. There's no data in the output buffer that needs to be sent + # before we potentially create a new task. + + return not ( + self.will_close + or self.close_when_flushed + or len(self.requests) > self.adj.channel_request_lookahead + or self.total_outbufs_len + ) + + def handle_read(self): + try: + data = self.recv(self.adj.recv_bytes) + except OSError: + if self.adj.log_socket_errors: + self.logger.exception("Socket error") + self.handle_close() + + return + + if data: + self.last_activity = time.time() + self.received(data) + else: + # Client disconnected. + self.connected = False + + def send_continue(self): + """ + Send a 100-Continue header to the client. This is either called from + receive (if no requests are running and the client expects it) or at + the end of service (if no more requests are queued and a request has + been read partially that expects it). + """ + self.request.expect_continue = False + outbuf_payload = b"HTTP/1.1 100 Continue\r\n\r\n" + num_bytes = len(outbuf_payload) + with self.outbuf_lock: + self.outbufs[-1].append(outbuf_payload) + self.current_outbuf_count += num_bytes + self.total_outbufs_len += num_bytes + self.sent_continue = True + self._flush_some() + self.request.completed = False + + def received(self, data): + """ + Receives input asynchronously and assigns one or more requests to the + channel. + """ + if not data: + return False + + with self.requests_lock: + while data: + if self.request is None: + self.request = self.parser_class(self.adj) + n = self.request.received(data) + + # if there are requests queued, we can not send the continue + # header yet since the responses need to be kept in order + if ( + self.request.expect_continue + and self.request.headers_finished + and not self.requests + and not self.sent_continue + ): + self.send_continue() + + if self.request.completed: + # The request (with the body) is ready to use. + self.sent_continue = False + + if not self.request.empty: + self.requests.append(self.request) + if len(self.requests) == 1: + # self.requests was empty before so the main thread + # is in charge of starting the task. Otherwise, + # service() will add a new task after each request + # has been processed + self.server.add_task(self) + self.request = None + + if n >= len(data): + break + data = data[n:] + + return True + + def _flush_some_if_lockable(self): + # Since our task may be appending to the outbuf, we try to acquire + # the lock, but we don't block if we can't. + + if self.outbuf_lock.acquire(False): + try: + self._flush_some() + + if self.total_outbufs_len < self.adj.outbuf_high_watermark: + self.outbuf_lock.notify() + finally: + self.outbuf_lock.release() + + def _flush_some(self): + # Send as much data as possible to our client + + sent = 0 + dobreak = False + + while True: + outbuf = self.outbufs[0] + # use outbuf.__len__ rather than len(outbuf) FBO of not getting + # OverflowError on 32-bit Python + outbuflen = outbuf.__len__() + + while outbuflen > 0: + chunk = outbuf.get(self.sendbuf_len) + num_sent = self.send(chunk) + + if num_sent: + outbuf.skip(num_sent, True) + outbuflen -= num_sent + sent += num_sent + self.total_outbufs_len -= num_sent + else: + # failed to write anything, break out entirely + dobreak = True + + break + else: + # self.outbufs[-1] must always be a writable outbuf + + if len(self.outbufs) > 1: + toclose = self.outbufs.pop(0) + try: + toclose.close() + except Exception: + self.logger.exception("Unexpected error when closing an outbuf") + else: + # caught up, done flushing for now + dobreak = True + + if dobreak: + break + + if sent: + self.last_activity = time.time() + + return True + + return False + + def handle_close(self): + with self.outbuf_lock: + for outbuf in self.outbufs: + try: + outbuf.close() + except Exception: + self.logger.exception( + "Unknown exception while trying to close outbuf" + ) + self.total_outbufs_len = 0 + self.connected = False + self.outbuf_lock.notify() + wasyncore.dispatcher.close(self) + + def add_channel(self, map=None): + """See wasyncore.dispatcher + + This hook keeps track of opened channels. + """ + wasyncore.dispatcher.add_channel(self, map) + self.server.active_channels[self._fileno] = self + + def del_channel(self, map=None): + """See wasyncore.dispatcher + + This hook keeps track of closed channels. + """ + fd = self._fileno # next line sets this to None + wasyncore.dispatcher.del_channel(self, map) + ac = self.server.active_channels + + if fd in ac: + del ac[fd] + + # + # SYNCHRONOUS METHODS + # + + def write_soon(self, data): + if not self.connected: + # if the socket is closed then interrupt the task so that it + # can cleanup possibly before the app_iter is exhausted + raise ClientDisconnected + + if data: + # the async mainloop might be popping data off outbuf; we can + # block here waiting for it because we're in a task thread + with self.outbuf_lock: + self._flush_outbufs_below_high_watermark() + + if not self.connected: + raise ClientDisconnected + num_bytes = len(data) + + if data.__class__ is ReadOnlyFileBasedBuffer: + # they used wsgi.file_wrapper + self.outbufs.append(data) + nextbuf = OverflowableBuffer(self.adj.outbuf_overflow) + self.outbufs.append(nextbuf) + self.current_outbuf_count = 0 + else: + if self.current_outbuf_count >= self.adj.outbuf_high_watermark: + # rotate to a new buffer if the current buffer has hit + # the watermark to avoid it growing unbounded + nextbuf = OverflowableBuffer(self.adj.outbuf_overflow) + self.outbufs.append(nextbuf) + self.current_outbuf_count = 0 + self.outbufs[-1].append(data) + self.current_outbuf_count += num_bytes + self.total_outbufs_len += num_bytes + + if self.total_outbufs_len >= self.adj.send_bytes: + self.server.pull_trigger() + + return num_bytes + + return 0 + + def _flush_outbufs_below_high_watermark(self): + # check first to avoid locking if possible + + if self.total_outbufs_len > self.adj.outbuf_high_watermark: + with self.outbuf_lock: + while ( + self.connected + and self.total_outbufs_len > self.adj.outbuf_high_watermark + ): + self.server.pull_trigger() + self.outbuf_lock.wait() + + def service(self): + """Execute one request. If there are more, we add another task to the + server at the end.""" + + request = self.requests[0] + + if request.error: + task = self.error_task_class(self, request) + else: + task = self.task_class(self, request) + + try: + if self.connected: + task.service() + else: + task.close_on_finish = True + except ClientDisconnected: + self.logger.info("Client disconnected while serving %s" % task.request.path) + task.close_on_finish = True + except Exception: + self.logger.exception("Exception while serving %s" % task.request.path) + + if not task.wrote_header: + if self.adj.expose_tracebacks: + body = traceback.format_exc() + else: + body = "The server encountered an unexpected internal server error" + req_version = request.version + req_headers = request.headers + err_request = self.parser_class(self.adj) + err_request.error = InternalServerError(body) + # copy some original request attributes to fulfill + # HTTP 1.1 requirements + err_request.version = req_version + try: + err_request.headers["CONNECTION"] = req_headers["CONNECTION"] + except KeyError: + pass + task = self.error_task_class(self, err_request) + try: + task.service() # must not fail + except ClientDisconnected: + task.close_on_finish = True + else: + task.close_on_finish = True + + if task.close_on_finish: + with self.requests_lock: + self.close_when_flushed = True + + for request in self.requests: + request.close() + self.requests = [] + else: + # before processing a new request, ensure there is not too + # much data in the outbufs waiting to be flushed + # NB: currently readable() returns False while we are + # flushing data so we know no new requests will come in + # that we need to account for, otherwise it'd be better + # to do this check at the start of the request instead of + # at the end to account for consecutive service() calls + + if len(self.requests) > 1: + self._flush_outbufs_below_high_watermark() + + # this is a little hacky but basically it's forcing the + # next request to create a new outbuf to avoid sharing + # outbufs across requests which can cause outbufs to + # not be deallocated regularly when a connection is open + # for a long time + + if self.current_outbuf_count > 0: + self.current_outbuf_count = self.adj.outbuf_high_watermark + + request.close() + + # Add new task to process the next request + with self.requests_lock: + self.requests.pop(0) + if self.connected and self.requests: + self.server.add_task(self) + elif ( + self.connected + and self.request is not None + and self.request.expect_continue + and self.request.headers_finished + and not self.sent_continue + ): + # A request waits for a signal to continue, but we could + # not send it until now because requests were being + # processed and the output needs to be kept in order + self.send_continue() + + if self.connected: + self.server.pull_trigger() + + self.last_activity = time.time() + + def cancel(self): + """ Cancels all pending / active requests """ + self.will_close = True + self.connected = False + self.last_activity = time.time() + self.requests = [] diff --git a/libs/waitress/compat.py b/libs/waitress/compat.py new file mode 100644 index 000000000..67543b9ca --- /dev/null +++ b/libs/waitress/compat.py @@ -0,0 +1,29 @@ +import platform + +# Fix for issue reported in https://github.com/Pylons/waitress/issues/138, +# Python on Windows may not define IPPROTO_IPV6 in socket. +import socket +import sys +import warnings + +# True if we are running on Windows +WIN = platform.system() == "Windows" + +MAXINT = sys.maxsize +HAS_IPV6 = socket.has_ipv6 + +if hasattr(socket, "IPPROTO_IPV6") and hasattr(socket, "IPV6_V6ONLY"): + IPPROTO_IPV6 = socket.IPPROTO_IPV6 + IPV6_V6ONLY = socket.IPV6_V6ONLY +else: # pragma: no cover + if WIN: + IPPROTO_IPV6 = 41 + IPV6_V6ONLY = 27 + else: + warnings.warn( + "OS does not support required IPv6 socket flags. This is requirement " + "for Waitress. Please open an issue at https://github.com/Pylons/waitress. " + "IPv6 support has been disabled.", + RuntimeWarning, + ) + HAS_IPV6 = False diff --git a/libs/waitress/parser.py b/libs/waitress/parser.py new file mode 100644 index 000000000..3b99921b0 --- /dev/null +++ b/libs/waitress/parser.py @@ -0,0 +1,439 @@ +############################################################################## +# +# Copyright (c) 2001, 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +"""HTTP Request Parser + +This server uses asyncore to accept connections and do initial +processing but threads to do work. +""" +from io import BytesIO +import re +from urllib import parse +from urllib.parse import unquote_to_bytes + +from waitress.buffers import OverflowableBuffer +from waitress.receiver import ChunkedReceiver, FixedStreamReceiver +from waitress.utilities import ( + BadRequest, + RequestEntityTooLarge, + RequestHeaderFieldsTooLarge, + ServerNotImplemented, + find_double_newline, +) + +from .rfc7230 import HEADER_FIELD + + +def unquote_bytes_to_wsgi(bytestring): + return unquote_to_bytes(bytestring).decode("latin-1") + + +class ParsingError(Exception): + pass + + +class TransferEncodingNotImplemented(Exception): + pass + + +class HTTPRequestParser: + """A structure that collects the HTTP request. + + Once the stream is completed, the instance is passed to + a server task constructor. + """ + + completed = False # Set once request is completed. + empty = False # Set if no request was made. + expect_continue = False # client sent "Expect: 100-continue" header + headers_finished = False # True when headers have been read + header_plus = b"" + chunked = False + content_length = 0 + header_bytes_received = 0 + body_bytes_received = 0 + body_rcv = None + version = "1.0" + error = None + connection_close = False + + # Other attributes: first_line, header, headers, command, uri, version, + # path, query, fragment + + def __init__(self, adj): + """ + adj is an Adjustments object. + """ + # headers is a mapping containing keys translated to uppercase + # with dashes turned into underscores. + self.headers = {} + self.adj = adj + + def received(self, data): + """ + Receives the HTTP stream for one request. Returns the number of + bytes consumed. Sets the completed flag once both the header and the + body have been received. + """ + + if self.completed: + return 0 # Can't consume any more. + + datalen = len(data) + br = self.body_rcv + + if br is None: + # In header. + max_header = self.adj.max_request_header_size + + s = self.header_plus + data + index = find_double_newline(s) + consumed = 0 + + if index >= 0: + # If the headers have ended, and we also have part of the body + # message in data we still want to validate we aren't going + # over our limit for received headers. + self.header_bytes_received += index + consumed = datalen - (len(s) - index) + else: + self.header_bytes_received += datalen + consumed = datalen + + # If the first line + headers is over the max length, we return a + # RequestHeaderFieldsTooLarge error rather than continuing to + # attempt to parse the headers. + + if self.header_bytes_received >= max_header: + self.parse_header(b"GET / HTTP/1.0\r\n") + self.error = RequestHeaderFieldsTooLarge( + "exceeds max_header of %s" % max_header + ) + self.completed = True + + return consumed + + if index >= 0: + # Header finished. + header_plus = s[:index] + + # Remove preceeding blank lines. This is suggested by + # https://tools.ietf.org/html/rfc7230#section-3.5 to support + # clients sending an extra CR LF after another request when + # using HTTP pipelining + header_plus = header_plus.lstrip() + + if not header_plus: + self.empty = True + self.completed = True + else: + try: + self.parse_header(header_plus) + except ParsingError as e: + self.error = BadRequest(e.args[0]) + self.completed = True + except TransferEncodingNotImplemented as e: + self.error = ServerNotImplemented(e.args[0]) + self.completed = True + else: + if self.body_rcv is None: + # no content-length header and not a t-e: chunked + # request + self.completed = True + + if self.content_length > 0: + max_body = self.adj.max_request_body_size + # we won't accept this request if the content-length + # is too large + + if self.content_length >= max_body: + self.error = RequestEntityTooLarge( + "exceeds max_body of %s" % max_body + ) + self.completed = True + self.headers_finished = True + + return consumed + + # Header not finished yet. + self.header_plus = s + + return datalen + else: + # In body. + consumed = br.received(data) + self.body_bytes_received += consumed + max_body = self.adj.max_request_body_size + + if self.body_bytes_received >= max_body: + # this will only be raised during t-e: chunked requests + self.error = RequestEntityTooLarge("exceeds max_body of %s" % max_body) + self.completed = True + elif br.error: + # garbage in chunked encoding input probably + self.error = br.error + self.completed = True + elif br.completed: + # The request (with the body) is ready to use. + self.completed = True + + if self.chunked: + # We've converted the chunked transfer encoding request + # body into a normal request body, so we know its content + # length; set the header here. We already popped the + # TRANSFER_ENCODING header in parse_header, so this will + # appear to the client to be an entirely non-chunked HTTP + # request with a valid content-length. + self.headers["CONTENT_LENGTH"] = str(br.__len__()) + + return consumed + + def parse_header(self, header_plus): + """ + Parses the header_plus block of text (the headers plus the + first line of the request). + """ + index = header_plus.find(b"\r\n") + + if index >= 0: + first_line = header_plus[:index].rstrip() + header = header_plus[index + 2 :] + else: + raise ParsingError("HTTP message header invalid") + + if b"\r" in first_line or b"\n" in first_line: + raise ParsingError("Bare CR or LF found in HTTP message") + + self.first_line = first_line # for testing + + lines = get_header_lines(header) + + headers = self.headers + + for line in lines: + header = HEADER_FIELD.match(line) + + if not header: + raise ParsingError("Invalid header") + + key, value = header.group("name", "value") + + if b"_" in key: + # TODO(xistence): Should we drop this request instead? + + continue + + # Only strip off whitespace that is considered valid whitespace by + # RFC7230, don't strip the rest + value = value.strip(b" \t") + key1 = key.upper().replace(b"-", b"_").decode("latin-1") + # If a header already exists, we append subsequent values + # separated by a comma. Applications already need to handle + # the comma separated values, as HTTP front ends might do + # the concatenation for you (behavior specified in RFC2616). + try: + headers[key1] += (b", " + value).decode("latin-1") + except KeyError: + headers[key1] = value.decode("latin-1") + + # command, uri, version will be bytes + command, uri, version = crack_first_line(first_line) + version = version.decode("latin-1") + command = command.decode("latin-1") + self.command = command + self.version = version + ( + self.proxy_scheme, + self.proxy_netloc, + self.path, + self.query, + self.fragment, + ) = split_uri(uri) + self.url_scheme = self.adj.url_scheme + connection = headers.get("CONNECTION", "") + + if version == "1.0": + if connection.lower() != "keep-alive": + self.connection_close = True + + if version == "1.1": + # since the server buffers data from chunked transfers and clients + # never need to deal with chunked requests, downstream clients + # should not see the HTTP_TRANSFER_ENCODING header; we pop it + # here + te = headers.pop("TRANSFER_ENCODING", "") + + # NB: We can not just call bare strip() here because it will also + # remove other non-printable characters that we explicitly do not + # want removed so that if someone attempts to smuggle a request + # with these characters we don't fall prey to it. + # + # For example \x85 is stripped by default, but it is not considered + # valid whitespace to be stripped by RFC7230. + encodings = [ + encoding.strip(" \t").lower() for encoding in te.split(",") if encoding + ] + + for encoding in encodings: + # Out of the transfer-codings listed in + # https://tools.ietf.org/html/rfc7230#section-4 we only support + # chunked at this time. + + # Note: the identity transfer-coding was removed in RFC7230: + # https://tools.ietf.org/html/rfc7230#appendix-A.2 and is thus + # not supported + + if encoding not in {"chunked"}: + raise TransferEncodingNotImplemented( + "Transfer-Encoding requested is not supported." + ) + + if encodings and encodings[-1] == "chunked": + self.chunked = True + buf = OverflowableBuffer(self.adj.inbuf_overflow) + self.body_rcv = ChunkedReceiver(buf) + elif encodings: # pragma: nocover + raise TransferEncodingNotImplemented( + "Transfer-Encoding requested is not supported." + ) + + expect = headers.get("EXPECT", "").lower() + self.expect_continue = expect == "100-continue" + + if connection.lower() == "close": + self.connection_close = True + + if not self.chunked: + try: + cl = int(headers.get("CONTENT_LENGTH", 0)) + except ValueError: + raise ParsingError("Content-Length is invalid") + + self.content_length = cl + + if cl > 0: + buf = OverflowableBuffer(self.adj.inbuf_overflow) + self.body_rcv = FixedStreamReceiver(cl, buf) + + def get_body_stream(self): + body_rcv = self.body_rcv + + if body_rcv is not None: + return body_rcv.getfile() + else: + return BytesIO() + + def close(self): + body_rcv = self.body_rcv + + if body_rcv is not None: + body_rcv.getbuf().close() + + +def split_uri(uri): + # urlsplit handles byte input by returning bytes on py3, so + # scheme, netloc, path, query, and fragment are bytes + + scheme = netloc = path = query = fragment = b"" + + # urlsplit below will treat this as a scheme-less netloc, thereby losing + # the original intent of the request. Here we shamelessly stole 4 lines of + # code from the CPython stdlib to parse out the fragment and query but + # leave the path alone. See + # https://github.com/python/cpython/blob/8c9e9b0cd5b24dfbf1424d1f253d02de80e8f5ef/Lib/urllib/parse.py#L465-L468 + # and https://github.com/Pylons/waitress/issues/260 + + if uri[:2] == b"//": + path = uri + + if b"#" in path: + path, fragment = path.split(b"#", 1) + + if b"?" in path: + path, query = path.split(b"?", 1) + else: + try: + scheme, netloc, path, query, fragment = parse.urlsplit(uri) + except UnicodeError: + raise ParsingError("Bad URI") + + return ( + scheme.decode("latin-1"), + netloc.decode("latin-1"), + unquote_bytes_to_wsgi(path), + query.decode("latin-1"), + fragment.decode("latin-1"), + ) + + +def get_header_lines(header): + """ + Splits the header into lines, putting multi-line headers together. + """ + r = [] + lines = header.split(b"\r\n") + + for line in lines: + if not line: + continue + + if b"\r" in line or b"\n" in line: + raise ParsingError( + 'Bare CR or LF found in header line "%s"' % str(line, "latin-1") + ) + + if line.startswith((b" ", b"\t")): + if not r: + # https://corte.si/posts/code/pathod/pythonservers/index.html + raise ParsingError('Malformed header line "%s"' % str(line, "latin-1")) + r[-1] += line + else: + r.append(line) + + return r + + +first_line_re = re.compile( + b"([^ ]+) " + b"((?:[^ :?#]+://[^ ?#/]*(?:[0-9]{1,5})?)?[^ ]+)" + b"(( HTTP/([0-9.]+))$|$)" +) + + +def crack_first_line(line): + m = first_line_re.match(line) + + if m is not None and m.end() == len(line): + if m.group(3): + version = m.group(5) + else: + version = b"" + method = m.group(1) + + # the request methods that are currently defined are all uppercase: + # https://www.iana.org/assignments/http-methods/http-methods.xhtml and + # the request method is case sensitive according to + # https://tools.ietf.org/html/rfc7231#section-4.1 + + # By disallowing anything but uppercase methods we save poor + # unsuspecting souls from sending lowercase HTTP methods to waitress + # and having the request complete, while servers like nginx drop the + # request onto the floor. + + if method != method.upper(): + raise ParsingError('Malformed HTTP method "%s"' % str(method, "latin-1")) + uri = m.group(2) + + return method, uri, version + else: + return b"", b"", b"" diff --git a/libs/waitress/proxy_headers.py b/libs/waitress/proxy_headers.py new file mode 100644 index 000000000..5d6164670 --- /dev/null +++ b/libs/waitress/proxy_headers.py @@ -0,0 +1,330 @@ +from collections import namedtuple + +from .utilities import BadRequest, logger, undquote + +PROXY_HEADERS = frozenset( + { + "X_FORWARDED_FOR", + "X_FORWARDED_HOST", + "X_FORWARDED_PROTO", + "X_FORWARDED_PORT", + "X_FORWARDED_BY", + "FORWARDED", + } +) + +Forwarded = namedtuple("Forwarded", ["by", "for_", "host", "proto"]) + + +class MalformedProxyHeader(Exception): + def __init__(self, header, reason, value): + self.header = header + self.reason = reason + self.value = value + super().__init__(header, reason, value) + + +def proxy_headers_middleware( + app, + trusted_proxy=None, + trusted_proxy_count=1, + trusted_proxy_headers=None, + clear_untrusted=True, + log_untrusted=False, + logger=logger, +): + def translate_proxy_headers(environ, start_response): + untrusted_headers = PROXY_HEADERS + remote_peer = environ["REMOTE_ADDR"] + if trusted_proxy == "*" or remote_peer == trusted_proxy: + try: + untrusted_headers = parse_proxy_headers( + environ, + trusted_proxy_count=trusted_proxy_count, + trusted_proxy_headers=trusted_proxy_headers, + logger=logger, + ) + except MalformedProxyHeader as ex: + logger.warning( + 'Malformed proxy header "%s" from "%s": %s value: %s', + ex.header, + remote_peer, + ex.reason, + ex.value, + ) + error = BadRequest('Header "{}" malformed.'.format(ex.header)) + return error.wsgi_response(environ, start_response) + + # Clear out the untrusted proxy headers + if clear_untrusted: + clear_untrusted_headers( + environ, untrusted_headers, log_warning=log_untrusted, logger=logger + ) + + return app(environ, start_response) + + return translate_proxy_headers + + +def parse_proxy_headers( + environ, trusted_proxy_count, trusted_proxy_headers, logger=logger +): + if trusted_proxy_headers is None: + trusted_proxy_headers = set() + + forwarded_for = [] + forwarded_host = forwarded_proto = forwarded_port = forwarded = "" + client_addr = None + untrusted_headers = set(PROXY_HEADERS) + + def raise_for_multiple_values(): + raise ValueError("Unspecified behavior for multiple values found in header") + + if "x-forwarded-for" in trusted_proxy_headers and "HTTP_X_FORWARDED_FOR" in environ: + try: + forwarded_for = [] + + for forward_hop in environ["HTTP_X_FORWARDED_FOR"].split(","): + forward_hop = forward_hop.strip() + forward_hop = undquote(forward_hop) + + # Make sure that all IPv6 addresses are surrounded by brackets, + # this is assuming that the IPv6 representation here does not + # include a port number. + + if "." not in forward_hop and ( + ":" in forward_hop and forward_hop[-1] != "]" + ): + forwarded_for.append("[{}]".format(forward_hop)) + else: + forwarded_for.append(forward_hop) + + forwarded_for = forwarded_for[-trusted_proxy_count:] + client_addr = forwarded_for[0] + + untrusted_headers.remove("X_FORWARDED_FOR") + except Exception as ex: + raise MalformedProxyHeader( + "X-Forwarded-For", str(ex), environ["HTTP_X_FORWARDED_FOR"] + ) + + if ( + "x-forwarded-host" in trusted_proxy_headers + and "HTTP_X_FORWARDED_HOST" in environ + ): + try: + forwarded_host_multiple = [] + + for forward_host in environ["HTTP_X_FORWARDED_HOST"].split(","): + forward_host = forward_host.strip() + forward_host = undquote(forward_host) + forwarded_host_multiple.append(forward_host) + + forwarded_host_multiple = forwarded_host_multiple[-trusted_proxy_count:] + forwarded_host = forwarded_host_multiple[0] + + untrusted_headers.remove("X_FORWARDED_HOST") + except Exception as ex: + raise MalformedProxyHeader( + "X-Forwarded-Host", str(ex), environ["HTTP_X_FORWARDED_HOST"] + ) + + if "x-forwarded-proto" in trusted_proxy_headers: + try: + forwarded_proto = undquote(environ.get("HTTP_X_FORWARDED_PROTO", "")) + if "," in forwarded_proto: + raise_for_multiple_values() + untrusted_headers.remove("X_FORWARDED_PROTO") + except Exception as ex: + raise MalformedProxyHeader( + "X-Forwarded-Proto", str(ex), environ["HTTP_X_FORWARDED_PROTO"] + ) + + if "x-forwarded-port" in trusted_proxy_headers: + try: + forwarded_port = undquote(environ.get("HTTP_X_FORWARDED_PORT", "")) + if "," in forwarded_port: + raise_for_multiple_values() + untrusted_headers.remove("X_FORWARDED_PORT") + except Exception as ex: + raise MalformedProxyHeader( + "X-Forwarded-Port", str(ex), environ["HTTP_X_FORWARDED_PORT"] + ) + + if "x-forwarded-by" in trusted_proxy_headers: + # Waitress itself does not use X-Forwarded-By, but we can not + # remove it so it can get set in the environ + untrusted_headers.remove("X_FORWARDED_BY") + + if "forwarded" in trusted_proxy_headers: + forwarded = environ.get("HTTP_FORWARDED", None) + untrusted_headers = PROXY_HEADERS - {"FORWARDED"} + + # If the Forwarded header exists, it gets priority + if forwarded: + proxies = [] + try: + for forwarded_element in forwarded.split(","): + # Remove whitespace that may have been introduced when + # appending a new entry + forwarded_element = forwarded_element.strip() + + forwarded_for = forwarded_host = forwarded_proto = "" + forwarded_port = forwarded_by = "" + + for pair in forwarded_element.split(";"): + pair = pair.lower() + + if not pair: + continue + + token, equals, value = pair.partition("=") + + if equals != "=": + raise ValueError('Invalid forwarded-pair missing "="') + + if token.strip() != token: + raise ValueError("Token may not be surrounded by whitespace") + + if value.strip() != value: + raise ValueError("Value may not be surrounded by whitespace") + + if token == "by": + forwarded_by = undquote(value) + + elif token == "for": + forwarded_for = undquote(value) + + elif token == "host": + forwarded_host = undquote(value) + + elif token == "proto": + forwarded_proto = undquote(value) + + else: + logger.warning("Unknown Forwarded token: %s" % token) + + proxies.append( + Forwarded( + forwarded_by, forwarded_for, forwarded_host, forwarded_proto + ) + ) + except Exception as ex: + raise MalformedProxyHeader("Forwarded", str(ex), environ["HTTP_FORWARDED"]) + + proxies = proxies[-trusted_proxy_count:] + + # Iterate backwards and fill in some values, the oldest entry that + # contains the information we expect is the one we use. We expect + # that intermediate proxies may re-write the host header or proto, + # but the oldest entry is the one that contains the information the + # client expects when generating URL's + # + # Forwarded: for="[2001:db8::1]";host="example.com:8443";proto="https" + # Forwarded: for=192.0.2.1;host="example.internal:8080" + # + # (After HTTPS header folding) should mean that we use as values: + # + # Host: example.com + # Protocol: https + # Port: 8443 + + for proxy in proxies[::-1]: + client_addr = proxy.for_ or client_addr + forwarded_host = proxy.host or forwarded_host + forwarded_proto = proxy.proto or forwarded_proto + + if forwarded_proto: + forwarded_proto = forwarded_proto.lower() + + if forwarded_proto not in {"http", "https"}: + raise MalformedProxyHeader( + "Forwarded Proto=" if forwarded else "X-Forwarded-Proto", + "unsupported proto value", + forwarded_proto, + ) + + # Set the URL scheme to the proxy provided proto + environ["wsgi.url_scheme"] = forwarded_proto + + if not forwarded_port: + if forwarded_proto == "http": + forwarded_port = "80" + + if forwarded_proto == "https": + forwarded_port = "443" + + if forwarded_host: + if ":" in forwarded_host and forwarded_host[-1] != "]": + host, port = forwarded_host.rsplit(":", 1) + host, port = host.strip(), str(port) + + # We trust the port in the Forwarded Host/X-Forwarded-Host over + # X-Forwarded-Port, or whatever we got from Forwarded + # Proto/X-Forwarded-Proto. + + if forwarded_port != port: + forwarded_port = port + + # We trust the proxy server's forwarded Host + environ["SERVER_NAME"] = host + environ["HTTP_HOST"] = forwarded_host + else: + # We trust the proxy server's forwarded Host + environ["SERVER_NAME"] = forwarded_host + environ["HTTP_HOST"] = forwarded_host + + if forwarded_port: + if forwarded_port not in {"443", "80"}: + environ["HTTP_HOST"] = "{}:{}".format( + forwarded_host, forwarded_port + ) + elif forwarded_port == "80" and environ["wsgi.url_scheme"] != "http": + environ["HTTP_HOST"] = "{}:{}".format( + forwarded_host, forwarded_port + ) + elif forwarded_port == "443" and environ["wsgi.url_scheme"] != "https": + environ["HTTP_HOST"] = "{}:{}".format( + forwarded_host, forwarded_port + ) + + if forwarded_port: + environ["SERVER_PORT"] = str(forwarded_port) + + if client_addr: + if ":" in client_addr and client_addr[-1] != "]": + addr, port = client_addr.rsplit(":", 1) + environ["REMOTE_ADDR"] = strip_brackets(addr.strip()) + environ["REMOTE_PORT"] = port.strip() + else: + environ["REMOTE_ADDR"] = strip_brackets(client_addr.strip()) + environ["REMOTE_HOST"] = environ["REMOTE_ADDR"] + + return untrusted_headers + + +def strip_brackets(addr): + if addr[0] == "[" and addr[-1] == "]": + return addr[1:-1] + return addr + + +def clear_untrusted_headers( + environ, untrusted_headers, log_warning=False, logger=logger +): + untrusted_headers_removed = [ + header + for header in untrusted_headers + if environ.pop("HTTP_" + header, False) is not False + ] + + if log_warning and untrusted_headers_removed: + untrusted_headers_removed = [ + "-".join(x.capitalize() for x in header.split("_")) + for header in untrusted_headers_removed + ] + logger.warning( + "Removed untrusted headers (%s). Waitress recommends these be " + "removed upstream.", + ", ".join(untrusted_headers_removed), + ) diff --git a/libs/waitress/receiver.py b/libs/waitress/receiver.py new file mode 100644 index 000000000..878528087 --- /dev/null +++ b/libs/waitress/receiver.py @@ -0,0 +1,186 @@ +############################################################################## +# +# Copyright (c) 2001, 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +"""Data Chunk Receiver +""" + +from waitress.utilities import BadRequest, find_double_newline + + +class FixedStreamReceiver: + + # See IStreamConsumer + completed = False + error = None + + def __init__(self, cl, buf): + self.remain = cl + self.buf = buf + + def __len__(self): + return self.buf.__len__() + + def received(self, data): + "See IStreamConsumer" + rm = self.remain + + if rm < 1: + self.completed = True # Avoid any chance of spinning + + return 0 + datalen = len(data) + + if rm <= datalen: + self.buf.append(data[:rm]) + self.remain = 0 + self.completed = True + + return rm + else: + self.buf.append(data) + self.remain -= datalen + + return datalen + + def getfile(self): + return self.buf.getfile() + + def getbuf(self): + return self.buf + + +class ChunkedReceiver: + + chunk_remainder = 0 + validate_chunk_end = False + control_line = b"" + chunk_end = b"" + all_chunks_received = False + trailer = b"" + completed = False + error = None + + # max_control_line = 1024 + # max_trailer = 65536 + + def __init__(self, buf): + self.buf = buf + + def __len__(self): + return self.buf.__len__() + + def received(self, s): + # Returns the number of bytes consumed. + + if self.completed: + return 0 + orig_size = len(s) + + while s: + rm = self.chunk_remainder + + if rm > 0: + # Receive the remainder of a chunk. + to_write = s[:rm] + self.buf.append(to_write) + written = len(to_write) + s = s[written:] + + self.chunk_remainder -= written + + if self.chunk_remainder == 0: + self.validate_chunk_end = True + elif self.validate_chunk_end: + s = self.chunk_end + s + + pos = s.find(b"\r\n") + + if pos < 0 and len(s) < 2: + self.chunk_end = s + s = b"" + else: + self.chunk_end = b"" + if pos == 0: + # Chop off the terminating CR LF from the chunk + s = s[2:] + else: + self.error = BadRequest("Chunk not properly terminated") + self.all_chunks_received = True + + # Always exit this loop + self.validate_chunk_end = False + elif not self.all_chunks_received: + # Receive a control line. + s = self.control_line + s + pos = s.find(b"\r\n") + + if pos < 0: + # Control line not finished. + self.control_line = s + s = b"" + else: + # Control line finished. + line = s[:pos] + s = s[pos + 2 :] + self.control_line = b"" + line = line.strip() + + if line: + # Begin a new chunk. + semi = line.find(b";") + + if semi >= 0: + # discard extension info. + line = line[:semi] + try: + sz = int(line.strip(), 16) # hexadecimal + except ValueError: # garbage in input + self.error = BadRequest("garbage in chunked encoding input") + sz = 0 + + if sz > 0: + # Start a new chunk. + self.chunk_remainder = sz + else: + # Finished chunks. + self.all_chunks_received = True + # else expect a control line. + else: + # Receive the trailer. + trailer = self.trailer + s + + if trailer.startswith(b"\r\n"): + # No trailer. + self.completed = True + + return orig_size - (len(trailer) - 2) + pos = find_double_newline(trailer) + + if pos < 0: + # Trailer not finished. + self.trailer = trailer + s = b"" + else: + # Finished the trailer. + self.completed = True + self.trailer = trailer[:pos] + + return orig_size - (len(trailer) - pos) + + return orig_size + + def getfile(self): + return self.buf.getfile() + + def getbuf(self): + return self.buf diff --git a/libs/waitress/rfc7230.py b/libs/waitress/rfc7230.py new file mode 100644 index 000000000..9b25fbd9a --- /dev/null +++ b/libs/waitress/rfc7230.py @@ -0,0 +1,50 @@ +""" +This contains a bunch of RFC7230 definitions and regular expressions that are +needed to properly parse HTTP messages. +""" + +import re + +WS = "[ \t]" +OWS = WS + "{0,}?" +RWS = WS + "{1,}?" +BWS = OWS + +# RFC 7230 Section 3.2.6 "Field Value Components": +# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" +# / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~" +# / DIGIT / ALPHA +# obs-text = %x80-FF +TCHAR = r"[!#$%&'*+\-.^_`|~0-9A-Za-z]" +OBS_TEXT = r"\x80-\xff" + +TOKEN = TCHAR + "{1,}" + +# RFC 5234 Appendix B.1 "Core Rules": +# VCHAR = %x21-7E +# ; visible (printing) characters +VCHAR = r"\x21-\x7e" + +# header-field = field-name ":" OWS field-value OWS +# field-name = token +# field-value = *( field-content / obs-fold ) +# field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ] +# field-vchar = VCHAR / obs-text + +# Errata from: https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189 +# changes field-content to: +# +# field-content = field-vchar [ 1*( SP / HTAB / field-vchar ) +# field-vchar ] + +FIELD_VCHAR = "[" + VCHAR + OBS_TEXT + "]" +# Field content is more greedy than the ABNF, in that it will match the whole value +FIELD_CONTENT = FIELD_VCHAR + "+(?:[ \t]+" + FIELD_VCHAR + "+)*" +# Which allows the field value here to just see if there is even a value in the first place +FIELD_VALUE = "(?:" + FIELD_CONTENT + ")?" + +HEADER_FIELD = re.compile( + ( + "^(?P<name>" + TOKEN + "):" + OWS + "(?P<value>" + FIELD_VALUE + ")" + OWS + "$" + ).encode("latin-1") +) diff --git a/libs/waitress/runner.py b/libs/waitress/runner.py new file mode 100644 index 000000000..949fdb9e9 --- /dev/null +++ b/libs/waitress/runner.py @@ -0,0 +1,299 @@ +############################################################################## +# +# Copyright (c) 2013 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +"""Command line runner. +""" + + +import getopt +import logging +import os +import os.path +import re +import sys + +from waitress import serve +from waitress.adjustments import Adjustments +from waitress.utilities import logger + +HELP = """\ +Usage: + + {0} [OPTS] MODULE:OBJECT + +Standard options: + + --help + Show this information. + + --call + Call the given object to get the WSGI application. + + --host=ADDR + Hostname or IP address on which to listen, default is '0.0.0.0', + which means "all IP addresses on this host". + + Note: May not be used together with --listen + + --port=PORT + TCP port on which to listen, default is '8080' + + Note: May not be used together with --listen + + --listen=ip:port + Tell waitress to listen on an ip port combination. + + Example: + + --listen=127.0.0.1:8080 + --listen=[::1]:8080 + --listen=*:8080 + + This option may be used multiple times to listen on multiple sockets. + A wildcard for the hostname is also supported and will bind to both + IPv4/IPv6 depending on whether they are enabled or disabled. + + --[no-]ipv4 + Toggle on/off IPv4 support. + + Example: + + --no-ipv4 + + This will disable IPv4 socket support. This affects wildcard matching + when generating the list of sockets. + + --[no-]ipv6 + Toggle on/off IPv6 support. + + Example: + + --no-ipv6 + + This will turn on IPv6 socket support. This affects wildcard matching + when generating a list of sockets. + + --unix-socket=PATH + Path of Unix socket. If a socket path is specified, a Unix domain + socket is made instead of the usual inet domain socket. + + Not available on Windows. + + --unix-socket-perms=PERMS + Octal permissions to use for the Unix domain socket, default is + '600'. + + --url-scheme=STR + Default wsgi.url_scheme value, default is 'http'. + + --url-prefix=STR + The ``SCRIPT_NAME`` WSGI environment value. Setting this to anything + except the empty string will cause the WSGI ``SCRIPT_NAME`` value to be + the value passed minus any trailing slashes you add, and it will cause + the ``PATH_INFO`` of any request which is prefixed with this value to + be stripped of the prefix. Default is the empty string. + + --ident=STR + Server identity used in the 'Server' header in responses. Default + is 'waitress'. + +Tuning options: + + --threads=INT + Number of threads used to process application logic, default is 4. + + --backlog=INT + Connection backlog for the server. Default is 1024. + + --recv-bytes=INT + Number of bytes to request when calling socket.recv(). Default is + 8192. + + --send-bytes=INT + Number of bytes to send to socket.send(). Default is 18000. + Multiples of 9000 should avoid partly-filled TCP packets. + + --outbuf-overflow=INT + A temporary file should be created if the pending output is larger + than this. Default is 1048576 (1MB). + + --outbuf-high-watermark=INT + The app_iter will pause when pending output is larger than this value + and will resume once enough data is written to the socket to fall below + this threshold. Default is 16777216 (16MB). + + --inbuf-overflow=INT + A temporary file should be created if the pending input is larger + than this. Default is 524288 (512KB). + + --connection-limit=INT + Stop creating new channels if too many are already active. + Default is 100. + + --cleanup-interval=INT + Minimum seconds between cleaning up inactive channels. Default + is 30. See '--channel-timeout'. + + --channel-timeout=INT + Maximum number of seconds to leave inactive connections open. + Default is 120. 'Inactive' is defined as 'has received no data + from the client and has sent no data to the client'. + + --[no-]log-socket-errors + Toggle whether premature client disconnect tracebacks ought to be + logged. On by default. + + --max-request-header-size=INT + Maximum size of all request headers combined. Default is 262144 + (256KB). + + --max-request-body-size=INT + Maximum size of request body. Default is 1073741824 (1GB). + + --[no-]expose-tracebacks + Toggle whether to expose tracebacks of unhandled exceptions to the + client. Off by default. + + --asyncore-loop-timeout=INT + The timeout value in seconds passed to asyncore.loop(). Default is 1. + + --asyncore-use-poll + The use_poll argument passed to ``asyncore.loop()``. Helps overcome + open file descriptors limit. Default is False. + + --channel-request-lookahead=INT + Allows channels to stay readable and buffer more requests up to the + given maximum even if a request is already being processed. This allows + detecting if a client closed the connection while its request is being + processed. Default is 0. + +""" + +RUNNER_PATTERN = re.compile( + r""" + ^ + (?P<module> + [a-z_][a-z0-9_]*(?:\.[a-z_][a-z0-9_]*)* + ) + : + (?P<object> + [a-z_][a-z0-9_]*(?:\.[a-z_][a-z0-9_]*)* + ) + $ + """, + re.I | re.X, +) + + +def match(obj_name): + matches = RUNNER_PATTERN.match(obj_name) + if not matches: + raise ValueError("Malformed application '{}'".format(obj_name)) + return matches.group("module"), matches.group("object") + + +def resolve(module_name, object_name): + """Resolve a named object in a module.""" + # We cast each segments due to an issue that has been found to manifest + # in Python 2.6.6, but not 2.6.8, and may affect other revisions of Python + # 2.6 and 2.7, whereby ``__import__`` chokes if the list passed in the + # ``fromlist`` argument are unicode strings rather than 8-bit strings. + # The error triggered is "TypeError: Item in ``fromlist '' not a string". + # My guess is that this was fixed by checking against ``basestring`` + # rather than ``str`` sometime between the release of 2.6.6 and 2.6.8, + # but I've yet to go over the commits. I know, however, that the NEWS + # file makes no mention of such a change to the behaviour of + # ``__import__``. + segments = [str(segment) for segment in object_name.split(".")] + obj = __import__(module_name, fromlist=segments[:1]) + for segment in segments: + obj = getattr(obj, segment) + return obj + + +def show_help(stream, name, error=None): # pragma: no cover + if error is not None: + print("Error: {}\n".format(error), file=stream) + print(HELP.format(name), file=stream) + + +def show_exception(stream): + exc_type, exc_value = sys.exc_info()[:2] + args = getattr(exc_value, "args", None) + print( + ("There was an exception ({}) importing your module.\n").format( + exc_type.__name__, + ), + file=stream, + ) + if args: + print("It had these arguments: ", file=stream) + for idx, arg in enumerate(args, start=1): + print("{}. {}\n".format(idx, arg), file=stream) + else: + print("It had no arguments.", file=stream) + + +def run(argv=sys.argv, _serve=serve): + """Command line runner.""" + name = os.path.basename(argv[0]) + + try: + kw, args = Adjustments.parse_args(argv[1:]) + except getopt.GetoptError as exc: + show_help(sys.stderr, name, str(exc)) + return 1 + + if kw["help"]: + show_help(sys.stdout, name) + return 0 + + if len(args) != 1: + show_help(sys.stderr, name, "Specify one application only") + return 1 + + # set a default level for the logger only if it hasn't been set explicitly + # note that this level does not override any parent logger levels, + # handlers, etc but without it no log messages are emitted by default + if logger.level == logging.NOTSET: + logger.setLevel(logging.INFO) + + try: + module, obj_name = match(args[0]) + except ValueError as exc: + show_help(sys.stderr, name, str(exc)) + show_exception(sys.stderr) + return 1 + + # Add the current directory onto sys.path + sys.path.append(os.getcwd()) + + # Get the WSGI function. + try: + app = resolve(module, obj_name) + except ImportError: + show_help(sys.stderr, name, "Bad module '{}'".format(module)) + show_exception(sys.stderr) + return 1 + except AttributeError: + show_help(sys.stderr, name, "Bad object name '{}'".format(obj_name)) + show_exception(sys.stderr) + return 1 + if kw["call"]: + app = app() + + # These arguments are specific to the runner, not waitress itself. + del kw["call"], kw["help"] + + _serve(app, **kw) + return 0 diff --git a/libs/waitress/server.py b/libs/waitress/server.py new file mode 100644 index 000000000..55cffe9ba --- /dev/null +++ b/libs/waitress/server.py @@ -0,0 +1,417 @@ +############################################################################## +# +# Copyright (c) 2001, 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## + +import os +import os.path +import socket +import time + +from waitress import trigger +from waitress.adjustments import Adjustments +from waitress.channel import HTTPChannel +from waitress.compat import IPPROTO_IPV6, IPV6_V6ONLY +from waitress.task import ThreadedTaskDispatcher +from waitress.utilities import cleanup_unix_socket + +from . import wasyncore +from .proxy_headers import proxy_headers_middleware + + +def create_server( + application, + map=None, + _start=True, # test shim + _sock=None, # test shim + _dispatcher=None, # test shim + **kw # adjustments +): + """ + if __name__ == '__main__': + server = create_server(app) + server.run() + """ + if application is None: + raise ValueError( + 'The "app" passed to ``create_server`` was ``None``. You forgot ' + "to return a WSGI app within your application." + ) + adj = Adjustments(**kw) + + if map is None: # pragma: nocover + map = {} + + dispatcher = _dispatcher + if dispatcher is None: + dispatcher = ThreadedTaskDispatcher() + dispatcher.set_thread_count(adj.threads) + + if adj.unix_socket and hasattr(socket, "AF_UNIX"): + sockinfo = (socket.AF_UNIX, socket.SOCK_STREAM, None, None) + return UnixWSGIServer( + application, + map, + _start, + _sock, + dispatcher=dispatcher, + adj=adj, + sockinfo=sockinfo, + ) + + effective_listen = [] + last_serv = None + if not adj.sockets: + for sockinfo in adj.listen: + # When TcpWSGIServer is called, it registers itself in the map. This + # side-effect is all we need it for, so we don't store a reference to + # or return it to the user. + last_serv = TcpWSGIServer( + application, + map, + _start, + _sock, + dispatcher=dispatcher, + adj=adj, + sockinfo=sockinfo, + ) + effective_listen.append( + (last_serv.effective_host, last_serv.effective_port) + ) + + for sock in adj.sockets: + sockinfo = (sock.family, sock.type, sock.proto, sock.getsockname()) + if sock.family == socket.AF_INET or sock.family == socket.AF_INET6: + last_serv = TcpWSGIServer( + application, + map, + _start, + sock, + dispatcher=dispatcher, + adj=adj, + bind_socket=False, + sockinfo=sockinfo, + ) + effective_listen.append( + (last_serv.effective_host, last_serv.effective_port) + ) + elif hasattr(socket, "AF_UNIX") and sock.family == socket.AF_UNIX: + last_serv = UnixWSGIServer( + application, + map, + _start, + sock, + dispatcher=dispatcher, + adj=adj, + bind_socket=False, + sockinfo=sockinfo, + ) + effective_listen.append( + (last_serv.effective_host, last_serv.effective_port) + ) + + # We are running a single server, so we can just return the last server, + # saves us from having to create one more object + if len(effective_listen) == 1: + # In this case we have no need to use a MultiSocketServer + return last_serv + + log_info = last_serv.log_info + # Return a class that has a utility function to print out the sockets it's + # listening on, and has a .run() function. All of the TcpWSGIServers + # registered themselves in the map above. + return MultiSocketServer(map, adj, effective_listen, dispatcher, log_info) + + +# This class is only ever used if we have multiple listen sockets. It allows +# the serve() API to call .run() which starts the wasyncore loop, and catches +# SystemExit/KeyboardInterrupt so that it can atempt to cleanly shut down. +class MultiSocketServer: + asyncore = wasyncore # test shim + + def __init__( + self, + map=None, + adj=None, + effective_listen=None, + dispatcher=None, + log_info=None, + ): + self.adj = adj + self.map = map + self.effective_listen = effective_listen + self.task_dispatcher = dispatcher + self.log_info = log_info + + def print_listen(self, format_str): # pragma: nocover + for l in self.effective_listen: + l = list(l) + + if ":" in l[0]: + l[0] = "[{}]".format(l[0]) + + self.log_info(format_str.format(*l)) + + def run(self): + try: + self.asyncore.loop( + timeout=self.adj.asyncore_loop_timeout, + map=self.map, + use_poll=self.adj.asyncore_use_poll, + ) + except (SystemExit, KeyboardInterrupt): + self.close() + + def close(self): + self.task_dispatcher.shutdown() + wasyncore.close_all(self.map) + + +class BaseWSGIServer(wasyncore.dispatcher): + + channel_class = HTTPChannel + next_channel_cleanup = 0 + socketmod = socket # test shim + asyncore = wasyncore # test shim + in_connection_overflow = False + + def __init__( + self, + application, + map=None, + _start=True, # test shim + _sock=None, # test shim + dispatcher=None, # dispatcher + adj=None, # adjustments + sockinfo=None, # opaque object + bind_socket=True, + **kw + ): + if adj is None: + adj = Adjustments(**kw) + + if adj.trusted_proxy or adj.clear_untrusted_proxy_headers: + # wrap the application to deal with proxy headers + # we wrap it here because webtest subclasses the TcpWSGIServer + # directly and thus doesn't run any code that's in create_server + application = proxy_headers_middleware( + application, + trusted_proxy=adj.trusted_proxy, + trusted_proxy_count=adj.trusted_proxy_count, + trusted_proxy_headers=adj.trusted_proxy_headers, + clear_untrusted=adj.clear_untrusted_proxy_headers, + log_untrusted=adj.log_untrusted_proxy_headers, + logger=self.logger, + ) + + if map is None: + # use a nonglobal socket map by default to hopefully prevent + # conflicts with apps and libs that use the wasyncore global socket + # map ala https://github.com/Pylons/waitress/issues/63 + map = {} + if sockinfo is None: + sockinfo = adj.listen[0] + + self.sockinfo = sockinfo + self.family = sockinfo[0] + self.socktype = sockinfo[1] + self.application = application + self.adj = adj + self.trigger = trigger.trigger(map) + if dispatcher is None: + dispatcher = ThreadedTaskDispatcher() + dispatcher.set_thread_count(self.adj.threads) + + self.task_dispatcher = dispatcher + self.asyncore.dispatcher.__init__(self, _sock, map=map) + if _sock is None: + self.create_socket(self.family, self.socktype) + if self.family == socket.AF_INET6: # pragma: nocover + self.socket.setsockopt(IPPROTO_IPV6, IPV6_V6ONLY, 1) + + self.set_reuse_addr() + + if bind_socket: + self.bind_server_socket() + + self.effective_host, self.effective_port = self.getsockname() + self.server_name = adj.server_name + self.active_channels = {} + if _start: + self.accept_connections() + + def bind_server_socket(self): + raise NotImplementedError # pragma: no cover + + def getsockname(self): + raise NotImplementedError # pragma: no cover + + def accept_connections(self): + self.accepting = True + self.socket.listen(self.adj.backlog) # Get around asyncore NT limit + + def add_task(self, task): + self.task_dispatcher.add_task(task) + + def readable(self): + now = time.time() + if now >= self.next_channel_cleanup: + self.next_channel_cleanup = now + self.adj.cleanup_interval + self.maintenance(now) + + if self.accepting: + if ( + not self.in_connection_overflow + and len(self._map) >= self.adj.connection_limit + ): + self.in_connection_overflow = True + self.logger.warning( + "total open connections reached the connection limit, " + "no longer accepting new connections" + ) + elif ( + self.in_connection_overflow + and len(self._map) < self.adj.connection_limit + ): + self.in_connection_overflow = False + self.logger.info( + "total open connections dropped below the connection limit, " + "listening again" + ) + return not self.in_connection_overflow + return False + + def writable(self): + return False + + def handle_read(self): + pass + + def handle_connect(self): + pass + + def handle_accept(self): + try: + v = self.accept() + if v is None: + return + conn, addr = v + except OSError: + # Linux: On rare occasions we get a bogus socket back from + # accept. socketmodule.c:makesockaddr complains that the + # address family is unknown. We don't want the whole server + # to shut down because of this. + if self.adj.log_socket_errors: + self.logger.warning("server accept() threw an exception", exc_info=True) + return + self.set_socket_options(conn) + addr = self.fix_addr(addr) + self.channel_class(self, conn, addr, self.adj, map=self._map) + + def run(self): + try: + self.asyncore.loop( + timeout=self.adj.asyncore_loop_timeout, + map=self._map, + use_poll=self.adj.asyncore_use_poll, + ) + except (SystemExit, KeyboardInterrupt): + self.task_dispatcher.shutdown() + + def pull_trigger(self): + self.trigger.pull_trigger() + + def set_socket_options(self, conn): + pass + + def fix_addr(self, addr): + return addr + + def maintenance(self, now): + """ + Closes channels that have not had any activity in a while. + + The timeout is configured through adj.channel_timeout (seconds). + """ + cutoff = now - self.adj.channel_timeout + for channel in self.active_channels.values(): + if (not channel.requests) and channel.last_activity < cutoff: + channel.will_close = True + + def print_listen(self, format_str): # pragma: no cover + self.log_info(format_str.format(self.effective_host, self.effective_port)) + + def close(self): + self.trigger.close() + return wasyncore.dispatcher.close(self) + + +class TcpWSGIServer(BaseWSGIServer): + def bind_server_socket(self): + (_, _, _, sockaddr) = self.sockinfo + self.bind(sockaddr) + + def getsockname(self): + # Return the IP address, port as numeric + return self.socketmod.getnameinfo( + self.socket.getsockname(), + self.socketmod.NI_NUMERICHOST | self.socketmod.NI_NUMERICSERV, + ) + + def set_socket_options(self, conn): + for (level, optname, value) in self.adj.socket_options: + conn.setsockopt(level, optname, value) + + +if hasattr(socket, "AF_UNIX"): + + class UnixWSGIServer(BaseWSGIServer): + def __init__( + self, + application, + map=None, + _start=True, # test shim + _sock=None, # test shim + dispatcher=None, # dispatcher + adj=None, # adjustments + sockinfo=None, # opaque object + **kw + ): + if sockinfo is None: + sockinfo = (socket.AF_UNIX, socket.SOCK_STREAM, None, None) + + super().__init__( + application, + map=map, + _start=_start, + _sock=_sock, + dispatcher=dispatcher, + adj=adj, + sockinfo=sockinfo, + **kw, + ) + + def bind_server_socket(self): + cleanup_unix_socket(self.adj.unix_socket) + self.bind(self.adj.unix_socket) + if os.path.exists(self.adj.unix_socket): + os.chmod(self.adj.unix_socket, self.adj.unix_socket_perms) + + def getsockname(self): + return ("unix", self.socket.getsockname()) + + def fix_addr(self, addr): + return ("localhost", None) + + +# Compatibility alias. +WSGIServer = TcpWSGIServer diff --git a/libs/waitress/task.py b/libs/waitress/task.py new file mode 100644 index 000000000..2ac8f4c81 --- /dev/null +++ b/libs/waitress/task.py @@ -0,0 +1,570 @@ +############################################################################## +# +# Copyright (c) 2001, 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## + +from collections import deque +import socket +import sys +import threading +import time + +from .buffers import ReadOnlyFileBasedBuffer +from .utilities import build_http_date, logger, queue_logger + +rename_headers = { # or keep them without the HTTP_ prefix added + "CONTENT_LENGTH": "CONTENT_LENGTH", + "CONTENT_TYPE": "CONTENT_TYPE", +} + +hop_by_hop = frozenset( + ( + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailers", + "transfer-encoding", + "upgrade", + ) +) + + +class ThreadedTaskDispatcher: + """A Task Dispatcher that creates a thread for each task.""" + + stop_count = 0 # Number of threads that will stop soon. + active_count = 0 # Number of currently active threads + logger = logger + queue_logger = queue_logger + + def __init__(self): + self.threads = set() + self.queue = deque() + self.lock = threading.Lock() + self.queue_cv = threading.Condition(self.lock) + self.thread_exit_cv = threading.Condition(self.lock) + + def start_new_thread(self, target, thread_no): + t = threading.Thread( + target=target, name="waitress-{}".format(thread_no), args=(thread_no,) + ) + t.daemon = True + t.start() + + def handler_thread(self, thread_no): + while True: + with self.lock: + while not self.queue and self.stop_count == 0: + # Mark ourselves as idle before waiting to be + # woken up, then we will once again be active + self.active_count -= 1 + self.queue_cv.wait() + self.active_count += 1 + + if self.stop_count > 0: + self.active_count -= 1 + self.stop_count -= 1 + self.threads.discard(thread_no) + self.thread_exit_cv.notify() + break + + task = self.queue.popleft() + try: + task.service() + except BaseException: + self.logger.exception("Exception when servicing %r", task) + + def set_thread_count(self, count): + with self.lock: + threads = self.threads + thread_no = 0 + running = len(threads) - self.stop_count + while running < count: + # Start threads. + while thread_no in threads: + thread_no = thread_no + 1 + threads.add(thread_no) + running += 1 + self.start_new_thread(self.handler_thread, thread_no) + self.active_count += 1 + thread_no = thread_no + 1 + if running > count: + # Stop threads. + self.stop_count += running - count + self.queue_cv.notify_all() + + def add_task(self, task): + with self.lock: + self.queue.append(task) + self.queue_cv.notify() + queue_size = len(self.queue) + idle_threads = len(self.threads) - self.stop_count - self.active_count + if queue_size > idle_threads: + self.queue_logger.warning( + "Task queue depth is %d", queue_size - idle_threads + ) + + def shutdown(self, cancel_pending=True, timeout=5): + self.set_thread_count(0) + # Ensure the threads shut down. + threads = self.threads + expiration = time.time() + timeout + with self.lock: + while threads: + if time.time() >= expiration: + self.logger.warning("%d thread(s) still running", len(threads)) + break + self.thread_exit_cv.wait(0.1) + if cancel_pending: + # Cancel remaining tasks. + queue = self.queue + if len(queue) > 0: + self.logger.warning("Canceling %d pending task(s)", len(queue)) + while queue: + task = queue.popleft() + task.cancel() + self.queue_cv.notify_all() + return True + return False + + +class Task: + close_on_finish = False + status = "200 OK" + wrote_header = False + start_time = 0 + content_length = None + content_bytes_written = 0 + logged_write_excess = False + logged_write_no_body = False + complete = False + chunked_response = False + logger = logger + + def __init__(self, channel, request): + self.channel = channel + self.request = request + self.response_headers = [] + version = request.version + if version not in ("1.0", "1.1"): + # fall back to a version we support. + version = "1.0" + self.version = version + + def service(self): + try: + self.start() + self.execute() + self.finish() + except OSError: + self.close_on_finish = True + if self.channel.adj.log_socket_errors: + raise + + @property + def has_body(self): + return not ( + self.status.startswith("1") + or self.status.startswith("204") + or self.status.startswith("304") + ) + + def build_response_header(self): + version = self.version + # Figure out whether the connection should be closed. + connection = self.request.headers.get("CONNECTION", "").lower() + response_headers = [] + content_length_header = None + date_header = None + server_header = None + connection_close_header = None + + for (headername, headerval) in self.response_headers: + headername = "-".join([x.capitalize() for x in headername.split("-")]) + + if headername == "Content-Length": + if self.has_body: + content_length_header = headerval + else: + continue # pragma: no cover + + if headername == "Date": + date_header = headerval + + if headername == "Server": + server_header = headerval + + if headername == "Connection": + connection_close_header = headerval.lower() + # replace with properly capitalized version + response_headers.append((headername, headerval)) + + if ( + content_length_header is None + and self.content_length is not None + and self.has_body + ): + content_length_header = str(self.content_length) + response_headers.append(("Content-Length", content_length_header)) + + def close_on_finish(): + if connection_close_header is None: + response_headers.append(("Connection", "close")) + self.close_on_finish = True + + if version == "1.0": + if connection == "keep-alive": + if not content_length_header: + close_on_finish() + else: + response_headers.append(("Connection", "Keep-Alive")) + else: + close_on_finish() + + elif version == "1.1": + if connection == "close": + close_on_finish() + + if not content_length_header: + # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length + # for any response with a status code of 1xx, 204 or 304. + + if self.has_body: + response_headers.append(("Transfer-Encoding", "chunked")) + self.chunked_response = True + + if not self.close_on_finish: + close_on_finish() + + # under HTTP 1.1 keep-alive is default, no need to set the header + else: + raise AssertionError("neither HTTP/1.0 or HTTP/1.1") + + # Set the Server and Date field, if not yet specified. This is needed + # if the server is used as a proxy. + ident = self.channel.server.adj.ident + + if not server_header: + if ident: + response_headers.append(("Server", ident)) + else: + response_headers.append(("Via", ident or "waitress")) + + if not date_header: + response_headers.append(("Date", build_http_date(self.start_time))) + + self.response_headers = response_headers + + first_line = "HTTP/%s %s" % (self.version, self.status) + # NB: sorting headers needs to preserve same-named-header order + # as per RFC 2616 section 4.2; thus the key=lambda x: x[0] here; + # rely on stable sort to keep relative position of same-named headers + next_lines = [ + "%s: %s" % hv for hv in sorted(self.response_headers, key=lambda x: x[0]) + ] + lines = [first_line] + next_lines + res = "%s\r\n\r\n" % "\r\n".join(lines) + + return res.encode("latin-1") + + def remove_content_length_header(self): + response_headers = [] + + for header_name, header_value in self.response_headers: + if header_name.lower() == "content-length": + continue # pragma: nocover + response_headers.append((header_name, header_value)) + + self.response_headers = response_headers + + def start(self): + self.start_time = time.time() + + def finish(self): + if not self.wrote_header: + self.write(b"") + if self.chunked_response: + # not self.write, it will chunk it! + self.channel.write_soon(b"0\r\n\r\n") + + def write(self, data): + if not self.complete: + raise RuntimeError("start_response was not called before body written") + channel = self.channel + if not self.wrote_header: + rh = self.build_response_header() + channel.write_soon(rh) + self.wrote_header = True + + if data and self.has_body: + towrite = data + cl = self.content_length + if self.chunked_response: + # use chunked encoding response + towrite = hex(len(data))[2:].upper().encode("latin-1") + b"\r\n" + towrite += data + b"\r\n" + elif cl is not None: + towrite = data[: cl - self.content_bytes_written] + self.content_bytes_written += len(towrite) + if towrite != data and not self.logged_write_excess: + self.logger.warning( + "application-written content exceeded the number of " + "bytes specified by Content-Length header (%s)" % cl + ) + self.logged_write_excess = True + if towrite: + channel.write_soon(towrite) + elif data: + # Cheat, and tell the application we have written all of the bytes, + # even though the response shouldn't have a body and we are + # ignoring it entirely. + self.content_bytes_written += len(data) + + if not self.logged_write_no_body: + self.logger.warning( + "application-written content was ignored due to HTTP " + "response that may not contain a message-body: (%s)" % self.status + ) + self.logged_write_no_body = True + + +class ErrorTask(Task): + """An error task produces an error response""" + + complete = True + + def execute(self): + e = self.request.error + status, headers, body = e.to_response() + self.status = status + self.response_headers.extend(headers) + # We need to explicitly tell the remote client we are closing the + # connection, because self.close_on_finish is set, and we are going to + # slam the door in the clients face. + self.response_headers.append(("Connection", "close")) + self.close_on_finish = True + self.content_length = len(body) + self.write(body.encode("latin-1")) + + +class WSGITask(Task): + """A WSGI task produces a response from a WSGI application.""" + + environ = None + + def execute(self): + environ = self.get_environment() + + def start_response(status, headers, exc_info=None): + if self.complete and not exc_info: + raise AssertionError( + "start_response called a second time without providing exc_info." + ) + if exc_info: + try: + if self.wrote_header: + # higher levels will catch and handle raised exception: + # 1. "service" method in task.py + # 2. "service" method in channel.py + # 3. "handler_thread" method in task.py + raise exc_info[1] + else: + # As per WSGI spec existing headers must be cleared + self.response_headers = [] + finally: + exc_info = None + + self.complete = True + + if not status.__class__ is str: + raise AssertionError("status %s is not a string" % status) + if "\n" in status or "\r" in status: + raise ValueError( + "carriage return/line feed character present in status" + ) + + self.status = status + + # Prepare the headers for output + for k, v in headers: + if not k.__class__ is str: + raise AssertionError( + "Header name %r is not a string in %r" % (k, (k, v)) + ) + if not v.__class__ is str: + raise AssertionError( + "Header value %r is not a string in %r" % (v, (k, v)) + ) + + if "\n" in v or "\r" in v: + raise ValueError( + "carriage return/line feed character present in header value" + ) + if "\n" in k or "\r" in k: + raise ValueError( + "carriage return/line feed character present in header name" + ) + + kl = k.lower() + if kl == "content-length": + self.content_length = int(v) + elif kl in hop_by_hop: + raise AssertionError( + '%s is a "hop-by-hop" header; it cannot be used by ' + "a WSGI application (see PEP 3333)" % k + ) + + self.response_headers.extend(headers) + + # Return a method used to write the response data. + return self.write + + # Call the application to handle the request and write a response + app_iter = self.channel.server.application(environ, start_response) + + can_close_app_iter = True + try: + if app_iter.__class__ is ReadOnlyFileBasedBuffer: + cl = self.content_length + size = app_iter.prepare(cl) + if size: + if cl != size: + if cl is not None: + self.remove_content_length_header() + self.content_length = size + self.write(b"") # generate headers + # if the write_soon below succeeds then the channel will + # take over closing the underlying file via the channel's + # _flush_some or handle_close so we intentionally avoid + # calling close in the finally block + self.channel.write_soon(app_iter) + can_close_app_iter = False + return + + first_chunk_len = None + for chunk in app_iter: + if first_chunk_len is None: + first_chunk_len = len(chunk) + # Set a Content-Length header if one is not supplied. + # start_response may not have been called until first + # iteration as per PEP, so we must reinterrogate + # self.content_length here + if self.content_length is None: + app_iter_len = None + if hasattr(app_iter, "__len__"): + app_iter_len = len(app_iter) + if app_iter_len == 1: + self.content_length = first_chunk_len + # transmit headers only after first iteration of the iterable + # that returns a non-empty bytestring (PEP 3333) + if chunk: + self.write(chunk) + + cl = self.content_length + if cl is not None: + if self.content_bytes_written != cl: + # close the connection so the client isn't sitting around + # waiting for more data when there are too few bytes + # to service content-length + self.close_on_finish = True + if self.request.command != "HEAD": + self.logger.warning( + "application returned too few bytes (%s) " + "for specified Content-Length (%s) via app_iter" + % (self.content_bytes_written, cl), + ) + finally: + if can_close_app_iter and hasattr(app_iter, "close"): + app_iter.close() + + def get_environment(self): + """Returns a WSGI environment.""" + environ = self.environ + if environ is not None: + # Return the cached copy. + return environ + + request = self.request + path = request.path + channel = self.channel + server = channel.server + url_prefix = server.adj.url_prefix + + if path.startswith("/"): + # strip extra slashes at the beginning of a path that starts + # with any number of slashes + path = "/" + path.lstrip("/") + + if url_prefix: + # NB: url_prefix is guaranteed by the configuration machinery to + # be either the empty string or a string that starts with a single + # slash and ends without any slashes + if path == url_prefix: + # if the path is the same as the url prefix, the SCRIPT_NAME + # should be the url_prefix and PATH_INFO should be empty + path = "" + else: + # if the path starts with the url prefix plus a slash, + # the SCRIPT_NAME should be the url_prefix and PATH_INFO should + # the value of path from the slash until its end + url_prefix_with_trailing_slash = url_prefix + "/" + if path.startswith(url_prefix_with_trailing_slash): + path = path[len(url_prefix) :] + + environ = { + "REMOTE_ADDR": channel.addr[0], + # Nah, we aren't actually going to look up the reverse DNS for + # REMOTE_ADDR, but we will happily set this environment variable + # for the WSGI application. Spec says we can just set this to + # REMOTE_ADDR, so we do. + "REMOTE_HOST": channel.addr[0], + # try and set the REMOTE_PORT to something useful, but maybe None + "REMOTE_PORT": str(channel.addr[1]), + "REQUEST_METHOD": request.command.upper(), + "SERVER_PORT": str(server.effective_port), + "SERVER_NAME": server.server_name, + "SERVER_SOFTWARE": server.adj.ident, + "SERVER_PROTOCOL": "HTTP/%s" % self.version, + "SCRIPT_NAME": url_prefix, + "PATH_INFO": path, + "QUERY_STRING": request.query, + "wsgi.url_scheme": request.url_scheme, + # the following environment variables are required by the WSGI spec + "wsgi.version": (1, 0), + # apps should use the logging module + "wsgi.errors": sys.stderr, + "wsgi.multithread": True, + "wsgi.multiprocess": False, + "wsgi.run_once": False, + "wsgi.input": request.get_body_stream(), + "wsgi.file_wrapper": ReadOnlyFileBasedBuffer, + "wsgi.input_terminated": True, # wsgi.input is EOF terminated + } + + for key, value in dict(request.headers).items(): + value = value.strip() + mykey = rename_headers.get(key, None) + if mykey is None: + mykey = "HTTP_" + key + if mykey not in environ: + environ[mykey] = value + + # Insert a callable into the environment that allows the application to + # check if the client disconnected. Only works with + # channel_request_lookahead larger than 0. + environ["waitress.client_disconnected"] = self.channel.check_client_disconnected + + # cache the environ for this request + self.environ = environ + return environ diff --git a/libs/waitress/trigger.py b/libs/waitress/trigger.py new file mode 100644 index 000000000..24c4d0d6b --- /dev/null +++ b/libs/waitress/trigger.py @@ -0,0 +1,203 @@ +############################################################################## +# +# Copyright (c) 2001-2005 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE +# +############################################################################## + +import errno +import os +import socket +import threading + +from . import wasyncore + +# Wake up a call to select() running in the main thread. +# +# This is useful in a context where you are using Medusa's I/O +# subsystem to deliver data, but the data is generated by another +# thread. Normally, if Medusa is in the middle of a call to +# select(), new output data generated by another thread will have +# to sit until the call to select() either times out or returns. +# If the trigger is 'pulled' by another thread, it should immediately +# generate a READ event on the trigger object, which will force the +# select() invocation to return. +# +# A common use for this facility: letting Medusa manage I/O for a +# large number of connections; but routing each request through a +# thread chosen from a fixed-size thread pool. When a thread is +# acquired, a transaction is performed, but output data is +# accumulated into buffers that will be emptied more efficiently +# by Medusa. [picture a server that can process database queries +# rapidly, but doesn't want to tie up threads waiting to send data +# to low-bandwidth connections] +# +# The other major feature provided by this class is the ability to +# move work back into the main thread: if you call pull_trigger() +# with a thunk argument, when select() wakes up and receives the +# event it will call your thunk from within that thread. The main +# purpose of this is to remove the need to wrap thread locks around +# Medusa's data structures, which normally do not need them. [To see +# why this is true, imagine this scenario: A thread tries to push some +# new data onto a channel's outgoing data queue at the same time that +# the main thread is trying to remove some] + + +class _triggerbase: + """OS-independent base class for OS-dependent trigger class.""" + + kind = None # subclass must set to "pipe" or "loopback"; used by repr + + def __init__(self): + self._closed = False + + # `lock` protects the `thunks` list from being traversed and + # appended to simultaneously. + self.lock = threading.Lock() + + # List of no-argument callbacks to invoke when the trigger is + # pulled. These run in the thread running the wasyncore mainloop, + # regardless of which thread pulls the trigger. + self.thunks = [] + + def readable(self): + return True + + def writable(self): + return False + + def handle_connect(self): + pass + + def handle_close(self): + self.close() + + # Override the wasyncore close() method, because it doesn't know about + # (so can't close) all the gimmicks we have open. Subclass must + # supply a _close() method to do platform-specific closing work. _close() + # will be called iff we're not already closed. + def close(self): + if not self._closed: + self._closed = True + self.del_channel() + self._close() # subclass does OS-specific stuff + + def pull_trigger(self, thunk=None): + if thunk: + with self.lock: + self.thunks.append(thunk) + self._physical_pull() + + def handle_read(self): + try: + self.recv(8192) + except OSError: + return + with self.lock: + for thunk in self.thunks: + try: + thunk() + except: + nil, t, v, tbinfo = wasyncore.compact_traceback() + self.log_info( + "exception in trigger thunk: (%s:%s %s)" % (t, v, tbinfo) + ) + self.thunks = [] + + +if os.name == "posix": + + class trigger(_triggerbase, wasyncore.file_dispatcher): + kind = "pipe" + + def __init__(self, map): + _triggerbase.__init__(self) + r, self.trigger = self._fds = os.pipe() + wasyncore.file_dispatcher.__init__(self, r, map=map) + + def _close(self): + for fd in self._fds: + os.close(fd) + self._fds = [] + wasyncore.file_dispatcher.close(self) + + def _physical_pull(self): + os.write(self.trigger, b"x") + + +else: # pragma: no cover + # Windows version; uses just sockets, because a pipe isn't select'able + # on Windows. + + class trigger(_triggerbase, wasyncore.dispatcher): + kind = "loopback" + + def __init__(self, map): + _triggerbase.__init__(self) + + # Get a pair of connected sockets. The trigger is the 'w' + # end of the pair, which is connected to 'r'. 'r' is put + # in the wasyncore socket map. "pulling the trigger" then + # means writing something on w, which will wake up r. + + w = socket.socket() + # Disable buffering -- pulling the trigger sends 1 byte, + # and we want that sent immediately, to wake up wasyncore's + # select() ASAP. + w.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + count = 0 + while True: + count += 1 + # Bind to a local port; for efficiency, let the OS pick + # a free port for us. + # Unfortunately, stress tests showed that we may not + # be able to connect to that port ("Address already in + # use") despite that the OS picked it. This appears + # to be a race bug in the Windows socket implementation. + # So we loop until a connect() succeeds (almost always + # on the first try). See the long thread at + # http://mail.zope.org/pipermail/zope/2005-July/160433.html + # for hideous details. + a = socket.socket() + a.bind(("127.0.0.1", 0)) + connect_address = a.getsockname() # assigned (host, port) pair + a.listen(1) + try: + w.connect(connect_address) + break # success + except OSError as detail: + if detail[0] != errno.WSAEADDRINUSE: + # "Address already in use" is the only error + # I've seen on two WinXP Pro SP2 boxes, under + # Pythons 2.3.5 and 2.4.1. + raise + # (10048, 'Address already in use') + # assert count <= 2 # never triggered in Tim's tests + if count >= 10: # I've never seen it go above 2 + a.close() + w.close() + raise RuntimeError("Cannot bind trigger!") + # Close `a` and try again. Note: I originally put a short + # sleep() here, but it didn't appear to help or hurt. + a.close() + + r, addr = a.accept() # r becomes wasyncore's (self.)socket + a.close() + self.trigger = w + wasyncore.dispatcher.__init__(self, r, map=map) + + def _close(self): + # self.socket is r, and self.trigger is w, from __init__ + self.socket.close() + self.trigger.close() + + def _physical_pull(self): + self.trigger.send(b"x") diff --git a/libs/waitress/utilities.py b/libs/waitress/utilities.py new file mode 100644 index 000000000..3caaa336f --- /dev/null +++ b/libs/waitress/utilities.py @@ -0,0 +1,320 @@ +############################################################################## +# +# Copyright (c) 2004 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +"""Utility functions +""" + +import calendar +import errno +import logging +import os +import re +import stat +import time + +from .rfc7230 import OBS_TEXT, VCHAR + +logger = logging.getLogger("waitress") +queue_logger = logging.getLogger("waitress.queue") + + +def find_double_newline(s): + """Returns the position just after a double newline in the given string.""" + pos = s.find(b"\r\n\r\n") + + if pos >= 0: + pos += 4 + + return pos + + +def concat(*args): + return "".join(args) + + +def join(seq, field=" "): + return field.join(seq) + + +def group(s): + return "(" + s + ")" + + +short_days = ["sun", "mon", "tue", "wed", "thu", "fri", "sat"] +long_days = [ + "sunday", + "monday", + "tuesday", + "wednesday", + "thursday", + "friday", + "saturday", +] + +short_day_reg = group(join(short_days, "|")) +long_day_reg = group(join(long_days, "|")) + +daymap = {} + +for i in range(7): + daymap[short_days[i]] = i + daymap[long_days[i]] = i + +hms_reg = join(3 * [group("[0-9][0-9]")], ":") + +months = [ + "jan", + "feb", + "mar", + "apr", + "may", + "jun", + "jul", + "aug", + "sep", + "oct", + "nov", + "dec", +] + +monmap = {} + +for i in range(12): + monmap[months[i]] = i + 1 + +months_reg = group(join(months, "|")) + +# From draft-ietf-http-v11-spec-07.txt/3.3.1 +# Sun, 06 Nov 1994 08:49:37 GMT ; RFC 822, updated by RFC 1123 +# Sunday, 06-Nov-94 08:49:37 GMT ; RFC 850, obsoleted by RFC 1036 +# Sun Nov 6 08:49:37 1994 ; ANSI C's asctime() format + +# rfc822 format +rfc822_date = join( + [ + concat(short_day_reg, ","), # day + group("[0-9][0-9]?"), # date + months_reg, # month + group("[0-9]+"), # year + hms_reg, # hour minute second + "gmt", + ], + " ", +) + +rfc822_reg = re.compile(rfc822_date) + + +def unpack_rfc822(m): + g = m.group + + return ( + int(g(4)), # year + monmap[g(3)], # month + int(g(2)), # day + int(g(5)), # hour + int(g(6)), # minute + int(g(7)), # second + 0, + 0, + 0, + ) + + +# rfc850 format +rfc850_date = join( + [ + concat(long_day_reg, ","), + join([group("[0-9][0-9]?"), months_reg, group("[0-9]+")], "-"), + hms_reg, + "gmt", + ], + " ", +) + +rfc850_reg = re.compile(rfc850_date) +# they actually unpack the same way +def unpack_rfc850(m): + g = m.group + yr = g(4) + + if len(yr) == 2: + yr = "19" + yr + + return ( + int(yr), # year + monmap[g(3)], # month + int(g(2)), # day + int(g(5)), # hour + int(g(6)), # minute + int(g(7)), # second + 0, + 0, + 0, + ) + + +# parsdate.parsedate - ~700/sec. +# parse_http_date - ~1333/sec. + +weekdayname = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"] +monthname = [ + None, + "Jan", + "Feb", + "Mar", + "Apr", + "May", + "Jun", + "Jul", + "Aug", + "Sep", + "Oct", + "Nov", + "Dec", +] + + +def build_http_date(when): + year, month, day, hh, mm, ss, wd, y, z = time.gmtime(when) + + return "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( + weekdayname[wd], + day, + monthname[month], + year, + hh, + mm, + ss, + ) + + +def parse_http_date(d): + d = d.lower() + m = rfc850_reg.match(d) + + if m and m.end() == len(d): + retval = int(calendar.timegm(unpack_rfc850(m))) + else: + m = rfc822_reg.match(d) + + if m and m.end() == len(d): + retval = int(calendar.timegm(unpack_rfc822(m))) + else: + return 0 + + return retval + + +# RFC 5234 Appendix B.1 "Core Rules": +# VCHAR = %x21-7E +# ; visible (printing) characters +vchar_re = VCHAR + +# RFC 7230 Section 3.2.6 "Field Value Components": +# quoted-string = DQUOTE *( qdtext / quoted-pair ) DQUOTE +# qdtext = HTAB / SP /%x21 / %x23-5B / %x5D-7E / obs-text +# obs-text = %x80-FF +# quoted-pair = "\" ( HTAB / SP / VCHAR / obs-text ) +obs_text_re = OBS_TEXT + +# The '\\' between \x5b and \x5d is needed to escape \x5d (']') +qdtext_re = "[\t \x21\x23-\x5b\\\x5d-\x7e" + obs_text_re + "]" + +quoted_pair_re = r"\\" + "([\t " + vchar_re + obs_text_re + "])" +quoted_string_re = '"(?:(?:' + qdtext_re + ")|(?:" + quoted_pair_re + '))*"' + +quoted_string = re.compile(quoted_string_re) +quoted_pair = re.compile(quoted_pair_re) + + +def undquote(value): + if value.startswith('"') and value.endswith('"'): + # So it claims to be DQUOTE'ed, let's validate that + matches = quoted_string.match(value) + + if matches and matches.end() == len(value): + # Remove the DQUOTE's from the value + value = value[1:-1] + + # Remove all backslashes that are followed by a valid vchar or + # obs-text + value = quoted_pair.sub(r"\1", value) + + return value + elif not value.startswith('"') and not value.endswith('"'): + return value + + raise ValueError("Invalid quoting in value") + + +def cleanup_unix_socket(path): + try: + st = os.stat(path) + except OSError as exc: + if exc.errno != errno.ENOENT: + raise # pragma: no cover + else: + if stat.S_ISSOCK(st.st_mode): + try: + os.remove(path) + except OSError: # pragma: no cover + # avoid race condition error during tests + pass + + +class Error: + code = 500 + reason = "Internal Server Error" + + def __init__(self, body): + self.body = body + + def to_response(self): + status = "%s %s" % (self.code, self.reason) + body = "%s\r\n\r\n%s" % (self.reason, self.body) + tag = "\r\n\r\n(generated by waitress)" + body = body + tag + headers = [("Content-Type", "text/plain")] + + return status, headers, body + + def wsgi_response(self, environ, start_response): + status, headers, body = self.to_response() + start_response(status, headers) + yield body + + +class BadRequest(Error): + code = 400 + reason = "Bad Request" + + +class RequestHeaderFieldsTooLarge(BadRequest): + code = 431 + reason = "Request Header Fields Too Large" + + +class RequestEntityTooLarge(BadRequest): + code = 413 + reason = "Request Entity Too Large" + + +class InternalServerError(Error): + code = 500 + reason = "Internal Server Error" + + +class ServerNotImplemented(Error): + code = 501 + reason = "Not Implemented" diff --git a/libs/waitress/wasyncore.py b/libs/waitress/wasyncore.py new file mode 100644 index 000000000..9a68c5171 --- /dev/null +++ b/libs/waitress/wasyncore.py @@ -0,0 +1,691 @@ +# -*- Mode: Python -*- +# Id: asyncore.py,v 2.51 2000/09/07 22:29:26 rushing Exp +# Author: Sam Rushing <[email protected]> + +# ====================================================================== +# Copyright 1996 by Sam Rushing +# +# All Rights Reserved +# +# Permission to use, copy, modify, and distribute this software and +# its documentation for any purpose and without fee is hereby +# granted, provided that the above copyright notice appear in all +# copies and that both that copyright notice and this permission +# notice appear in supporting documentation, and that the name of Sam +# Rushing not be used in advertising or publicity pertaining to +# distribution of the software without specific, written prior +# permission. +# +# SAM RUSHING DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, +# INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN +# NO EVENT SHALL SAM RUSHING BE LIABLE FOR ANY SPECIAL, INDIRECT OR +# CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS +# OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, +# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN +# CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +# ====================================================================== + +"""Basic infrastructure for asynchronous socket service clients and servers. + +There are only two ways to have a program on a single processor do "more +than one thing at a time". Multi-threaded programming is the simplest and +most popular way to do it, but there is another very different technique, +that lets you have nearly all the advantages of multi-threading, without +actually using multiple threads. it's really only practical if your program +is largely I/O bound. If your program is CPU bound, then pre-emptive +scheduled threads are probably what you really need. Network servers are +rarely CPU-bound, however. + +If your operating system supports the select() system call in its I/O +library (and nearly all do), then you can use it to juggle multiple +communication channels at once; doing other work while your I/O is taking +place in the "background." Although this strategy can seem strange and +complex, especially at first, it is in many ways easier to understand and +control than multi-threaded programming. The module documented here solves +many of the difficult problems for you, making the task of building +sophisticated high-performance network servers and clients a snap. + +NB: this is a fork of asyncore from the stdlib that we've (the waitress +developers) named 'wasyncore' to ensure forward compatibility, as asyncore +in the stdlib will be dropped soon. It is neither a copy of the 2.7 asyncore +nor the 3.X asyncore; it is a version compatible with either 2.7 or 3.X. +""" + +from errno import ( + EAGAIN, + EALREADY, + EBADF, + ECONNABORTED, + ECONNRESET, + EINPROGRESS, + EINTR, + EINVAL, + EISCONN, + ENOTCONN, + EPIPE, + ESHUTDOWN, + EWOULDBLOCK, + errorcode, +) +import logging +import os +import select +import socket +import sys +import time +import warnings + +from . import compat, utilities + +_DISCONNECTED = frozenset({ECONNRESET, ENOTCONN, ESHUTDOWN, ECONNABORTED, EPIPE, EBADF}) + +try: + socket_map +except NameError: + socket_map = {} + + +def _strerror(err): + try: + return os.strerror(err) + except (TypeError, ValueError, OverflowError, NameError): + return "Unknown error %s" % err + + +class ExitNow(Exception): + pass + + +_reraised_exceptions = (ExitNow, KeyboardInterrupt, SystemExit) + + +def read(obj): + try: + obj.handle_read_event() + except _reraised_exceptions: + raise + except: + obj.handle_error() + + +def write(obj): + try: + obj.handle_write_event() + except _reraised_exceptions: + raise + except: + obj.handle_error() + + +def _exception(obj): + try: + obj.handle_expt_event() + except _reraised_exceptions: + raise + except: + obj.handle_error() + + +def readwrite(obj, flags): + try: + if flags & select.POLLIN: + obj.handle_read_event() + if flags & select.POLLOUT: + obj.handle_write_event() + if flags & select.POLLPRI: + obj.handle_expt_event() + if flags & (select.POLLHUP | select.POLLERR | select.POLLNVAL): + obj.handle_close() + except OSError as e: + if e.args[0] not in _DISCONNECTED: + obj.handle_error() + else: + obj.handle_close() + except _reraised_exceptions: + raise + except: + obj.handle_error() + + +def poll(timeout=0.0, map=None): + if map is None: # pragma: no cover + map = socket_map + if map: + r = [] + w = [] + e = [] + for fd, obj in list(map.items()): # list() call FBO py3 + is_r = obj.readable() + is_w = obj.writable() + if is_r: + r.append(fd) + # accepting sockets should not be writable + if is_w and not obj.accepting: + w.append(fd) + if is_r or is_w: + e.append(fd) + if [] == r == w == e: + time.sleep(timeout) + return + + try: + r, w, e = select.select(r, w, e, timeout) + except OSError as err: + if err.args[0] != EINTR: + raise + else: + return + + for fd in r: + obj = map.get(fd) + if obj is None: # pragma: no cover + continue + read(obj) + + for fd in w: + obj = map.get(fd) + if obj is None: # pragma: no cover + continue + write(obj) + + for fd in e: + obj = map.get(fd) + if obj is None: # pragma: no cover + continue + _exception(obj) + + +def poll2(timeout=0.0, map=None): + # Use the poll() support added to the select module in Python 2.0 + if map is None: # pragma: no cover + map = socket_map + if timeout is not None: + # timeout is in milliseconds + timeout = int(timeout * 1000) + pollster = select.poll() + if map: + for fd, obj in list(map.items()): + flags = 0 + if obj.readable(): + flags |= select.POLLIN | select.POLLPRI + # accepting sockets should not be writable + if obj.writable() and not obj.accepting: + flags |= select.POLLOUT + if flags: + pollster.register(fd, flags) + + try: + r = pollster.poll(timeout) + except OSError as err: + if err.args[0] != EINTR: + raise + r = [] + + for fd, flags in r: + obj = map.get(fd) + if obj is None: # pragma: no cover + continue + readwrite(obj, flags) + + +poll3 = poll2 # Alias for backward compatibility + + +def loop(timeout=30.0, use_poll=False, map=None, count=None): + if map is None: # pragma: no cover + map = socket_map + + if use_poll and hasattr(select, "poll"): + poll_fun = poll2 + else: + poll_fun = poll + + if count is None: # pragma: no cover + while map: + poll_fun(timeout, map) + + else: + while map and count > 0: + poll_fun(timeout, map) + count = count - 1 + + +def compact_traceback(): + t, v, tb = sys.exc_info() + tbinfo = [] + if not tb: # pragma: no cover + raise AssertionError("traceback does not exist") + while tb: + tbinfo.append( + ( + tb.tb_frame.f_code.co_filename, + tb.tb_frame.f_code.co_name, + str(tb.tb_lineno), + ) + ) + tb = tb.tb_next + + # just to be safe + del tb + + file, function, line = tbinfo[-1] + info = " ".join(["[%s|%s|%s]" % x for x in tbinfo]) + return (file, function, line), t, v, info + + +class dispatcher: + + debug = False + connected = False + accepting = False + connecting = False + closing = False + addr = None + ignore_log_types = frozenset({"warning"}) + logger = utilities.logger + compact_traceback = staticmethod(compact_traceback) # for testing + + def __init__(self, sock=None, map=None): + if map is None: # pragma: no cover + self._map = socket_map + else: + self._map = map + + self._fileno = None + + if sock: + # Set to nonblocking just to make sure for cases where we + # get a socket from a blocking source. + sock.setblocking(0) + self.set_socket(sock, map) + self.connected = True + # The constructor no longer requires that the socket + # passed be connected. + try: + self.addr = sock.getpeername() + except OSError as err: + if err.args[0] in (ENOTCONN, EINVAL): + # To handle the case where we got an unconnected + # socket. + self.connected = False + else: + # The socket is broken in some unknown way, alert + # the user and remove it from the map (to prevent + # polling of broken sockets). + self.del_channel(map) + raise + else: + self.socket = None + + def __repr__(self): + status = [self.__class__.__module__ + "." + self.__class__.__qualname__] + if self.accepting and self.addr: + status.append("listening") + elif self.connected: + status.append("connected") + if self.addr is not None: + try: + status.append("%s:%d" % self.addr) + except TypeError: # pragma: no cover + status.append(repr(self.addr)) + return "<%s at %#x>" % (" ".join(status), id(self)) + + __str__ = __repr__ + + def add_channel(self, map=None): + # self.log_info('adding channel %s' % self) + if map is None: + map = self._map + map[self._fileno] = self + + def del_channel(self, map=None): + fd = self._fileno + if map is None: + map = self._map + if fd in map: + # self.log_info('closing channel %d:%s' % (fd, self)) + del map[fd] + self._fileno = None + + def create_socket(self, family=socket.AF_INET, type=socket.SOCK_STREAM): + self.family_and_type = family, type + sock = socket.socket(family, type) + sock.setblocking(0) + self.set_socket(sock) + + def set_socket(self, sock, map=None): + self.socket = sock + self._fileno = sock.fileno() + self.add_channel(map) + + def set_reuse_addr(self): + # try to re-use a server port if possible + try: + self.socket.setsockopt( + socket.SOL_SOCKET, + socket.SO_REUSEADDR, + self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) | 1, + ) + except OSError: + pass + + # ================================================== + # predicates for select() + # these are used as filters for the lists of sockets + # to pass to select(). + # ================================================== + + def readable(self): + return True + + def writable(self): + return True + + # ================================================== + # socket object methods. + # ================================================== + + def listen(self, num): + self.accepting = True + if os.name == "nt" and num > 5: # pragma: no cover + num = 5 + return self.socket.listen(num) + + def bind(self, addr): + self.addr = addr + return self.socket.bind(addr) + + def connect(self, address): + self.connected = False + self.connecting = True + err = self.socket.connect_ex(address) + if ( + err in (EINPROGRESS, EALREADY, EWOULDBLOCK) + or err == EINVAL + and os.name == "nt" + ): # pragma: no cover + self.addr = address + return + if err in (0, EISCONN): + self.addr = address + self.handle_connect_event() + else: + raise OSError(err, errorcode[err]) + + def accept(self): + # XXX can return either an address pair or None + try: + conn, addr = self.socket.accept() + except TypeError: + return None + except OSError as why: + if why.args[0] in (EWOULDBLOCK, ECONNABORTED, EAGAIN): + return None + else: + raise + else: + return conn, addr + + def send(self, data): + try: + result = self.socket.send(data) + return result + except OSError as why: + if why.args[0] == EWOULDBLOCK: + return 0 + elif why.args[0] in _DISCONNECTED: + self.handle_close() + return 0 + else: + raise + + def recv(self, buffer_size): + try: + data = self.socket.recv(buffer_size) + if not data: + # a closed connection is indicated by signaling + # a read condition, and having recv() return 0. + self.handle_close() + return b"" + else: + return data + except OSError as why: + # winsock sometimes raises ENOTCONN + if why.args[0] in _DISCONNECTED: + self.handle_close() + return b"" + else: + raise + + def close(self): + self.connected = False + self.accepting = False + self.connecting = False + self.del_channel() + if self.socket is not None: + try: + self.socket.close() + except OSError as why: + if why.args[0] not in (ENOTCONN, EBADF): + raise + + # log and log_info may be overridden to provide more sophisticated + # logging and warning methods. In general, log is for 'hit' logging + # and 'log_info' is for informational, warning and error logging. + + def log(self, message): + self.logger.log(logging.DEBUG, message) + + def log_info(self, message, type="info"): + severity = { + "info": logging.INFO, + "warning": logging.WARN, + "error": logging.ERROR, + } + self.logger.log(severity.get(type, logging.INFO), message) + + def handle_read_event(self): + if self.accepting: + # accepting sockets are never connected, they "spawn" new + # sockets that are connected + self.handle_accept() + elif not self.connected: + if self.connecting: + self.handle_connect_event() + self.handle_read() + else: + self.handle_read() + + def handle_connect_event(self): + err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + raise OSError(err, _strerror(err)) + self.handle_connect() + self.connected = True + self.connecting = False + + def handle_write_event(self): + if self.accepting: + # Accepting sockets shouldn't get a write event. + # We will pretend it didn't happen. + return + + if not self.connected: + if self.connecting: + self.handle_connect_event() + self.handle_write() + + def handle_expt_event(self): + # handle_expt_event() is called if there might be an error on the + # socket, or if there is OOB data + # check for the error condition first + err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # we can get here when select.select() says that there is an + # exceptional condition on the socket + # since there is an error, we'll go ahead and close the socket + # like we would in a subclassed handle_read() that received no + # data + self.handle_close() + else: + self.handle_expt() + + def handle_error(self): + nil, t, v, tbinfo = self.compact_traceback() + + # sometimes a user repr method will crash. + try: + self_repr = repr(self) + except: # pragma: no cover + self_repr = "<__repr__(self) failed for object at %0x>" % id(self) + + self.log_info( + "uncaptured python exception, closing channel %s (%s:%s %s)" + % (self_repr, t, v, tbinfo), + "error", + ) + self.handle_close() + + def handle_expt(self): + self.log_info("unhandled incoming priority event", "warning") + + def handle_read(self): + self.log_info("unhandled read event", "warning") + + def handle_write(self): + self.log_info("unhandled write event", "warning") + + def handle_connect(self): + self.log_info("unhandled connect event", "warning") + + def handle_accept(self): + pair = self.accept() + if pair is not None: + self.handle_accepted(*pair) + + def handle_accepted(self, sock, addr): + sock.close() + self.log_info("unhandled accepted event", "warning") + + def handle_close(self): + self.log_info("unhandled close event", "warning") + self.close() + + +# --------------------------------------------------------------------------- +# adds simple buffered output capability, useful for simple clients. +# [for more sophisticated usage use asynchat.async_chat] +# --------------------------------------------------------------------------- + + +class dispatcher_with_send(dispatcher): + def __init__(self, sock=None, map=None): + dispatcher.__init__(self, sock, map) + self.out_buffer = b"" + + def initiate_send(self): + num_sent = 0 + num_sent = dispatcher.send(self, self.out_buffer[:65536]) + self.out_buffer = self.out_buffer[num_sent:] + + handle_write = initiate_send + + def writable(self): + return (not self.connected) or len(self.out_buffer) + + def send(self, data): + if self.debug: # pragma: no cover + self.log_info("sending %s" % repr(data)) + self.out_buffer = self.out_buffer + data + self.initiate_send() + + +def close_all(map=None, ignore_all=False): + if map is None: # pragma: no cover + map = socket_map + for x in list(map.values()): # list() FBO py3 + try: + x.close() + except OSError as x: + if x.args[0] == EBADF: + pass + elif not ignore_all: + raise + except _reraised_exceptions: + raise + except: + if not ignore_all: + raise + map.clear() + + +# Asynchronous File I/O: +# +# After a little research (reading man pages on various unixen, and +# digging through the linux kernel), I've determined that select() +# isn't meant for doing asynchronous file i/o. +# Heartening, though - reading linux/mm/filemap.c shows that linux +# supports asynchronous read-ahead. So _MOST_ of the time, the data +# will be sitting in memory for us already when we go to read it. +# +# What other OS's (besides NT) support async file i/o? [VMS?] +# +# Regardless, this is useful for pipes, and stdin/stdout... + +if os.name == "posix": + + class file_wrapper: + # Here we override just enough to make a file + # look like a socket for the purposes of asyncore. + # The passed fd is automatically os.dup()'d + + def __init__(self, fd): + self.fd = os.dup(fd) + + def __del__(self): + if self.fd >= 0: + warnings.warn("unclosed file %r" % self, ResourceWarning) + self.close() + + def recv(self, *args): + return os.read(self.fd, *args) + + def send(self, *args): + return os.write(self.fd, *args) + + def getsockopt(self, level, optname, buflen=None): # pragma: no cover + if level == socket.SOL_SOCKET and optname == socket.SO_ERROR and not buflen: + return 0 + raise NotImplementedError( + "Only asyncore specific behaviour " "implemented." + ) + + read = recv + write = send + + def close(self): + if self.fd < 0: + return + fd = self.fd + self.fd = -1 + os.close(fd) + + def fileno(self): + return self.fd + + class file_dispatcher(dispatcher): + def __init__(self, fd, map=None): + dispatcher.__init__(self, None, map) + self.connected = True + try: + fd = fd.fileno() + except AttributeError: + pass + self.set_file(fd) + # set it to non-blocking mode + os.set_blocking(fd, False) + + def set_file(self, fd): + self.socket = file_wrapper(fd) + self._fileno = self.socket.fileno() + self.add_channel() diff --git a/requirements.txt b/requirements.txt index 40076f2d4..c061d866e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,4 @@ setuptools -gevent>=21 -gevent-websocket>=0.10.1 lxml>=4.3.0 numpy>=1.12.0 webrtcvad-wheels>=2.0.10
\ No newline at end of file |