aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorLouis Vézina <[email protected]>2020-01-29 20:07:26 -0500
committerLouis Vézina <[email protected]>2020-01-29 20:07:26 -0500
commit83c95cc77dfd5ed18b439b1635f95bac129d0ce2 (patch)
treeea557727572cf3479a0af6d11434d0b486132eb8
parent95b8aadb239bdce8d7f7a03e4ab995e56bf4e820 (diff)
downloadbazarr-83c95cc77dfd5ed18b439b1635f95bac129d0ce2.tar.gz
bazarr-83c95cc77dfd5ed18b439b1635f95bac129d0ce2.zip
WIP
-rw-r--r--libs/bs4/__init__.py616
-rw-r--r--libs/bs4/builder/__init__.py367
-rw-r--r--libs/bs4/builder/_html5lib.py426
-rw-r--r--libs/bs4/builder/_htmlparser.py350
-rw-r--r--libs/bs4/builder/_lxml.py296
-rw-r--r--libs/bs4/dammit.py850
-rw-r--r--libs/bs4/diagnose.py224
-rw-r--r--libs/bs4/element.py1579
-rw-r--r--libs/bs4/formatter.py99
-rw-r--r--libs/bs4/testing.py992
-rw-r--r--libs/bs4/tests/__init__.py1
-rw-r--r--libs/bs4/tests/test_builder_registry.py147
-rw-r--r--libs/bs4/tests/test_docs.py36
-rw-r--r--libs/bs4/tests/test_html5lib.py170
-rw-r--r--libs/bs4/tests/test_htmlparser.py47
-rw-r--r--libs/bs4/tests/test_lxml.py100
-rw-r--r--libs/bs4/tests/test_soup.py567
-rw-r--r--libs/bs4/tests/test_tree.py2205
-rw-r--r--libs/engineio/__init__.py25
-rw-r--r--libs/engineio/async_drivers/__init__.py0
-rw-r--r--libs/engineio/async_drivers/aiohttp.py128
-rw-r--r--libs/engineio/async_drivers/asgi.py214
-rw-r--r--libs/engineio/async_drivers/eventlet.py30
-rw-r--r--libs/engineio/async_drivers/gevent.py63
-rw-r--r--libs/engineio/async_drivers/gevent_uwsgi.py156
-rw-r--r--libs/engineio/async_drivers/sanic.py144
-rw-r--r--libs/engineio/async_drivers/threading.py17
-rw-r--r--libs/engineio/async_drivers/tornado.py184
-rw-r--r--libs/engineio/asyncio_client.py585
-rw-r--r--libs/engineio/asyncio_server.py472
-rw-r--r--libs/engineio/asyncio_socket.py236
-rw-r--r--libs/engineio/client.py680
-rw-r--r--libs/engineio/exceptions.py22
-rw-r--r--libs/engineio/middleware.py87
-rw-r--r--libs/engineio/packet.py92
-rw-r--r--libs/engineio/payload.py81
-rw-r--r--libs/engineio/server.py675
-rw-r--r--libs/engineio/socket.py248
-rw-r--r--libs/engineio/static_files.py55
-rw-r--r--libs/flask_socketio/__init__.py922
-rw-r--r--libs/flask_socketio/namespace.py47
-rw-r--r--libs/flask_socketio/test_client.py205
-rw-r--r--libs/socketio/__init__.py38
-rw-r--r--libs/socketio/asgi.py36
-rw-r--r--libs/socketio/asyncio_aiopika_manager.py105
-rw-r--r--libs/socketio/asyncio_client.py475
-rw-r--r--libs/socketio/asyncio_manager.py58
-rw-r--r--libs/socketio/asyncio_namespace.py204
-rw-r--r--libs/socketio/asyncio_pubsub_manager.py163
-rw-r--r--libs/socketio/asyncio_redis_manager.py107
-rw-r--r--libs/socketio/asyncio_server.py526
-rw-r--r--libs/socketio/base_manager.py178
-rw-r--r--libs/socketio/client.py620
-rw-r--r--libs/socketio/exceptions.py30
-rw-r--r--libs/socketio/kafka_manager.py63
-rw-r--r--libs/socketio/kombu_manager.py122
-rw-r--r--libs/socketio/middleware.py42
-rw-r--r--libs/socketio/namespace.py191
-rw-r--r--libs/socketio/packet.py179
-rw-r--r--libs/socketio/pubsub_manager.py154
-rw-r--r--libs/socketio/redis_manager.py115
-rw-r--r--libs/socketio/server.py730
-rw-r--r--libs/socketio/tornado.py11
-rw-r--r--libs/socketio/zmq_manager.py111
-rw-r--r--libs/yaml/__init__.py402
-rw-r--r--libs/yaml/composer.py139
-rw-r--r--libs/yaml/constructor.py720
-rw-r--r--libs/yaml/cyaml.py101
-rw-r--r--libs/yaml/dumper.py62
-rw-r--r--libs/yaml/emitter.py1137
-rw-r--r--libs/yaml/error.py75
-rw-r--r--libs/yaml/events.py86
-rw-r--r--libs/yaml/loader.py63
-rw-r--r--libs/yaml/nodes.py49
-rw-r--r--libs/yaml/parser.py589
-rw-r--r--libs/yaml/reader.py185
-rw-r--r--libs/yaml/representer.py389
-rw-r--r--libs/yaml/resolver.py227
-rw-r--r--libs/yaml/scanner.py1435
-rw-r--r--libs/yaml/serializer.py111
-rw-r--r--libs/yaml/tokens.py104
81 files changed, 24572 insertions, 0 deletions
diff --git a/libs/bs4/__init__.py b/libs/bs4/__init__.py
new file mode 100644
index 000000000..95ca229c1
--- /dev/null
+++ b/libs/bs4/__init__.py
@@ -0,0 +1,616 @@
+"""Beautiful Soup
+Elixir and Tonic
+"The Screen-Scraper's Friend"
+http://www.crummy.com/software/BeautifulSoup/
+
+Beautiful Soup uses a pluggable XML or HTML parser to parse a
+(possibly invalid) document into a tree representation. Beautiful Soup
+provides methods and Pythonic idioms that make it easy to navigate,
+search, and modify the parse tree.
+
+Beautiful Soup works with Python 2.7 and up. It works better if lxml
+and/or html5lib is installed.
+
+For more than you ever wanted to know about Beautiful Soup, see the
+documentation:
+http://www.crummy.com/software/BeautifulSoup/bs4/doc/
+
+"""
+
+__author__ = "Leonard Richardson ([email protected])"
+__version__ = "4.8.0"
+__copyright__ = "Copyright (c) 2004-2019 Leonard Richardson"
+# Use of this source code is governed by the MIT license.
+__license__ = "MIT"
+
+__all__ = ['BeautifulSoup']
+
+import os
+import re
+import sys
+import traceback
+import warnings
+
+from .builder import builder_registry, ParserRejectedMarkup
+from .dammit import UnicodeDammit
+from .element import (
+ CData,
+ Comment,
+ DEFAULT_OUTPUT_ENCODING,
+ Declaration,
+ Doctype,
+ NavigableString,
+ PageElement,
+ ProcessingInstruction,
+ ResultSet,
+ SoupStrainer,
+ Tag,
+ )
+
+# The very first thing we do is give a useful error if someone is
+# running this code under Python 3 without converting it.
+'You are trying to run the Python 2 version of Beautiful Soup under Python 3. This will not work.'!='You need to convert the code, either by installing it (`python setup.py install`) or by running 2to3 (`2to3 -w bs4`).'
+
+class BeautifulSoup(Tag):
+ """
+ This class defines the basic interface called by the tree builders.
+
+ These methods will be called by the parser:
+ reset()
+ feed(markup)
+
+ The tree builder may call these methods from its feed() implementation:
+ handle_starttag(name, attrs) # See note about return value
+ handle_endtag(name)
+ handle_data(data) # Appends to the current data node
+ endData(containerClass=NavigableString) # Ends the current data node
+
+ No matter how complicated the underlying parser is, you should be
+ able to build a tree using 'start tag' events, 'end tag' events,
+ 'data' events, and "done with data" events.
+
+ If you encounter an empty-element tag (aka a self-closing tag,
+ like HTML's <br> tag), call handle_starttag and then
+ handle_endtag.
+ """
+ ROOT_TAG_NAME = '[document]'
+
+ # If the end-user gives no indication which tree builder they
+ # want, look for one with these features.
+ DEFAULT_BUILDER_FEATURES = ['html', 'fast']
+
+ ASCII_SPACES = '\x20\x0a\x09\x0c\x0d'
+
+ NO_PARSER_SPECIFIED_WARNING = "No parser was explicitly specified, so I'm using the best available %(markup_type)s parser for this system (\"%(parser)s\"). This usually isn't a problem, but if you run this code on another system, or in a different virtual environment, it may use a different parser and behave differently.\n\nThe code that caused this warning is on line %(line_number)s of the file %(filename)s. To get rid of this warning, pass the additional argument 'features=\"%(parser)s\"' to the BeautifulSoup constructor.\n"
+
+ def __init__(self, markup="", features=None, builder=None,
+ parse_only=None, from_encoding=None, exclude_encodings=None,
+ **kwargs):
+ """Constructor.
+
+ :param markup: A string or a file-like object representing
+ markup to be parsed.
+
+ :param features: Desirable features of the parser to be used. This
+ may be the name of a specific parser ("lxml", "lxml-xml",
+ "html.parser", or "html5lib") or it may be the type of markup
+ to be used ("html", "html5", "xml"). It's recommended that you
+ name a specific parser, so that Beautiful Soup gives you the
+ same results across platforms and virtual environments.
+
+ :param builder: A TreeBuilder subclass to instantiate (or
+ instance to use) instead of looking one up based on
+ `features`. You only need to use this if you've implemented a
+ custom TreeBuilder.
+
+ :param parse_only: A SoupStrainer. Only parts of the document
+ matching the SoupStrainer will be considered. This is useful
+ when parsing part of a document that would otherwise be too
+ large to fit into memory.
+
+ :param from_encoding: A string indicating the encoding of the
+ document to be parsed. Pass this in if Beautiful Soup is
+ guessing wrongly about the document's encoding.
+
+ :param exclude_encodings: A list of strings indicating
+ encodings known to be wrong. Pass this in if you don't know
+ the document's encoding but you know Beautiful Soup's guess is
+ wrong.
+
+ :param kwargs: For backwards compatibility purposes, the
+ constructor accepts certain keyword arguments used in
+ Beautiful Soup 3. None of these arguments do anything in
+ Beautiful Soup 4; they will result in a warning and then be ignored.
+
+ Apart from this, any keyword arguments passed into the BeautifulSoup
+ constructor are propagated to the TreeBuilder constructor. This
+ makes it possible to configure a TreeBuilder beyond saying
+ which one to use.
+
+ """
+
+ if 'convertEntities' in kwargs:
+ del kwargs['convertEntities']
+ warnings.warn(
+ "BS4 does not respect the convertEntities argument to the "
+ "BeautifulSoup constructor. Entities are always converted "
+ "to Unicode characters.")
+
+ if 'markupMassage' in kwargs:
+ del kwargs['markupMassage']
+ warnings.warn(
+ "BS4 does not respect the markupMassage argument to the "
+ "BeautifulSoup constructor. The tree builder is responsible "
+ "for any necessary markup massage.")
+
+ if 'smartQuotesTo' in kwargs:
+ del kwargs['smartQuotesTo']
+ warnings.warn(
+ "BS4 does not respect the smartQuotesTo argument to the "
+ "BeautifulSoup constructor. Smart quotes are always converted "
+ "to Unicode characters.")
+
+ if 'selfClosingTags' in kwargs:
+ del kwargs['selfClosingTags']
+ warnings.warn(
+ "BS4 does not respect the selfClosingTags argument to the "
+ "BeautifulSoup constructor. The tree builder is responsible "
+ "for understanding self-closing tags.")
+
+ if 'isHTML' in kwargs:
+ del kwargs['isHTML']
+ warnings.warn(
+ "BS4 does not respect the isHTML argument to the "
+ "BeautifulSoup constructor. Suggest you use "
+ "features='lxml' for HTML and features='lxml-xml' for "
+ "XML.")
+
+ def deprecated_argument(old_name, new_name):
+ if old_name in kwargs:
+ warnings.warn(
+ 'The "%s" argument to the BeautifulSoup constructor '
+ 'has been renamed to "%s."' % (old_name, new_name))
+ value = kwargs[old_name]
+ del kwargs[old_name]
+ return value
+ return None
+
+ parse_only = parse_only or deprecated_argument(
+ "parseOnlyThese", "parse_only")
+
+ from_encoding = from_encoding or deprecated_argument(
+ "fromEncoding", "from_encoding")
+
+ if from_encoding and isinstance(markup, str):
+ warnings.warn("You provided Unicode markup but also provided a value for from_encoding. Your from_encoding will be ignored.")
+ from_encoding = None
+
+ # We need this information to track whether or not the builder
+ # was specified well enough that we can omit the 'you need to
+ # specify a parser' warning.
+ original_builder = builder
+ original_features = features
+
+ if isinstance(builder, type):
+ # A builder class was passed in; it needs to be instantiated.
+ builder_class = builder
+ builder = None
+ elif builder is None:
+ if isinstance(features, str):
+ features = [features]
+ if features is None or len(features) == 0:
+ features = self.DEFAULT_BUILDER_FEATURES
+ builder_class = builder_registry.lookup(*features)
+ if builder_class is None:
+ raise FeatureNotFound(
+ "Couldn't find a tree builder with the features you "
+ "requested: %s. Do you need to install a parser library?"
+ % ",".join(features))
+
+ # At this point either we have a TreeBuilder instance in
+ # builder, or we have a builder_class that we can instantiate
+ # with the remaining **kwargs.
+ if builder is None:
+ builder = builder_class(**kwargs)
+ if not original_builder and not (
+ original_features == builder.NAME or
+ original_features in builder.ALTERNATE_NAMES
+ ):
+ if builder.is_xml:
+ markup_type = "XML"
+ else:
+ markup_type = "HTML"
+
+ # This code adapted from warnings.py so that we get the same line
+ # of code as our warnings.warn() call gets, even if the answer is wrong
+ # (as it may be in a multithreading situation).
+ caller = None
+ try:
+ caller = sys._getframe(1)
+ except ValueError:
+ pass
+ if caller:
+ globals = caller.f_globals
+ line_number = caller.f_lineno
+ else:
+ globals = sys.__dict__
+ line_number= 1
+ filename = globals.get('__file__')
+ if filename:
+ fnl = filename.lower()
+ if fnl.endswith((".pyc", ".pyo")):
+ filename = filename[:-1]
+ if filename:
+ # If there is no filename at all, the user is most likely in a REPL,
+ # and the warning is not necessary.
+ values = dict(
+ filename=filename,
+ line_number=line_number,
+ parser=builder.NAME,
+ markup_type=markup_type
+ )
+ warnings.warn(self.NO_PARSER_SPECIFIED_WARNING % values, stacklevel=2)
+ else:
+ if kwargs:
+ warnings.warn("Keyword arguments to the BeautifulSoup constructor will be ignored. These would normally be passed into the TreeBuilder constructor, but a TreeBuilder instance was passed in as `builder`.")
+
+ self.builder = builder
+ self.is_xml = builder.is_xml
+ self.known_xml = self.is_xml
+ self._namespaces = dict()
+ self.parse_only = parse_only
+
+ self.builder.initialize_soup(self)
+
+ if hasattr(markup, 'read'): # It's a file-type object.
+ markup = markup.read()
+ elif len(markup) <= 256 and (
+ (isinstance(markup, bytes) and not b'<' in markup)
+ or (isinstance(markup, str) and not '<' in markup)
+ ):
+ # Print out warnings for a couple beginner problems
+ # involving passing non-markup to Beautiful Soup.
+ # Beautiful Soup will still parse the input as markup,
+ # just in case that's what the user really wants.
+ if (isinstance(markup, str)
+ and not os.path.supports_unicode_filenames):
+ possible_filename = markup.encode("utf8")
+ else:
+ possible_filename = markup
+ is_file = False
+ try:
+ is_file = os.path.exists(possible_filename)
+ except Exception as e:
+ # This is almost certainly a problem involving
+ # characters not valid in filenames on this
+ # system. Just let it go.
+ pass
+ if is_file:
+ if isinstance(markup, str):
+ markup = markup.encode("utf8")
+ warnings.warn(
+ '"%s" looks like a filename, not markup. You should'
+ ' probably open this file and pass the filehandle into'
+ ' Beautiful Soup.' % markup)
+ self._check_markup_is_url(markup)
+
+ for (self.markup, self.original_encoding, self.declared_html_encoding,
+ self.contains_replacement_characters) in (
+ self.builder.prepare_markup(
+ markup, from_encoding, exclude_encodings=exclude_encodings)):
+ self.reset()
+ try:
+ self._feed()
+ break
+ except ParserRejectedMarkup:
+ pass
+
+ # Clear out the markup and remove the builder's circular
+ # reference to this object.
+ self.markup = None
+ self.builder.soup = None
+
+ def __copy__(self):
+ copy = type(self)(
+ self.encode('utf-8'), builder=self.builder, from_encoding='utf-8'
+ )
+
+ # Although we encoded the tree to UTF-8, that may not have
+ # been the encoding of the original markup. Set the copy's
+ # .original_encoding to reflect the original object's
+ # .original_encoding.
+ copy.original_encoding = self.original_encoding
+ return copy
+
+ def __getstate__(self):
+ # Frequently a tree builder can't be pickled.
+ d = dict(self.__dict__)
+ if 'builder' in d and not self.builder.picklable:
+ d['builder'] = None
+ return d
+
+ @staticmethod
+ def _check_markup_is_url(markup):
+ """
+ Check if markup looks like it's actually a url and raise a warning
+ if so. Markup can be unicode or str (py2) / bytes (py3).
+ """
+ if isinstance(markup, bytes):
+ space = b' '
+ cant_start_with = (b"http:", b"https:")
+ elif isinstance(markup, str):
+ space = ' '
+ cant_start_with = ("http:", "https:")
+ else:
+ return
+
+ if any(markup.startswith(prefix) for prefix in cant_start_with):
+ if not space in markup:
+ if isinstance(markup, bytes):
+ decoded_markup = markup.decode('utf-8', 'replace')
+ else:
+ decoded_markup = markup
+ warnings.warn(
+ '"%s" looks like a URL. Beautiful Soup is not an'
+ ' HTTP client. You should probably use an HTTP client like'
+ ' requests to get the document behind the URL, and feed'
+ ' that document to Beautiful Soup.' % decoded_markup
+ )
+
+ def _feed(self):
+ # Convert the document to Unicode.
+ self.builder.reset()
+
+ self.builder.feed(self.markup)
+ # Close out any unfinished strings and close all the open tags.
+ self.endData()
+ while self.currentTag.name != self.ROOT_TAG_NAME:
+ self.popTag()
+
+ def reset(self):
+ Tag.__init__(self, self, self.builder, self.ROOT_TAG_NAME)
+ self.hidden = 1
+ self.builder.reset()
+ self.current_data = []
+ self.currentTag = None
+ self.tagStack = []
+ self.preserve_whitespace_tag_stack = []
+ self.pushTag(self)
+
+ def new_tag(self, name, namespace=None, nsprefix=None, attrs={}, **kwattrs):
+ """Create a new tag associated with this soup."""
+ kwattrs.update(attrs)
+ return Tag(None, self.builder, name, namespace, nsprefix, kwattrs)
+
+ def new_string(self, s, subclass=NavigableString):
+ """Create a new NavigableString associated with this soup."""
+ return subclass(s)
+
+ def insert_before(self, successor):
+ raise NotImplementedError("BeautifulSoup objects don't support insert_before().")
+
+ def insert_after(self, successor):
+ raise NotImplementedError("BeautifulSoup objects don't support insert_after().")
+
+ def popTag(self):
+ tag = self.tagStack.pop()
+ if self.preserve_whitespace_tag_stack and tag == self.preserve_whitespace_tag_stack[-1]:
+ self.preserve_whitespace_tag_stack.pop()
+ #print "Pop", tag.name
+ if self.tagStack:
+ self.currentTag = self.tagStack[-1]
+ return self.currentTag
+
+ def pushTag(self, tag):
+ #print "Push", tag.name
+ if self.currentTag is not None:
+ self.currentTag.contents.append(tag)
+ self.tagStack.append(tag)
+ self.currentTag = self.tagStack[-1]
+ if tag.name in self.builder.preserve_whitespace_tags:
+ self.preserve_whitespace_tag_stack.append(tag)
+
+ def endData(self, containerClass=NavigableString):
+ if self.current_data:
+ current_data = ''.join(self.current_data)
+ # If whitespace is not preserved, and this string contains
+ # nothing but ASCII spaces, replace it with a single space
+ # or newline.
+ if not self.preserve_whitespace_tag_stack:
+ strippable = True
+ for i in current_data:
+ if i not in self.ASCII_SPACES:
+ strippable = False
+ break
+ if strippable:
+ if '\n' in current_data:
+ current_data = '\n'
+ else:
+ current_data = ' '
+
+ # Reset the data collector.
+ self.current_data = []
+
+ # Should we add this string to the tree at all?
+ if self.parse_only and len(self.tagStack) <= 1 and \
+ (not self.parse_only.text or \
+ not self.parse_only.search(current_data)):
+ return
+
+ o = containerClass(current_data)
+ self.object_was_parsed(o)
+
+ def object_was_parsed(self, o, parent=None, most_recent_element=None):
+ """Add an object to the parse tree."""
+ if parent is None:
+ parent = self.currentTag
+ if most_recent_element is not None:
+ previous_element = most_recent_element
+ else:
+ previous_element = self._most_recent_element
+
+ next_element = previous_sibling = next_sibling = None
+ if isinstance(o, Tag):
+ next_element = o.next_element
+ next_sibling = o.next_sibling
+ previous_sibling = o.previous_sibling
+ if previous_element is None:
+ previous_element = o.previous_element
+
+ fix = parent.next_element is not None
+
+ o.setup(parent, previous_element, next_element, previous_sibling, next_sibling)
+
+ self._most_recent_element = o
+ parent.contents.append(o)
+
+ # Check if we are inserting into an already parsed node.
+ if fix:
+ self._linkage_fixer(parent)
+
+ def _linkage_fixer(self, el):
+ """Make sure linkage of this fragment is sound."""
+
+ first = el.contents[0]
+ child = el.contents[-1]
+ descendant = child
+
+ if child is first and el.parent is not None:
+ # Parent should be linked to first child
+ el.next_element = child
+ # We are no longer linked to whatever this element is
+ prev_el = child.previous_element
+ if prev_el is not None and prev_el is not el:
+ prev_el.next_element = None
+ # First child should be linked to the parent, and no previous siblings.
+ child.previous_element = el
+ child.previous_sibling = None
+
+ # We have no sibling as we've been appended as the last.
+ child.next_sibling = None
+
+ # This index is a tag, dig deeper for a "last descendant"
+ if isinstance(child, Tag) and child.contents:
+ descendant = child._last_descendant(False)
+
+ # As the final step, link last descendant. It should be linked
+ # to the parent's next sibling (if found), else walk up the chain
+ # and find a parent with a sibling. It should have no next sibling.
+ descendant.next_element = None
+ descendant.next_sibling = None
+ target = el
+ while True:
+ if target is None:
+ break
+ elif target.next_sibling is not None:
+ descendant.next_element = target.next_sibling
+ target.next_sibling.previous_element = child
+ break
+ target = target.parent
+
+ def _popToTag(self, name, nsprefix=None, inclusivePop=True):
+ """Pops the tag stack up to and including the most recent
+ instance of the given tag. If inclusivePop is false, pops the tag
+ stack up to but *not* including the most recent instqance of
+ the given tag."""
+ #print "Popping to %s" % name
+ if name == self.ROOT_TAG_NAME:
+ # The BeautifulSoup object itself can never be popped.
+ return
+
+ most_recently_popped = None
+
+ stack_size = len(self.tagStack)
+ for i in range(stack_size - 1, 0, -1):
+ t = self.tagStack[i]
+ if (name == t.name and nsprefix == t.prefix):
+ if inclusivePop:
+ most_recently_popped = self.popTag()
+ break
+ most_recently_popped = self.popTag()
+
+ return most_recently_popped
+
+ def handle_starttag(self, name, namespace, nsprefix, attrs):
+ """Push a start tag on to the stack.
+
+ If this method returns None, the tag was rejected by the
+ SoupStrainer. You should proceed as if the tag had not occurred
+ in the document. For instance, if this was a self-closing tag,
+ don't call handle_endtag.
+ """
+
+ # print "Start tag %s: %s" % (name, attrs)
+ self.endData()
+
+ if (self.parse_only and len(self.tagStack) <= 1
+ and (self.parse_only.text
+ or not self.parse_only.search_tag(name, attrs))):
+ return None
+
+ tag = Tag(self, self.builder, name, namespace, nsprefix, attrs,
+ self.currentTag, self._most_recent_element)
+ if tag is None:
+ return tag
+ if self._most_recent_element is not None:
+ self._most_recent_element.next_element = tag
+ self._most_recent_element = tag
+ self.pushTag(tag)
+ return tag
+
+ def handle_endtag(self, name, nsprefix=None):
+ #print "End tag: " + name
+ self.endData()
+ self._popToTag(name, nsprefix)
+
+ def handle_data(self, data):
+ self.current_data.append(data)
+
+ def decode(self, pretty_print=False,
+ eventual_encoding=DEFAULT_OUTPUT_ENCODING,
+ formatter="minimal"):
+ """Returns a string or Unicode representation of this document.
+ To get Unicode, pass None for encoding."""
+
+ if self.is_xml:
+ # Print the XML declaration
+ encoding_part = ''
+ if eventual_encoding != None:
+ encoding_part = ' encoding="%s"' % eventual_encoding
+ prefix = '<?xml version="1.0"%s?>\n' % encoding_part
+ else:
+ prefix = ''
+ if not pretty_print:
+ indent_level = None
+ else:
+ indent_level = 0
+ return prefix + super(BeautifulSoup, self).decode(
+ indent_level, eventual_encoding, formatter)
+
+# Alias to make it easier to type import: 'from bs4 import _soup'
+_s = BeautifulSoup
+_soup = BeautifulSoup
+
+class BeautifulStoneSoup(BeautifulSoup):
+ """Deprecated interface to an XML parser."""
+
+ def __init__(self, *args, **kwargs):
+ kwargs['features'] = 'xml'
+ warnings.warn(
+ 'The BeautifulStoneSoup class is deprecated. Instead of using '
+ 'it, pass features="xml" into the BeautifulSoup constructor.')
+ super(BeautifulStoneSoup, self).__init__(*args, **kwargs)
+
+
+class StopParsing(Exception):
+ pass
+
+class FeatureNotFound(ValueError):
+ pass
+
+
+#By default, act as an HTML pretty-printer.
+if __name__ == '__main__':
+ import sys
+ soup = BeautifulSoup(sys.stdin)
+ print(soup.prettify())
diff --git a/libs/bs4/builder/__init__.py b/libs/bs4/builder/__init__.py
new file mode 100644
index 000000000..cc497cf0b
--- /dev/null
+++ b/libs/bs4/builder/__init__.py
@@ -0,0 +1,367 @@
+# Use of this source code is governed by the MIT license.
+__license__ = "MIT"
+
+from collections import defaultdict
+import itertools
+import sys
+from bs4.element import (
+ CharsetMetaAttributeValue,
+ ContentMetaAttributeValue,
+ nonwhitespace_re
+ )
+
+__all__ = [
+ 'HTMLTreeBuilder',
+ 'SAXTreeBuilder',
+ 'TreeBuilder',
+ 'TreeBuilderRegistry',
+ ]
+
+# Some useful features for a TreeBuilder to have.
+FAST = 'fast'
+PERMISSIVE = 'permissive'
+STRICT = 'strict'
+XML = 'xml'
+HTML = 'html'
+HTML_5 = 'html5'
+
+
+class TreeBuilderRegistry(object):
+
+ def __init__(self):
+ self.builders_for_feature = defaultdict(list)
+ self.builders = []
+
+ def register(self, treebuilder_class):
+ """Register a treebuilder based on its advertised features."""
+ for feature in treebuilder_class.features:
+ self.builders_for_feature[feature].insert(0, treebuilder_class)
+ self.builders.insert(0, treebuilder_class)
+
+ def lookup(self, *features):
+ if len(self.builders) == 0:
+ # There are no builders at all.
+ return None
+
+ if len(features) == 0:
+ # They didn't ask for any features. Give them the most
+ # recently registered builder.
+ return self.builders[0]
+
+ # Go down the list of features in order, and eliminate any builders
+ # that don't match every feature.
+ features = list(features)
+ features.reverse()
+ candidates = None
+ candidate_set = None
+ while len(features) > 0:
+ feature = features.pop()
+ we_have_the_feature = self.builders_for_feature.get(feature, [])
+ if len(we_have_the_feature) > 0:
+ if candidates is None:
+ candidates = we_have_the_feature
+ candidate_set = set(candidates)
+ else:
+ # Eliminate any candidates that don't have this feature.
+ candidate_set = candidate_set.intersection(
+ set(we_have_the_feature))
+
+ # The only valid candidates are the ones in candidate_set.
+ # Go through the original list of candidates and pick the first one
+ # that's in candidate_set.
+ if candidate_set is None:
+ return None
+ for candidate in candidates:
+ if candidate in candidate_set:
+ return candidate
+ return None
+
+# The BeautifulSoup class will take feature lists from developers and use them
+# to look up builders in this registry.
+builder_registry = TreeBuilderRegistry()
+
+class TreeBuilder(object):
+ """Turn a document into a Beautiful Soup object tree."""
+
+ NAME = "[Unknown tree builder]"
+ ALTERNATE_NAMES = []
+ features = []
+
+ is_xml = False
+ picklable = False
+ empty_element_tags = None # A tag will be considered an empty-element
+ # tag when and only when it has no contents.
+
+ # A value for these tag/attribute combinations is a space- or
+ # comma-separated list of CDATA, rather than a single CDATA.
+ DEFAULT_CDATA_LIST_ATTRIBUTES = {}
+
+ DEFAULT_PRESERVE_WHITESPACE_TAGS = set()
+
+ USE_DEFAULT = object()
+
+ def __init__(self, multi_valued_attributes=USE_DEFAULT, preserve_whitespace_tags=USE_DEFAULT):
+ """Constructor.
+
+ :param multi_valued_attributes: If this is set to None, the
+ TreeBuilder will not turn any values for attributes like
+ 'class' into lists. Setting this do a dictionary will
+ customize this behavior; look at DEFAULT_CDATA_LIST_ATTRIBUTES
+ for an example.
+
+ Internally, these are called "CDATA list attributes", but that
+ probably doesn't make sense to an end-user, so the argument name
+ is `multi_valued_attributes`.
+
+ :param preserve_whitespace_tags:
+ """
+ self.soup = None
+ if multi_valued_attributes is self.USE_DEFAULT:
+ multi_valued_attributes = self.DEFAULT_CDATA_LIST_ATTRIBUTES
+ self.cdata_list_attributes = multi_valued_attributes
+ if preserve_whitespace_tags is self.USE_DEFAULT:
+ preserve_whitespace_tags = self.DEFAULT_PRESERVE_WHITESPACE_TAGS
+ self.preserve_whitespace_tags = preserve_whitespace_tags
+
+ def initialize_soup(self, soup):
+ """The BeautifulSoup object has been initialized and is now
+ being associated with the TreeBuilder.
+ """
+ self.soup = soup
+
+ def reset(self):
+ pass
+
+ def can_be_empty_element(self, tag_name):
+ """Might a tag with this name be an empty-element tag?
+
+ The final markup may or may not actually present this tag as
+ self-closing.
+
+ For instance: an HTMLBuilder does not consider a <p> tag to be
+ an empty-element tag (it's not in
+ HTMLBuilder.empty_element_tags). This means an empty <p> tag
+ will be presented as "<p></p>", not "<p />".
+
+ The default implementation has no opinion about which tags are
+ empty-element tags, so a tag will be presented as an
+ empty-element tag if and only if it has no contents.
+ "<foo></foo>" will become "<foo />", and "<foo>bar</foo>" will
+ be left alone.
+ """
+ if self.empty_element_tags is None:
+ return True
+ return tag_name in self.empty_element_tags
+
+ def feed(self, markup):
+ raise NotImplementedError()
+
+ def prepare_markup(self, markup, user_specified_encoding=None,
+ document_declared_encoding=None):
+ return markup, None, None, False
+
+ def test_fragment_to_document(self, fragment):
+ """Wrap an HTML fragment to make it look like a document.
+
+ Different parsers do this differently. For instance, lxml
+ introduces an empty <head> tag, and html5lib
+ doesn't. Abstracting this away lets us write simple tests
+ which run HTML fragments through the parser and compare the
+ results against other HTML fragments.
+
+ This method should not be used outside of tests.
+ """
+ return fragment
+
+ def set_up_substitutions(self, tag):
+ return False
+
+ def _replace_cdata_list_attribute_values(self, tag_name, attrs):
+ """Replaces class="foo bar" with class=["foo", "bar"]
+
+ Modifies its input in place.
+ """
+ if not attrs:
+ return attrs
+ if self.cdata_list_attributes:
+ universal = self.cdata_list_attributes.get('*', [])
+ tag_specific = self.cdata_list_attributes.get(
+ tag_name.lower(), None)
+ for attr in list(attrs.keys()):
+ if attr in universal or (tag_specific and attr in tag_specific):
+ # We have a "class"-type attribute whose string
+ # value is a whitespace-separated list of
+ # values. Split it into a list.
+ value = attrs[attr]
+ if isinstance(value, str):
+ values = nonwhitespace_re.findall(value)
+ else:
+ # html5lib sometimes calls setAttributes twice
+ # for the same tag when rearranging the parse
+ # tree. On the second call the attribute value
+ # here is already a list. If this happens,
+ # leave the value alone rather than trying to
+ # split it again.
+ values = value
+ attrs[attr] = values
+ return attrs
+
+class SAXTreeBuilder(TreeBuilder):
+ """A Beautiful Soup treebuilder that listens for SAX events."""
+
+ def feed(self, markup):
+ raise NotImplementedError()
+
+ def close(self):
+ pass
+
+ def startElement(self, name, attrs):
+ attrs = dict((key[1], value) for key, value in list(attrs.items()))
+ #print "Start %s, %r" % (name, attrs)
+ self.soup.handle_starttag(name, attrs)
+
+ def endElement(self, name):
+ #print "End %s" % name
+ self.soup.handle_endtag(name)
+
+ def startElementNS(self, nsTuple, nodeName, attrs):
+ # Throw away (ns, nodeName) for now.
+ self.startElement(nodeName, attrs)
+
+ def endElementNS(self, nsTuple, nodeName):
+ # Throw away (ns, nodeName) for now.
+ self.endElement(nodeName)
+ #handler.endElementNS((ns, node.nodeName), node.nodeName)
+
+ def startPrefixMapping(self, prefix, nodeValue):
+ # Ignore the prefix for now.
+ pass
+
+ def endPrefixMapping(self, prefix):
+ # Ignore the prefix for now.
+ # handler.endPrefixMapping(prefix)
+ pass
+
+ def characters(self, content):
+ self.soup.handle_data(content)
+
+ def startDocument(self):
+ pass
+
+ def endDocument(self):
+ pass
+
+
+class HTMLTreeBuilder(TreeBuilder):
+ """This TreeBuilder knows facts about HTML.
+
+ Such as which tags are empty-element tags.
+ """
+
+ empty_element_tags = set([
+ # These are from HTML5.
+ 'area', 'base', 'br', 'col', 'embed', 'hr', 'img', 'input', 'keygen', 'link', 'menuitem', 'meta', 'param', 'source', 'track', 'wbr',
+
+ # These are from earlier versions of HTML and are removed in HTML5.
+ 'basefont', 'bgsound', 'command', 'frame', 'image', 'isindex', 'nextid', 'spacer'
+ ])
+
+ # The HTML standard defines these as block-level elements. Beautiful
+ # Soup does not treat these elements differently from other elements,
+ # but it may do so eventually, and this information is available if
+ # you need to use it.
+ block_elements = set(["address", "article", "aside", "blockquote", "canvas", "dd", "div", "dl", "dt", "fieldset", "figcaption", "figure", "footer", "form", "h1", "h2", "h3", "h4", "h5", "h6", "header", "hr", "li", "main", "nav", "noscript", "ol", "output", "p", "pre", "section", "table", "tfoot", "ul", "video"])
+
+ # The HTML standard defines these attributes as containing a
+ # space-separated list of values, not a single value. That is,
+ # class="foo bar" means that the 'class' attribute has two values,
+ # 'foo' and 'bar', not the single value 'foo bar'. When we
+ # encounter one of these attributes, we will parse its value into
+ # a list of values if possible. Upon output, the list will be
+ # converted back into a string.
+ DEFAULT_CDATA_LIST_ATTRIBUTES = {
+ "*" : ['class', 'accesskey', 'dropzone'],
+ "a" : ['rel', 'rev'],
+ "link" : ['rel', 'rev'],
+ "td" : ["headers"],
+ "th" : ["headers"],
+ "td" : ["headers"],
+ "form" : ["accept-charset"],
+ "object" : ["archive"],
+
+ # These are HTML5 specific, as are *.accesskey and *.dropzone above.
+ "area" : ["rel"],
+ "icon" : ["sizes"],
+ "iframe" : ["sandbox"],
+ "output" : ["for"],
+ }
+
+ DEFAULT_PRESERVE_WHITESPACE_TAGS = set(['pre', 'textarea'])
+
+ def set_up_substitutions(self, tag):
+ # We are only interested in <meta> tags
+ if tag.name != 'meta':
+ return False
+
+ http_equiv = tag.get('http-equiv')
+ content = tag.get('content')
+ charset = tag.get('charset')
+
+ # We are interested in <meta> tags that say what encoding the
+ # document was originally in. This means HTML 5-style <meta>
+ # tags that provide the "charset" attribute. It also means
+ # HTML 4-style <meta> tags that provide the "content"
+ # attribute and have "http-equiv" set to "content-type".
+ #
+ # In both cases we will replace the value of the appropriate
+ # attribute with a standin object that can take on any
+ # encoding.
+ meta_encoding = None
+ if charset is not None:
+ # HTML 5 style:
+ # <meta charset="utf8">
+ meta_encoding = charset
+ tag['charset'] = CharsetMetaAttributeValue(charset)
+
+ elif (content is not None and http_equiv is not None
+ and http_equiv.lower() == 'content-type'):
+ # HTML 4 style:
+ # <meta http-equiv="content-type" content="text/html; charset=utf8">
+ tag['content'] = ContentMetaAttributeValue(content)
+
+ return (meta_encoding is not None)
+
+def register_treebuilders_from(module):
+ """Copy TreeBuilders from the given module into this module."""
+ # I'm fairly sure this is not the best way to do this.
+ this_module = sys.modules['bs4.builder']
+ for name in module.__all__:
+ obj = getattr(module, name)
+
+ if issubclass(obj, TreeBuilder):
+ setattr(this_module, name, obj)
+ this_module.__all__.append(name)
+ # Register the builder while we're at it.
+ this_module.builder_registry.register(obj)
+
+class ParserRejectedMarkup(Exception):
+ pass
+
+# Builders are registered in reverse order of priority, so that custom
+# builder registrations will take precedence. In general, we want lxml
+# to take precedence over html5lib, because it's faster. And we only
+# want to use HTMLParser as a last result.
+from . import _htmlparser
+register_treebuilders_from(_htmlparser)
+try:
+ from . import _html5lib
+ register_treebuilders_from(_html5lib)
+except ImportError:
+ # They don't have html5lib installed.
+ pass
+try:
+ from . import _lxml
+ register_treebuilders_from(_lxml)
+except ImportError:
+ # They don't have lxml installed.
+ pass
diff --git a/libs/bs4/builder/_html5lib.py b/libs/bs4/builder/_html5lib.py
new file mode 100644
index 000000000..090bb61a8
--- /dev/null
+++ b/libs/bs4/builder/_html5lib.py
@@ -0,0 +1,426 @@
+# Use of this source code is governed by the MIT license.
+__license__ = "MIT"
+
+__all__ = [
+ 'HTML5TreeBuilder',
+ ]
+
+import warnings
+import re
+from bs4.builder import (
+ PERMISSIVE,
+ HTML,
+ HTML_5,
+ HTMLTreeBuilder,
+ )
+from bs4.element import (
+ NamespacedAttribute,
+ nonwhitespace_re,
+)
+import html5lib
+from html5lib.constants import (
+ namespaces,
+ prefixes,
+ )
+from bs4.element import (
+ Comment,
+ Doctype,
+ NavigableString,
+ Tag,
+ )
+
+try:
+ # Pre-0.99999999
+ from html5lib.treebuilders import _base as treebuilder_base
+ new_html5lib = False
+except ImportError as e:
+ # 0.99999999 and up
+ from html5lib.treebuilders import base as treebuilder_base
+ new_html5lib = True
+
+class HTML5TreeBuilder(HTMLTreeBuilder):
+ """Use html5lib to build a tree."""
+
+ NAME = "html5lib"
+
+ features = [NAME, PERMISSIVE, HTML_5, HTML]
+
+ def prepare_markup(self, markup, user_specified_encoding,
+ document_declared_encoding=None, exclude_encodings=None):
+ # Store the user-specified encoding for use later on.
+ self.user_specified_encoding = user_specified_encoding
+
+ # document_declared_encoding and exclude_encodings aren't used
+ # ATM because the html5lib TreeBuilder doesn't use
+ # UnicodeDammit.
+ if exclude_encodings:
+ warnings.warn("You provided a value for exclude_encoding, but the html5lib tree builder doesn't support exclude_encoding.")
+ yield (markup, None, None, False)
+
+ # These methods are defined by Beautiful Soup.
+ def feed(self, markup):
+ if self.soup.parse_only is not None:
+ warnings.warn("You provided a value for parse_only, but the html5lib tree builder doesn't support parse_only. The entire document will be parsed.")
+ parser = html5lib.HTMLParser(tree=self.create_treebuilder)
+
+ extra_kwargs = dict()
+ if not isinstance(markup, str):
+ if new_html5lib:
+ extra_kwargs['override_encoding'] = self.user_specified_encoding
+ else:
+ extra_kwargs['encoding'] = self.user_specified_encoding
+ doc = parser.parse(markup, **extra_kwargs)
+
+ # Set the character encoding detected by the tokenizer.
+ if isinstance(markup, str):
+ # We need to special-case this because html5lib sets
+ # charEncoding to UTF-8 if it gets Unicode input.
+ doc.original_encoding = None
+ else:
+ original_encoding = parser.tokenizer.stream.charEncoding[0]
+ if not isinstance(original_encoding, str):
+ # In 0.99999999 and up, the encoding is an html5lib
+ # Encoding object. We want to use a string for compatibility
+ # with other tree builders.
+ original_encoding = original_encoding.name
+ doc.original_encoding = original_encoding
+
+ def create_treebuilder(self, namespaceHTMLElements):
+ self.underlying_builder = TreeBuilderForHtml5lib(
+ namespaceHTMLElements, self.soup)
+ return self.underlying_builder
+
+ def test_fragment_to_document(self, fragment):
+ """See `TreeBuilder`."""
+ return '<html><head></head><body>%s</body></html>' % fragment
+
+
+class TreeBuilderForHtml5lib(treebuilder_base.TreeBuilder):
+
+ def __init__(self, namespaceHTMLElements, soup=None):
+ if soup:
+ self.soup = soup
+ else:
+ from bs4 import BeautifulSoup
+ self.soup = BeautifulSoup("", "html.parser")
+ super(TreeBuilderForHtml5lib, self).__init__(namespaceHTMLElements)
+
+ def documentClass(self):
+ self.soup.reset()
+ return Element(self.soup, self.soup, None)
+
+ def insertDoctype(self, token):
+ name = token["name"]
+ publicId = token["publicId"]
+ systemId = token["systemId"]
+
+ doctype = Doctype.for_name_and_ids(name, publicId, systemId)
+ self.soup.object_was_parsed(doctype)
+
+ def elementClass(self, name, namespace):
+ tag = self.soup.new_tag(name, namespace)
+ return Element(tag, self.soup, namespace)
+
+ def commentClass(self, data):
+ return TextNode(Comment(data), self.soup)
+
+ def fragmentClass(self):
+ from bs4 import BeautifulSoup
+ self.soup = BeautifulSoup("", "html.parser")
+ self.soup.name = "[document_fragment]"
+ return Element(self.soup, self.soup, None)
+
+ def appendChild(self, node):
+ # XXX This code is not covered by the BS4 tests.
+ self.soup.append(node.element)
+
+ def getDocument(self):
+ return self.soup
+
+ def getFragment(self):
+ return treebuilder_base.TreeBuilder.getFragment(self).element
+
+ def testSerializer(self, element):
+ from bs4 import BeautifulSoup
+ rv = []
+ doctype_re = re.compile(r'^(.*?)(?: PUBLIC "(.*?)"(?: "(.*?)")?| SYSTEM "(.*?)")?$')
+
+ def serializeElement(element, indent=0):
+ if isinstance(element, BeautifulSoup):
+ pass
+ if isinstance(element, Doctype):
+ m = doctype_re.match(element)
+ if m:
+ name = m.group(1)
+ if m.lastindex > 1:
+ publicId = m.group(2) or ""
+ systemId = m.group(3) or m.group(4) or ""
+ rv.append("""|%s<!DOCTYPE %s "%s" "%s">""" %
+ (' ' * indent, name, publicId, systemId))
+ else:
+ rv.append("|%s<!DOCTYPE %s>" % (' ' * indent, name))
+ else:
+ rv.append("|%s<!DOCTYPE >" % (' ' * indent,))
+ elif isinstance(element, Comment):
+ rv.append("|%s<!-- %s -->" % (' ' * indent, element))
+ elif isinstance(element, NavigableString):
+ rv.append("|%s\"%s\"" % (' ' * indent, element))
+ else:
+ if element.namespace:
+ name = "%s %s" % (prefixes[element.namespace],
+ element.name)
+ else:
+ name = element.name
+ rv.append("|%s<%s>" % (' ' * indent, name))
+ if element.attrs:
+ attributes = []
+ for name, value in list(element.attrs.items()):
+ if isinstance(name, NamespacedAttribute):
+ name = "%s %s" % (prefixes[name.namespace], name.name)
+ if isinstance(value, list):
+ value = " ".join(value)
+ attributes.append((name, value))
+
+ for name, value in sorted(attributes):
+ rv.append('|%s%s="%s"' % (' ' * (indent + 2), name, value))
+ indent += 2
+ for child in element.children:
+ serializeElement(child, indent)
+ serializeElement(element, 0)
+
+ return "\n".join(rv)
+
+class AttrList(object):
+ def __init__(self, element):
+ self.element = element
+ self.attrs = dict(self.element.attrs)
+ def __iter__(self):
+ return list(self.attrs.items()).__iter__()
+ def __setitem__(self, name, value):
+ # If this attribute is a multi-valued attribute for this element,
+ # turn its value into a list.
+ list_attr = self.element.cdata_list_attributes
+ if (name in list_attr['*']
+ or (self.element.name in list_attr
+ and name in list_attr[self.element.name])):
+ # A node that is being cloned may have already undergone
+ # this procedure.
+ if not isinstance(value, list):
+ value = nonwhitespace_re.findall(value)
+ self.element[name] = value
+ def items(self):
+ return list(self.attrs.items())
+ def keys(self):
+ return list(self.attrs.keys())
+ def __len__(self):
+ return len(self.attrs)
+ def __getitem__(self, name):
+ return self.attrs[name]
+ def __contains__(self, name):
+ return name in list(self.attrs.keys())
+
+
+class Element(treebuilder_base.Node):
+ def __init__(self, element, soup, namespace):
+ treebuilder_base.Node.__init__(self, element.name)
+ self.element = element
+ self.soup = soup
+ self.namespace = namespace
+
+ def appendChild(self, node):
+ string_child = child = None
+ if isinstance(node, str):
+ # Some other piece of code decided to pass in a string
+ # instead of creating a TextElement object to contain the
+ # string.
+ string_child = child = node
+ elif isinstance(node, Tag):
+ # Some other piece of code decided to pass in a Tag
+ # instead of creating an Element object to contain the
+ # Tag.
+ child = node
+ elif node.element.__class__ == NavigableString:
+ string_child = child = node.element
+ node.parent = self
+ else:
+ child = node.element
+ node.parent = self
+
+ if not isinstance(child, str) and child.parent is not None:
+ node.element.extract()
+
+ if (string_child is not None and self.element.contents
+ and self.element.contents[-1].__class__ == NavigableString):
+ # We are appending a string onto another string.
+ # TODO This has O(n^2) performance, for input like
+ # "a</a>a</a>a</a>..."
+ old_element = self.element.contents[-1]
+ new_element = self.soup.new_string(old_element + string_child)
+ old_element.replace_with(new_element)
+ self.soup._most_recent_element = new_element
+ else:
+ if isinstance(node, str):
+ # Create a brand new NavigableString from this string.
+ child = self.soup.new_string(node)
+
+ # Tell Beautiful Soup to act as if it parsed this element
+ # immediately after the parent's last descendant. (Or
+ # immediately after the parent, if it has no children.)
+ if self.element.contents:
+ most_recent_element = self.element._last_descendant(False)
+ elif self.element.next_element is not None:
+ # Something from further ahead in the parse tree is
+ # being inserted into this earlier element. This is
+ # very annoying because it means an expensive search
+ # for the last element in the tree.
+ most_recent_element = self.soup._last_descendant()
+ else:
+ most_recent_element = self.element
+
+ self.soup.object_was_parsed(
+ child, parent=self.element,
+ most_recent_element=most_recent_element)
+
+ def getAttributes(self):
+ if isinstance(self.element, Comment):
+ return {}
+ return AttrList(self.element)
+
+ def setAttributes(self, attributes):
+
+ if attributes is not None and len(attributes) > 0:
+
+ converted_attributes = []
+ for name, value in list(attributes.items()):
+ if isinstance(name, tuple):
+ new_name = NamespacedAttribute(*name)
+ del attributes[name]
+ attributes[new_name] = value
+
+ self.soup.builder._replace_cdata_list_attribute_values(
+ self.name, attributes)
+ for name, value in list(attributes.items()):
+ self.element[name] = value
+
+ # The attributes may contain variables that need substitution.
+ # Call set_up_substitutions manually.
+ #
+ # The Tag constructor called this method when the Tag was created,
+ # but we just set/changed the attributes, so call it again.
+ self.soup.builder.set_up_substitutions(self.element)
+ attributes = property(getAttributes, setAttributes)
+
+ def insertText(self, data, insertBefore=None):
+ text = TextNode(self.soup.new_string(data), self.soup)
+ if insertBefore:
+ self.insertBefore(text, insertBefore)
+ else:
+ self.appendChild(text)
+
+ def insertBefore(self, node, refNode):
+ index = self.element.index(refNode.element)
+ if (node.element.__class__ == NavigableString and self.element.contents
+ and self.element.contents[index-1].__class__ == NavigableString):
+ # (See comments in appendChild)
+ old_node = self.element.contents[index-1]
+ new_str = self.soup.new_string(old_node + node.element)
+ old_node.replace_with(new_str)
+ else:
+ self.element.insert(index, node.element)
+ node.parent = self
+
+ def removeChild(self, node):
+ node.element.extract()
+
+ def reparentChildren(self, new_parent):
+ """Move all of this tag's children into another tag."""
+ # print "MOVE", self.element.contents
+ # print "FROM", self.element
+ # print "TO", new_parent.element
+
+ element = self.element
+ new_parent_element = new_parent.element
+ # Determine what this tag's next_element will be once all the children
+ # are removed.
+ final_next_element = element.next_sibling
+
+ new_parents_last_descendant = new_parent_element._last_descendant(False, False)
+ if len(new_parent_element.contents) > 0:
+ # The new parent already contains children. We will be
+ # appending this tag's children to the end.
+ new_parents_last_child = new_parent_element.contents[-1]
+ new_parents_last_descendant_next_element = new_parents_last_descendant.next_element
+ else:
+ # The new parent contains no children.
+ new_parents_last_child = None
+ new_parents_last_descendant_next_element = new_parent_element.next_element
+
+ to_append = element.contents
+ if len(to_append) > 0:
+ # Set the first child's previous_element and previous_sibling
+ # to elements within the new parent
+ first_child = to_append[0]
+ if new_parents_last_descendant is not None:
+ first_child.previous_element = new_parents_last_descendant
+ else:
+ first_child.previous_element = new_parent_element
+ first_child.previous_sibling = new_parents_last_child
+ if new_parents_last_descendant is not None:
+ new_parents_last_descendant.next_element = first_child
+ else:
+ new_parent_element.next_element = first_child
+ if new_parents_last_child is not None:
+ new_parents_last_child.next_sibling = first_child
+
+ # Find the very last element being moved. It is now the
+ # parent's last descendant. It has no .next_sibling and
+ # its .next_element is whatever the previous last
+ # descendant had.
+ last_childs_last_descendant = to_append[-1]._last_descendant(False, True)
+
+ last_childs_last_descendant.next_element = new_parents_last_descendant_next_element
+ if new_parents_last_descendant_next_element is not None:
+ # TODO: This code has no test coverage and I'm not sure
+ # how to get html5lib to go through this path, but it's
+ # just the other side of the previous line.
+ new_parents_last_descendant_next_element.previous_element = last_childs_last_descendant
+ last_childs_last_descendant.next_sibling = None
+
+ for child in to_append:
+ child.parent = new_parent_element
+ new_parent_element.contents.append(child)
+
+ # Now that this element has no children, change its .next_element.
+ element.contents = []
+ element.next_element = final_next_element
+
+ # print "DONE WITH MOVE"
+ # print "FROM", self.element
+ # print "TO", new_parent_element
+
+ def cloneNode(self):
+ tag = self.soup.new_tag(self.element.name, self.namespace)
+ node = Element(tag, self.soup, self.namespace)
+ for key,value in self.attributes:
+ node.attributes[key] = value
+ return node
+
+ def hasContent(self):
+ return self.element.contents
+
+ def getNameTuple(self):
+ if self.namespace == None:
+ return namespaces["html"], self.name
+ else:
+ return self.namespace, self.name
+
+ nameTuple = property(getNameTuple)
+
+class TextNode(Element):
+ def __init__(self, element, soup):
+ treebuilder_base.Node.__init__(self, None)
+ self.element = element
+ self.soup = soup
+
+ def cloneNode(self):
+ raise NotImplementedError
diff --git a/libs/bs4/builder/_htmlparser.py b/libs/bs4/builder/_htmlparser.py
new file mode 100644
index 000000000..ea549c356
--- /dev/null
+++ b/libs/bs4/builder/_htmlparser.py
@@ -0,0 +1,350 @@
+# encoding: utf-8
+"""Use the HTMLParser library to parse HTML files that aren't too bad."""
+
+# Use of this source code is governed by the MIT license.
+__license__ = "MIT"
+
+__all__ = [
+ 'HTMLParserTreeBuilder',
+ ]
+
+from html.parser import HTMLParser
+
+try:
+ from html.parser import HTMLParseError
+except ImportError as e:
+ # HTMLParseError is removed in Python 3.5. Since it can never be
+ # thrown in 3.5, we can just define our own class as a placeholder.
+ class HTMLParseError(Exception):
+ pass
+
+import sys
+import warnings
+
+# Starting in Python 3.2, the HTMLParser constructor takes a 'strict'
+# argument, which we'd like to set to False. Unfortunately,
+# http://bugs.python.org/issue13273 makes strict=True a better bet
+# before Python 3.2.3.
+#
+# At the end of this file, we monkeypatch HTMLParser so that
+# strict=True works well on Python 3.2.2.
+major, minor, release = sys.version_info[:3]
+CONSTRUCTOR_TAKES_STRICT = major == 3 and minor == 2 and release >= 3
+CONSTRUCTOR_STRICT_IS_DEPRECATED = major == 3 and minor == 3
+CONSTRUCTOR_TAKES_CONVERT_CHARREFS = major == 3 and minor >= 4
+
+
+from bs4.element import (
+ CData,
+ Comment,
+ Declaration,
+ Doctype,
+ ProcessingInstruction,
+ )
+from bs4.dammit import EntitySubstitution, UnicodeDammit
+
+from bs4.builder import (
+ HTML,
+ HTMLTreeBuilder,
+ STRICT,
+ )
+
+
+HTMLPARSER = 'html.parser'
+
+class BeautifulSoupHTMLParser(HTMLParser):
+
+ def __init__(self, *args, **kwargs):
+ HTMLParser.__init__(self, *args, **kwargs)
+
+ # Keep a list of empty-element tags that were encountered
+ # without an explicit closing tag. If we encounter a closing tag
+ # of this type, we'll associate it with one of those entries.
+ #
+ # This isn't a stack because we don't care about the
+ # order. It's a list of closing tags we've already handled and
+ # will ignore, assuming they ever show up.
+ self.already_closed_empty_element = []
+
+ def error(self, msg):
+ """In Python 3, HTMLParser subclasses must implement error(), although this
+ requirement doesn't appear to be documented.
+
+ In Python 2, HTMLParser implements error() as raising an exception.
+
+ In any event, this method is called only on very strange markup and our best strategy
+ is to pretend it didn't happen and keep going.
+ """
+ warnings.warn(msg)
+
+ def handle_startendtag(self, name, attrs):
+ # This is only called when the markup looks like
+ # <tag/>.
+
+ # is_startend() tells handle_starttag not to close the tag
+ # just because its name matches a known empty-element tag. We
+ # know that this is an empty-element tag and we want to call
+ # handle_endtag ourselves.
+ tag = self.handle_starttag(name, attrs, handle_empty_element=False)
+ self.handle_endtag(name)
+
+ def handle_starttag(self, name, attrs, handle_empty_element=True):
+ # XXX namespace
+ attr_dict = {}
+ for key, value in attrs:
+ # Change None attribute values to the empty string
+ # for consistency with the other tree builders.
+ if value is None:
+ value = ''
+ attr_dict[key] = value
+ attrvalue = '""'
+ #print "START", name
+ tag = self.soup.handle_starttag(name, None, None, attr_dict)
+ if tag and tag.is_empty_element and handle_empty_element:
+ # Unlike other parsers, html.parser doesn't send separate end tag
+ # events for empty-element tags. (It's handled in
+ # handle_startendtag, but only if the original markup looked like
+ # <tag/>.)
+ #
+ # So we need to call handle_endtag() ourselves. Since we
+ # know the start event is identical to the end event, we
+ # don't want handle_endtag() to cross off any previous end
+ # events for tags of this name.
+ self.handle_endtag(name, check_already_closed=False)
+
+ # But we might encounter an explicit closing tag for this tag
+ # later on. If so, we want to ignore it.
+ self.already_closed_empty_element.append(name)
+
+ def handle_endtag(self, name, check_already_closed=True):
+ #print "END", name
+ if check_already_closed and name in self.already_closed_empty_element:
+ # This is a redundant end tag for an empty-element tag.
+ # We've already called handle_endtag() for it, so just
+ # check it off the list.
+ # print "ALREADY CLOSED", name
+ self.already_closed_empty_element.remove(name)
+ else:
+ self.soup.handle_endtag(name)
+
+ def handle_data(self, data):
+ self.soup.handle_data(data)
+
+ def handle_charref(self, name):
+ # XXX workaround for a bug in HTMLParser. Remove this once
+ # it's fixed in all supported versions.
+ # http://bugs.python.org/issue13633
+ if name.startswith('x'):
+ real_name = int(name.lstrip('x'), 16)
+ elif name.startswith('X'):
+ real_name = int(name.lstrip('X'), 16)
+ else:
+ real_name = int(name)
+
+ data = None
+ if real_name < 256:
+ # HTML numeric entities are supposed to reference Unicode
+ # code points, but sometimes they reference code points in
+ # some other encoding (ahem, Windows-1252). E.g. &#147;
+ # instead of &#201; for LEFT DOUBLE QUOTATION MARK. This
+ # code tries to detect this situation and compensate.
+ for encoding in (self.soup.original_encoding, 'windows-1252'):
+ if not encoding:
+ continue
+ try:
+ data = bytearray([real_name]).decode(encoding)
+ except UnicodeDecodeError as e:
+ pass
+ if not data:
+ try:
+ data = chr(real_name)
+ except (ValueError, OverflowError) as e:
+ pass
+ data = data or "\N{REPLACEMENT CHARACTER}"
+ self.handle_data(data)
+
+ def handle_entityref(self, name):
+ character = EntitySubstitution.HTML_ENTITY_TO_CHARACTER.get(name)
+ if character is not None:
+ data = character
+ else:
+ # If this were XML, it would be ambiguous whether "&foo"
+ # was an character entity reference with a missing
+ # semicolon or the literal string "&foo". Since this is
+ # HTML, we have a complete list of all character entity references,
+ # and this one wasn't found, so assume it's the literal string "&foo".
+ data = "&%s" % name
+ self.handle_data(data)
+
+ def handle_comment(self, data):
+ self.soup.endData()
+ self.soup.handle_data(data)
+ self.soup.endData(Comment)
+
+ def handle_decl(self, data):
+ self.soup.endData()
+ if data.startswith("DOCTYPE "):
+ data = data[len("DOCTYPE "):]
+ elif data == 'DOCTYPE':
+ # i.e. "<!DOCTYPE>"
+ data = ''
+ self.soup.handle_data(data)
+ self.soup.endData(Doctype)
+
+ def unknown_decl(self, data):
+ if data.upper().startswith('CDATA['):
+ cls = CData
+ data = data[len('CDATA['):]
+ else:
+ cls = Declaration
+ self.soup.endData()
+ self.soup.handle_data(data)
+ self.soup.endData(cls)
+
+ def handle_pi(self, data):
+ self.soup.endData()
+ self.soup.handle_data(data)
+ self.soup.endData(ProcessingInstruction)
+
+
+class HTMLParserTreeBuilder(HTMLTreeBuilder):
+
+ is_xml = False
+ picklable = True
+ NAME = HTMLPARSER
+ features = [NAME, HTML, STRICT]
+
+ def __init__(self, parser_args=None, parser_kwargs=None, **kwargs):
+ super(HTMLParserTreeBuilder, self).__init__(**kwargs)
+ parser_args = parser_args or []
+ parser_kwargs = parser_kwargs or {}
+ if CONSTRUCTOR_TAKES_STRICT and not CONSTRUCTOR_STRICT_IS_DEPRECATED:
+ parser_kwargs['strict'] = False
+ if CONSTRUCTOR_TAKES_CONVERT_CHARREFS:
+ parser_kwargs['convert_charrefs'] = False
+ self.parser_args = (parser_args, parser_kwargs)
+
+ def prepare_markup(self, markup, user_specified_encoding=None,
+ document_declared_encoding=None, exclude_encodings=None):
+ """
+ :return: A 4-tuple (markup, original encoding, encoding
+ declared within markup, whether any characters had to be
+ replaced with REPLACEMENT CHARACTER).
+ """
+ if isinstance(markup, str):
+ yield (markup, None, None, False)
+ return
+
+ try_encodings = [user_specified_encoding, document_declared_encoding]
+ dammit = UnicodeDammit(markup, try_encodings, is_html=True,
+ exclude_encodings=exclude_encodings)
+ yield (dammit.markup, dammit.original_encoding,
+ dammit.declared_html_encoding,
+ dammit.contains_replacement_characters)
+
+ def feed(self, markup):
+ args, kwargs = self.parser_args
+ parser = BeautifulSoupHTMLParser(*args, **kwargs)
+ parser.soup = self.soup
+ try:
+ parser.feed(markup)
+ parser.close()
+ except HTMLParseError as e:
+ warnings.warn(RuntimeWarning(
+ "Python's built-in HTMLParser cannot parse the given document. This is not a bug in Beautiful Soup. The best solution is to install an external parser (lxml or html5lib), and use Beautiful Soup with that parser. See http://www.crummy.com/software/BeautifulSoup/bs4/doc/#installing-a-parser for help."))
+ raise e
+ parser.already_closed_empty_element = []
+
+# Patch 3.2 versions of HTMLParser earlier than 3.2.3 to use some
+# 3.2.3 code. This ensures they don't treat markup like <p></p> as a
+# string.
+#
+# XXX This code can be removed once most Python 3 users are on 3.2.3.
+if major == 3 and minor == 2 and not CONSTRUCTOR_TAKES_STRICT:
+ import re
+ attrfind_tolerant = re.compile(
+ r'\s*((?<=[\'"\s])[^\s/>][^\s/=>]*)(\s*=+\s*'
+ r'(\'[^\']*\'|"[^"]*"|(?![\'"])[^>\s]*))?')
+ HTMLParserTreeBuilder.attrfind_tolerant = attrfind_tolerant
+
+ locatestarttagend = re.compile(r"""
+ <[a-zA-Z][-.a-zA-Z0-9:_]* # tag name
+ (?:\s+ # whitespace before attribute name
+ (?:[a-zA-Z_][-.:a-zA-Z0-9_]* # attribute name
+ (?:\s*=\s* # value indicator
+ (?:'[^']*' # LITA-enclosed value
+ |\"[^\"]*\" # LIT-enclosed value
+ |[^'\">\s]+ # bare value
+ )
+ )?
+ )
+ )*
+ \s* # trailing whitespace
+""", re.VERBOSE)
+ BeautifulSoupHTMLParser.locatestarttagend = locatestarttagend
+
+ from html.parser import tagfind, attrfind
+
+ def parse_starttag(self, i):
+ self.__starttag_text = None
+ endpos = self.check_for_whole_start_tag(i)
+ if endpos < 0:
+ return endpos
+ rawdata = self.rawdata
+ self.__starttag_text = rawdata[i:endpos]
+
+ # Now parse the data between i+1 and j into a tag and attrs
+ attrs = []
+ match = tagfind.match(rawdata, i+1)
+ assert match, 'unexpected call to parse_starttag()'
+ k = match.end()
+ self.lasttag = tag = rawdata[i+1:k].lower()
+ while k < endpos:
+ if self.strict:
+ m = attrfind.match(rawdata, k)
+ else:
+ m = attrfind_tolerant.match(rawdata, k)
+ if not m:
+ break
+ attrname, rest, attrvalue = m.group(1, 2, 3)
+ if not rest:
+ attrvalue = None
+ elif attrvalue[:1] == '\'' == attrvalue[-1:] or \
+ attrvalue[:1] == '"' == attrvalue[-1:]:
+ attrvalue = attrvalue[1:-1]
+ if attrvalue:
+ attrvalue = self.unescape(attrvalue)
+ attrs.append((attrname.lower(), attrvalue))
+ k = m.end()
+
+ end = rawdata[k:endpos].strip()
+ if end not in (">", "/>"):
+ lineno, offset = self.getpos()
+ if "\n" in self.__starttag_text:
+ lineno = lineno + self.__starttag_text.count("\n")
+ offset = len(self.__starttag_text) \
+ - self.__starttag_text.rfind("\n")
+ else:
+ offset = offset + len(self.__starttag_text)
+ if self.strict:
+ self.error("junk characters in start tag: %r"
+ % (rawdata[k:endpos][:20],))
+ self.handle_data(rawdata[i:endpos])
+ return endpos
+ if end.endswith('/>'):
+ # XHTML-style empty tag: <span attr="value" />
+ self.handle_startendtag(tag, attrs)
+ else:
+ self.handle_starttag(tag, attrs)
+ if tag in self.CDATA_CONTENT_ELEMENTS:
+ self.set_cdata_mode(tag)
+ return endpos
+
+ def set_cdata_mode(self, elem):
+ self.cdata_elem = elem.lower()
+ self.interesting = re.compile(r'</\s*%s\s*>' % self.cdata_elem, re.I)
+
+ BeautifulSoupHTMLParser.parse_starttag = parse_starttag
+ BeautifulSoupHTMLParser.set_cdata_mode = set_cdata_mode
+
+ CONSTRUCTOR_TAKES_STRICT = True
diff --git a/libs/bs4/builder/_lxml.py b/libs/bs4/builder/_lxml.py
new file mode 100644
index 000000000..a490e2301
--- /dev/null
+++ b/libs/bs4/builder/_lxml.py
@@ -0,0 +1,296 @@
+# Use of this source code is governed by the MIT license.
+__license__ = "MIT"
+
+__all__ = [
+ 'LXMLTreeBuilderForXML',
+ 'LXMLTreeBuilder',
+ ]
+
+try:
+ from collections.abc import Callable # Python 3.6
+except ImportError as e:
+ from collections import Callable
+
+from io import BytesIO
+from io import StringIO
+from lxml import etree
+from bs4.element import (
+ Comment,
+ Doctype,
+ NamespacedAttribute,
+ ProcessingInstruction,
+ XMLProcessingInstruction,
+)
+from bs4.builder import (
+ FAST,
+ HTML,
+ HTMLTreeBuilder,
+ PERMISSIVE,
+ ParserRejectedMarkup,
+ TreeBuilder,
+ XML)
+from bs4.dammit import EncodingDetector
+
+LXML = 'lxml'
+
+def _invert(d):
+ "Invert a dictionary."
+ return dict((v,k) for k, v in list(d.items()))
+
+class LXMLTreeBuilderForXML(TreeBuilder):
+ DEFAULT_PARSER_CLASS = etree.XMLParser
+
+ is_xml = True
+ processing_instruction_class = XMLProcessingInstruction
+
+ NAME = "lxml-xml"
+ ALTERNATE_NAMES = ["xml"]
+
+ # Well, it's permissive by XML parser standards.
+ features = [NAME, LXML, XML, FAST, PERMISSIVE]
+
+ CHUNK_SIZE = 512
+
+ # This namespace mapping is specified in the XML Namespace
+ # standard.
+ DEFAULT_NSMAPS = dict(xml='http://www.w3.org/XML/1998/namespace')
+
+ DEFAULT_NSMAPS_INVERTED = _invert(DEFAULT_NSMAPS)
+
+ def initialize_soup(self, soup):
+ """Let the BeautifulSoup object know about the standard namespace
+ mapping.
+ """
+ super(LXMLTreeBuilderForXML, self).initialize_soup(soup)
+ self._register_namespaces(self.DEFAULT_NSMAPS)
+
+ def _register_namespaces(self, mapping):
+ """Let the BeautifulSoup object know about namespaces encountered
+ while parsing the document.
+
+ This might be useful later on when creating CSS selectors.
+ """
+ for key, value in list(mapping.items()):
+ if key and key not in self.soup._namespaces:
+ # Let the BeautifulSoup object know about a new namespace.
+ # If there are multiple namespaces defined with the same
+ # prefix, the first one in the document takes precedence.
+ self.soup._namespaces[key] = value
+
+ def default_parser(self, encoding):
+ # This can either return a parser object or a class, which
+ # will be instantiated with default arguments.
+ if self._default_parser is not None:
+ return self._default_parser
+ return etree.XMLParser(
+ target=self, strip_cdata=False, recover=True, encoding=encoding)
+
+ def parser_for(self, encoding):
+ # Use the default parser.
+ parser = self.default_parser(encoding)
+
+ if isinstance(parser, Callable):
+ # Instantiate the parser with default arguments
+ parser = parser(target=self, strip_cdata=False, encoding=encoding)
+ return parser
+
+ def __init__(self, parser=None, empty_element_tags=None, **kwargs):
+ # TODO: Issue a warning if parser is present but not a
+ # callable, since that means there's no way to create new
+ # parsers for different encodings.
+ self._default_parser = parser
+ if empty_element_tags is not None:
+ self.empty_element_tags = set(empty_element_tags)
+ self.soup = None
+ self.nsmaps = [self.DEFAULT_NSMAPS_INVERTED]
+ super(LXMLTreeBuilderForXML, self).__init__(**kwargs)
+
+ def _getNsTag(self, tag):
+ # Split the namespace URL out of a fully-qualified lxml tag
+ # name. Copied from lxml's src/lxml/sax.py.
+ if tag[0] == '{':
+ return tuple(tag[1:].split('}', 1))
+ else:
+ return (None, tag)
+
+ def prepare_markup(self, markup, user_specified_encoding=None,
+ exclude_encodings=None,
+ document_declared_encoding=None):
+ """
+ :yield: A series of 4-tuples.
+ (markup, encoding, declared encoding,
+ has undergone character replacement)
+
+ Each 4-tuple represents a strategy for parsing the document.
+ """
+ # Instead of using UnicodeDammit to convert the bytestring to
+ # Unicode using different encodings, use EncodingDetector to
+ # iterate over the encodings, and tell lxml to try to parse
+ # the document as each one in turn.
+ is_html = not self.is_xml
+ if is_html:
+ self.processing_instruction_class = ProcessingInstruction
+ else:
+ self.processing_instruction_class = XMLProcessingInstruction
+
+ if isinstance(markup, str):
+ # We were given Unicode. Maybe lxml can parse Unicode on
+ # this system?
+ yield markup, None, document_declared_encoding, False
+
+ if isinstance(markup, str):
+ # No, apparently not. Convert the Unicode to UTF-8 and
+ # tell lxml to parse it as UTF-8.
+ yield (markup.encode("utf8"), "utf8",
+ document_declared_encoding, False)
+
+ try_encodings = [user_specified_encoding, document_declared_encoding]
+ detector = EncodingDetector(
+ markup, try_encodings, is_html, exclude_encodings)
+ for encoding in detector.encodings:
+ yield (detector.markup, encoding, document_declared_encoding, False)
+
+ def feed(self, markup):
+ if isinstance(markup, bytes):
+ markup = BytesIO(markup)
+ elif isinstance(markup, str):
+ markup = StringIO(markup)
+
+ # Call feed() at least once, even if the markup is empty,
+ # or the parser won't be initialized.
+ data = markup.read(self.CHUNK_SIZE)
+ try:
+ self.parser = self.parser_for(self.soup.original_encoding)
+ self.parser.feed(data)
+ while len(data) != 0:
+ # Now call feed() on the rest of the data, chunk by chunk.
+ data = markup.read(self.CHUNK_SIZE)
+ if len(data) != 0:
+ self.parser.feed(data)
+ self.parser.close()
+ except (UnicodeDecodeError, LookupError, etree.ParserError) as e:
+ raise ParserRejectedMarkup(str(e))
+
+ def close(self):
+ self.nsmaps = [self.DEFAULT_NSMAPS_INVERTED]
+
+ def start(self, name, attrs, nsmap={}):
+ # Make sure attrs is a mutable dict--lxml may send an immutable dictproxy.
+ attrs = dict(attrs)
+ nsprefix = None
+ # Invert each namespace map as it comes in.
+ if len(nsmap) == 0 and len(self.nsmaps) > 1:
+ # There are no new namespaces for this tag, but
+ # non-default namespaces are in play, so we need a
+ # separate tag stack to know when they end.
+ self.nsmaps.append(None)
+ elif len(nsmap) > 0:
+ # A new namespace mapping has come into play.
+
+ # First, Let the BeautifulSoup object know about it.
+ self._register_namespaces(nsmap)
+
+ # Then, add it to our running list of inverted namespace
+ # mappings.
+ self.nsmaps.append(_invert(nsmap))
+
+ # Also treat the namespace mapping as a set of attributes on the
+ # tag, so we can recreate it later.
+ attrs = attrs.copy()
+ for prefix, namespace in list(nsmap.items()):
+ attribute = NamespacedAttribute(
+ "xmlns", prefix, "http://www.w3.org/2000/xmlns/")
+ attrs[attribute] = namespace
+
+ # Namespaces are in play. Find any attributes that came in
+ # from lxml with namespaces attached to their names, and
+ # turn then into NamespacedAttribute objects.
+ new_attrs = {}
+ for attr, value in list(attrs.items()):
+ namespace, attr = self._getNsTag(attr)
+ if namespace is None:
+ new_attrs[attr] = value
+ else:
+ nsprefix = self._prefix_for_namespace(namespace)
+ attr = NamespacedAttribute(nsprefix, attr, namespace)
+ new_attrs[attr] = value
+ attrs = new_attrs
+
+ namespace, name = self._getNsTag(name)
+ nsprefix = self._prefix_for_namespace(namespace)
+ self.soup.handle_starttag(name, namespace, nsprefix, attrs)
+
+ def _prefix_for_namespace(self, namespace):
+ """Find the currently active prefix for the given namespace."""
+ if namespace is None:
+ return None
+ for inverted_nsmap in reversed(self.nsmaps):
+ if inverted_nsmap is not None and namespace in inverted_nsmap:
+ return inverted_nsmap[namespace]
+ return None
+
+ def end(self, name):
+ self.soup.endData()
+ completed_tag = self.soup.tagStack[-1]
+ namespace, name = self._getNsTag(name)
+ nsprefix = None
+ if namespace is not None:
+ for inverted_nsmap in reversed(self.nsmaps):
+ if inverted_nsmap is not None and namespace in inverted_nsmap:
+ nsprefix = inverted_nsmap[namespace]
+ break
+ self.soup.handle_endtag(name, nsprefix)
+ if len(self.nsmaps) > 1:
+ # This tag, or one of its parents, introduced a namespace
+ # mapping, so pop it off the stack.
+ self.nsmaps.pop()
+
+ def pi(self, target, data):
+ self.soup.endData()
+ self.soup.handle_data(target + ' ' + data)
+ self.soup.endData(self.processing_instruction_class)
+
+ def data(self, content):
+ self.soup.handle_data(content)
+
+ def doctype(self, name, pubid, system):
+ self.soup.endData()
+ doctype = Doctype.for_name_and_ids(name, pubid, system)
+ self.soup.object_was_parsed(doctype)
+
+ def comment(self, content):
+ "Handle comments as Comment objects."
+ self.soup.endData()
+ self.soup.handle_data(content)
+ self.soup.endData(Comment)
+
+ def test_fragment_to_document(self, fragment):
+ """See `TreeBuilder`."""
+ return '<?xml version="1.0" encoding="utf-8"?>\n%s' % fragment
+
+
+class LXMLTreeBuilder(HTMLTreeBuilder, LXMLTreeBuilderForXML):
+
+ NAME = LXML
+ ALTERNATE_NAMES = ["lxml-html"]
+
+ features = ALTERNATE_NAMES + [NAME, HTML, FAST, PERMISSIVE]
+ is_xml = False
+ processing_instruction_class = ProcessingInstruction
+
+ def default_parser(self, encoding):
+ return etree.HTMLParser
+
+ def feed(self, markup):
+ encoding = self.soup.original_encoding
+ try:
+ self.parser = self.parser_for(encoding)
+ self.parser.feed(markup)
+ self.parser.close()
+ except (UnicodeDecodeError, LookupError, etree.ParserError) as e:
+ raise ParserRejectedMarkup(str(e))
+
+
+ def test_fragment_to_document(self, fragment):
+ """See `TreeBuilder`."""
+ return '<html><body>%s</body></html>' % fragment
diff --git a/libs/bs4/dammit.py b/libs/bs4/dammit.py
new file mode 100644
index 000000000..c7ac4d431
--- /dev/null
+++ b/libs/bs4/dammit.py
@@ -0,0 +1,850 @@
+# -*- coding: utf-8 -*-
+"""Beautiful Soup bonus library: Unicode, Dammit
+
+This library converts a bytestream to Unicode through any means
+necessary. It is heavily based on code from Mark Pilgrim's Universal
+Feed Parser. It works best on XML and HTML, but it does not rewrite the
+XML or HTML to reflect a new encoding; that's the tree builder's job.
+"""
+# Use of this source code is governed by the MIT license.
+__license__ = "MIT"
+
+import codecs
+from html.entities import codepoint2name
+import re
+import logging
+import string
+
+# Import a library to autodetect character encodings.
+chardet_type = None
+try:
+ # First try the fast C implementation.
+ # PyPI package: cchardet
+ import cchardet
+ def chardet_dammit(s):
+ return cchardet.detect(s)['encoding']
+except ImportError:
+ try:
+ # Fall back to the pure Python implementation
+ # Debian package: python-chardet
+ # PyPI package: chardet
+ import chardet
+ def chardet_dammit(s):
+ return chardet.detect(s)['encoding']
+ #import chardet.constants
+ #chardet.constants._debug = 1
+ except ImportError:
+ # No chardet available.
+ def chardet_dammit(s):
+ return None
+
+# Available from http://cjkpython.i18n.org/.
+try:
+ import iconv_codec
+except ImportError:
+ pass
+
+xml_encoding_re = re.compile(
+ '^<\\?.*encoding=[\'"](.*?)[\'"].*\\?>'.encode(), re.I)
+html_meta_re = re.compile(
+ '<\\s*meta[^>]+charset\\s*=\\s*["\']?([^>]*?)[ /;\'">]'.encode(), re.I)
+
+class EntitySubstitution(object):
+
+ """Substitute XML or HTML entities for the corresponding characters."""
+
+ def _populate_class_variables():
+ lookup = {}
+ reverse_lookup = {}
+ characters_for_re = []
+
+ # &apos is an XHTML entity and an HTML 5, but not an HTML 4
+ # entity. We don't want to use it, but we want to recognize it on the way in.
+ #
+ # TODO: Ideally we would be able to recognize all HTML 5 named
+ # entities, but that's a little tricky.
+ extra = [(39, 'apos')]
+ for codepoint, name in list(codepoint2name.items()) + extra:
+ character = chr(codepoint)
+ if codepoint not in (34, 39):
+ # There's no point in turning the quotation mark into
+ # &quot; or the single quote into &apos;, unless it
+ # happens within an attribute value, which is handled
+ # elsewhere.
+ characters_for_re.append(character)
+ lookup[character] = name
+ # But we do want to recognize those entities on the way in and
+ # convert them to Unicode characters.
+ reverse_lookup[name] = character
+ re_definition = "[%s]" % "".join(characters_for_re)
+ return lookup, reverse_lookup, re.compile(re_definition)
+ (CHARACTER_TO_HTML_ENTITY, HTML_ENTITY_TO_CHARACTER,
+ CHARACTER_TO_HTML_ENTITY_RE) = _populate_class_variables()
+
+ CHARACTER_TO_XML_ENTITY = {
+ "'": "apos",
+ '"': "quot",
+ "&": "amp",
+ "<": "lt",
+ ">": "gt",
+ }
+
+ BARE_AMPERSAND_OR_BRACKET = re.compile("([<>]|"
+ "&(?!#\\d+;|#x[0-9a-fA-F]+;|\\w+;)"
+ ")")
+
+ AMPERSAND_OR_BRACKET = re.compile("([<>&])")
+
+ @classmethod
+ def _substitute_html_entity(cls, matchobj):
+ entity = cls.CHARACTER_TO_HTML_ENTITY.get(matchobj.group(0))
+ return "&%s;" % entity
+
+ @classmethod
+ def _substitute_xml_entity(cls, matchobj):
+ """Used with a regular expression to substitute the
+ appropriate XML entity for an XML special character."""
+ entity = cls.CHARACTER_TO_XML_ENTITY[matchobj.group(0)]
+ return "&%s;" % entity
+
+ @classmethod
+ def quoted_attribute_value(self, value):
+ """Make a value into a quoted XML attribute, possibly escaping it.
+
+ Most strings will be quoted using double quotes.
+
+ Bob's Bar -> "Bob's Bar"
+
+ If a string contains double quotes, it will be quoted using
+ single quotes.
+
+ Welcome to "my bar" -> 'Welcome to "my bar"'
+
+ If a string contains both single and double quotes, the
+ double quotes will be escaped, and the string will be quoted
+ using double quotes.
+
+ Welcome to "Bob's Bar" -> "Welcome to &quot;Bob's bar&quot;
+ """
+ quote_with = '"'
+ if '"' in value:
+ if "'" in value:
+ # The string contains both single and double
+ # quotes. Turn the double quotes into
+ # entities. We quote the double quotes rather than
+ # the single quotes because the entity name is
+ # "&quot;" whether this is HTML or XML. If we
+ # quoted the single quotes, we'd have to decide
+ # between &apos; and &squot;.
+ replace_with = "&quot;"
+ value = value.replace('"', replace_with)
+ else:
+ # There are double quotes but no single quotes.
+ # We can use single quotes to quote the attribute.
+ quote_with = "'"
+ return quote_with + value + quote_with
+
+ @classmethod
+ def substitute_xml(cls, value, make_quoted_attribute=False):
+ """Substitute XML entities for special XML characters.
+
+ :param value: A string to be substituted. The less-than sign
+ will become &lt;, the greater-than sign will become &gt;,
+ and any ampersands will become &amp;. If you want ampersands
+ that appear to be part of an entity definition to be left
+ alone, use substitute_xml_containing_entities() instead.
+
+ :param make_quoted_attribute: If True, then the string will be
+ quoted, as befits an attribute value.
+ """
+ # Escape angle brackets and ampersands.
+ value = cls.AMPERSAND_OR_BRACKET.sub(
+ cls._substitute_xml_entity, value)
+
+ if make_quoted_attribute:
+ value = cls.quoted_attribute_value(value)
+ return value
+
+ @classmethod
+ def substitute_xml_containing_entities(
+ cls, value, make_quoted_attribute=False):
+ """Substitute XML entities for special XML characters.
+
+ :param value: A string to be substituted. The less-than sign will
+ become &lt;, the greater-than sign will become &gt;, and any
+ ampersands that are not part of an entity defition will
+ become &amp;.
+
+ :param make_quoted_attribute: If True, then the string will be
+ quoted, as befits an attribute value.
+ """
+ # Escape angle brackets, and ampersands that aren't part of
+ # entities.
+ value = cls.BARE_AMPERSAND_OR_BRACKET.sub(
+ cls._substitute_xml_entity, value)
+
+ if make_quoted_attribute:
+ value = cls.quoted_attribute_value(value)
+ return value
+
+ @classmethod
+ def substitute_html(cls, s):
+ """Replace certain Unicode characters with named HTML entities.
+
+ This differs from data.encode(encoding, 'xmlcharrefreplace')
+ in that the goal is to make the result more readable (to those
+ with ASCII displays) rather than to recover from
+ errors. There's absolutely nothing wrong with a UTF-8 string
+ containg a LATIN SMALL LETTER E WITH ACUTE, but replacing that
+ character with "&eacute;" will make it more readable to some
+ people.
+ """
+ return cls.CHARACTER_TO_HTML_ENTITY_RE.sub(
+ cls._substitute_html_entity, s)
+
+
+class EncodingDetector:
+ """Suggests a number of possible encodings for a bytestring.
+
+ Order of precedence:
+
+ 1. Encodings you specifically tell EncodingDetector to try first
+ (the override_encodings argument to the constructor).
+
+ 2. An encoding declared within the bytestring itself, either in an
+ XML declaration (if the bytestring is to be interpreted as an XML
+ document), or in a <meta> tag (if the bytestring is to be
+ interpreted as an HTML document.)
+
+ 3. An encoding detected through textual analysis by chardet,
+ cchardet, or a similar external library.
+
+ 4. UTF-8.
+
+ 5. Windows-1252.
+ """
+ def __init__(self, markup, override_encodings=None, is_html=False,
+ exclude_encodings=None):
+ self.override_encodings = override_encodings or []
+ exclude_encodings = exclude_encodings or []
+ self.exclude_encodings = set([x.lower() for x in exclude_encodings])
+ self.chardet_encoding = None
+ self.is_html = is_html
+ self.declared_encoding = None
+
+ # First order of business: strip a byte-order mark.
+ self.markup, self.sniffed_encoding = self.strip_byte_order_mark(markup)
+
+ def _usable(self, encoding, tried):
+ if encoding is not None:
+ encoding = encoding.lower()
+ if encoding in self.exclude_encodings:
+ return False
+ if encoding not in tried:
+ tried.add(encoding)
+ return True
+ return False
+
+ @property
+ def encodings(self):
+ """Yield a number of encodings that might work for this markup."""
+ tried = set()
+ for e in self.override_encodings:
+ if self._usable(e, tried):
+ yield e
+
+ # Did the document originally start with a byte-order mark
+ # that indicated its encoding?
+ if self._usable(self.sniffed_encoding, tried):
+ yield self.sniffed_encoding
+
+ # Look within the document for an XML or HTML encoding
+ # declaration.
+ if self.declared_encoding is None:
+ self.declared_encoding = self.find_declared_encoding(
+ self.markup, self.is_html)
+ if self._usable(self.declared_encoding, tried):
+ yield self.declared_encoding
+
+ # Use third-party character set detection to guess at the
+ # encoding.
+ if self.chardet_encoding is None:
+ self.chardet_encoding = chardet_dammit(self.markup)
+ if self._usable(self.chardet_encoding, tried):
+ yield self.chardet_encoding
+
+ # As a last-ditch effort, try utf-8 and windows-1252.
+ for e in ('utf-8', 'windows-1252'):
+ if self._usable(e, tried):
+ yield e
+
+ @classmethod
+ def strip_byte_order_mark(cls, data):
+ """If a byte-order mark is present, strip it and return the encoding it implies."""
+ encoding = None
+ if isinstance(data, str):
+ # Unicode data cannot have a byte-order mark.
+ return data, encoding
+ if (len(data) >= 4) and (data[:2] == b'\xfe\xff') \
+ and (data[2:4] != '\x00\x00'):
+ encoding = 'utf-16be'
+ data = data[2:]
+ elif (len(data) >= 4) and (data[:2] == b'\xff\xfe') \
+ and (data[2:4] != '\x00\x00'):
+ encoding = 'utf-16le'
+ data = data[2:]
+ elif data[:3] == b'\xef\xbb\xbf':
+ encoding = 'utf-8'
+ data = data[3:]
+ elif data[:4] == b'\x00\x00\xfe\xff':
+ encoding = 'utf-32be'
+ data = data[4:]
+ elif data[:4] == b'\xff\xfe\x00\x00':
+ encoding = 'utf-32le'
+ data = data[4:]
+ return data, encoding
+
+ @classmethod
+ def find_declared_encoding(cls, markup, is_html=False, search_entire_document=False):
+ """Given a document, tries to find its declared encoding.
+
+ An XML encoding is declared at the beginning of the document.
+
+ An HTML encoding is declared in a <meta> tag, hopefully near the
+ beginning of the document.
+ """
+ if search_entire_document:
+ xml_endpos = html_endpos = len(markup)
+ else:
+ xml_endpos = 1024
+ html_endpos = max(2048, int(len(markup) * 0.05))
+
+ declared_encoding = None
+ declared_encoding_match = xml_encoding_re.search(markup, endpos=xml_endpos)
+ if not declared_encoding_match and is_html:
+ declared_encoding_match = html_meta_re.search(markup, endpos=html_endpos)
+ if declared_encoding_match is not None:
+ declared_encoding = declared_encoding_match.groups()[0].decode(
+ 'ascii', 'replace')
+ if declared_encoding:
+ return declared_encoding.lower()
+ return None
+
+class UnicodeDammit:
+ """A class for detecting the encoding of a *ML document and
+ converting it to a Unicode string. If the source encoding is
+ windows-1252, can replace MS smart quotes with their HTML or XML
+ equivalents."""
+
+ # This dictionary maps commonly seen values for "charset" in HTML
+ # meta tags to the corresponding Python codec names. It only covers
+ # values that aren't in Python's aliases and can't be determined
+ # by the heuristics in find_codec.
+ CHARSET_ALIASES = {"macintosh": "mac-roman",
+ "x-sjis": "shift-jis"}
+
+ ENCODINGS_WITH_SMART_QUOTES = [
+ "windows-1252",
+ "iso-8859-1",
+ "iso-8859-2",
+ ]
+
+ def __init__(self, markup, override_encodings=[],
+ smart_quotes_to=None, is_html=False, exclude_encodings=[]):
+ self.smart_quotes_to = smart_quotes_to
+ self.tried_encodings = []
+ self.contains_replacement_characters = False
+ self.is_html = is_html
+ self.log = logging.getLogger(__name__)
+ self.detector = EncodingDetector(
+ markup, override_encodings, is_html, exclude_encodings)
+
+ # Short-circuit if the data is in Unicode to begin with.
+ if isinstance(markup, str) or markup == '':
+ self.markup = markup
+ self.unicode_markup = str(markup)
+ self.original_encoding = None
+ return
+
+ # The encoding detector may have stripped a byte-order mark.
+ # Use the stripped markup from this point on.
+ self.markup = self.detector.markup
+
+ u = None
+ for encoding in self.detector.encodings:
+ markup = self.detector.markup
+ u = self._convert_from(encoding)
+ if u is not None:
+ break
+
+ if not u:
+ # None of the encodings worked. As an absolute last resort,
+ # try them again with character replacement.
+
+ for encoding in self.detector.encodings:
+ if encoding != "ascii":
+ u = self._convert_from(encoding, "replace")
+ if u is not None:
+ self.log.warning(
+ "Some characters could not be decoded, and were "
+ "replaced with REPLACEMENT CHARACTER."
+ )
+ self.contains_replacement_characters = True
+ break
+
+ # If none of that worked, we could at this point force it to
+ # ASCII, but that would destroy so much data that I think
+ # giving up is better.
+ self.unicode_markup = u
+ if not u:
+ self.original_encoding = None
+
+ def _sub_ms_char(self, match):
+ """Changes a MS smart quote character to an XML or HTML
+ entity, or an ASCII character."""
+ orig = match.group(1)
+ if self.smart_quotes_to == 'ascii':
+ sub = self.MS_CHARS_TO_ASCII.get(orig).encode()
+ else:
+ sub = self.MS_CHARS.get(orig)
+ if type(sub) == tuple:
+ if self.smart_quotes_to == 'xml':
+ sub = '&#x'.encode() + sub[1].encode() + ';'.encode()
+ else:
+ sub = '&'.encode() + sub[0].encode() + ';'.encode()
+ else:
+ sub = sub.encode()
+ return sub
+
+ def _convert_from(self, proposed, errors="strict"):
+ proposed = self.find_codec(proposed)
+ if not proposed or (proposed, errors) in self.tried_encodings:
+ return None
+ self.tried_encodings.append((proposed, errors))
+ markup = self.markup
+ # Convert smart quotes to HTML if coming from an encoding
+ # that might have them.
+ if (self.smart_quotes_to is not None
+ and proposed in self.ENCODINGS_WITH_SMART_QUOTES):
+ smart_quotes_re = b"([\x80-\x9f])"
+ smart_quotes_compiled = re.compile(smart_quotes_re)
+ markup = smart_quotes_compiled.sub(self._sub_ms_char, markup)
+
+ try:
+ #print "Trying to convert document to %s (errors=%s)" % (
+ # proposed, errors)
+ u = self._to_unicode(markup, proposed, errors)
+ self.markup = u
+ self.original_encoding = proposed
+ except Exception as e:
+ #print "That didn't work!"
+ #print e
+ return None
+ #print "Correct encoding: %s" % proposed
+ return self.markup
+
+ def _to_unicode(self, data, encoding, errors="strict"):
+ '''Given a string and its encoding, decodes the string into Unicode.
+ %encoding is a string recognized by encodings.aliases'''
+ return str(data, encoding, errors)
+
+ @property
+ def declared_html_encoding(self):
+ if not self.is_html:
+ return None
+ return self.detector.declared_encoding
+
+ def find_codec(self, charset):
+ value = (self._codec(self.CHARSET_ALIASES.get(charset, charset))
+ or (charset and self._codec(charset.replace("-", "")))
+ or (charset and self._codec(charset.replace("-", "_")))
+ or (charset and charset.lower())
+ or charset
+ )
+ if value:
+ return value.lower()
+ return None
+
+ def _codec(self, charset):
+ if not charset:
+ return charset
+ codec = None
+ try:
+ codecs.lookup(charset)
+ codec = charset
+ except (LookupError, ValueError):
+ pass
+ return codec
+
+
+ # A partial mapping of ISO-Latin-1 to HTML entities/XML numeric entities.
+ MS_CHARS = {b'\x80': ('euro', '20AC'),
+ b'\x81': ' ',
+ b'\x82': ('sbquo', '201A'),
+ b'\x83': ('fnof', '192'),
+ b'\x84': ('bdquo', '201E'),
+ b'\x85': ('hellip', '2026'),
+ b'\x86': ('dagger', '2020'),
+ b'\x87': ('Dagger', '2021'),
+ b'\x88': ('circ', '2C6'),
+ b'\x89': ('permil', '2030'),
+ b'\x8A': ('Scaron', '160'),
+ b'\x8B': ('lsaquo', '2039'),
+ b'\x8C': ('OElig', '152'),
+ b'\x8D': '?',
+ b'\x8E': ('#x17D', '17D'),
+ b'\x8F': '?',
+ b'\x90': '?',
+ b'\x91': ('lsquo', '2018'),
+ b'\x92': ('rsquo', '2019'),
+ b'\x93': ('ldquo', '201C'),
+ b'\x94': ('rdquo', '201D'),
+ b'\x95': ('bull', '2022'),
+ b'\x96': ('ndash', '2013'),
+ b'\x97': ('mdash', '2014'),
+ b'\x98': ('tilde', '2DC'),
+ b'\x99': ('trade', '2122'),
+ b'\x9a': ('scaron', '161'),
+ b'\x9b': ('rsaquo', '203A'),
+ b'\x9c': ('oelig', '153'),
+ b'\x9d': '?',
+ b'\x9e': ('#x17E', '17E'),
+ b'\x9f': ('Yuml', ''),}
+
+ # A parochial partial mapping of ISO-Latin-1 to ASCII. Contains
+ # horrors like stripping diacritical marks to turn á into a, but also
+ # contains non-horrors like turning “ into ".
+ MS_CHARS_TO_ASCII = {
+ b'\x80' : 'EUR',
+ b'\x81' : ' ',
+ b'\x82' : ',',
+ b'\x83' : 'f',
+ b'\x84' : ',,',
+ b'\x85' : '...',
+ b'\x86' : '+',
+ b'\x87' : '++',
+ b'\x88' : '^',
+ b'\x89' : '%',
+ b'\x8a' : 'S',
+ b'\x8b' : '<',
+ b'\x8c' : 'OE',
+ b'\x8d' : '?',
+ b'\x8e' : 'Z',
+ b'\x8f' : '?',
+ b'\x90' : '?',
+ b'\x91' : "'",
+ b'\x92' : "'",
+ b'\x93' : '"',
+ b'\x94' : '"',
+ b'\x95' : '*',
+ b'\x96' : '-',
+ b'\x97' : '--',
+ b'\x98' : '~',
+ b'\x99' : '(TM)',
+ b'\x9a' : 's',
+ b'\x9b' : '>',
+ b'\x9c' : 'oe',
+ b'\x9d' : '?',
+ b'\x9e' : 'z',
+ b'\x9f' : 'Y',
+ b'\xa0' : ' ',
+ b'\xa1' : '!',
+ b'\xa2' : 'c',
+ b'\xa3' : 'GBP',
+ b'\xa4' : '$', #This approximation is especially parochial--this is the
+ #generic currency symbol.
+ b'\xa5' : 'YEN',
+ b'\xa6' : '|',
+ b'\xa7' : 'S',
+ b'\xa8' : '..',
+ b'\xa9' : '',
+ b'\xaa' : '(th)',
+ b'\xab' : '<<',
+ b'\xac' : '!',
+ b'\xad' : ' ',
+ b'\xae' : '(R)',
+ b'\xaf' : '-',
+ b'\xb0' : 'o',
+ b'\xb1' : '+-',
+ b'\xb2' : '2',
+ b'\xb3' : '3',
+ b'\xb4' : ("'", 'acute'),
+ b'\xb5' : 'u',
+ b'\xb6' : 'P',
+ b'\xb7' : '*',
+ b'\xb8' : ',',
+ b'\xb9' : '1',
+ b'\xba' : '(th)',
+ b'\xbb' : '>>',
+ b'\xbc' : '1/4',
+ b'\xbd' : '1/2',
+ b'\xbe' : '3/4',
+ b'\xbf' : '?',
+ b'\xc0' : 'A',
+ b'\xc1' : 'A',
+ b'\xc2' : 'A',
+ b'\xc3' : 'A',
+ b'\xc4' : 'A',
+ b'\xc5' : 'A',
+ b'\xc6' : 'AE',
+ b'\xc7' : 'C',
+ b'\xc8' : 'E',
+ b'\xc9' : 'E',
+ b'\xca' : 'E',
+ b'\xcb' : 'E',
+ b'\xcc' : 'I',
+ b'\xcd' : 'I',
+ b'\xce' : 'I',
+ b'\xcf' : 'I',
+ b'\xd0' : 'D',
+ b'\xd1' : 'N',
+ b'\xd2' : 'O',
+ b'\xd3' : 'O',
+ b'\xd4' : 'O',
+ b'\xd5' : 'O',
+ b'\xd6' : 'O',
+ b'\xd7' : '*',
+ b'\xd8' : 'O',
+ b'\xd9' : 'U',
+ b'\xda' : 'U',
+ b'\xdb' : 'U',
+ b'\xdc' : 'U',
+ b'\xdd' : 'Y',
+ b'\xde' : 'b',
+ b'\xdf' : 'B',
+ b'\xe0' : 'a',
+ b'\xe1' : 'a',
+ b'\xe2' : 'a',
+ b'\xe3' : 'a',
+ b'\xe4' : 'a',
+ b'\xe5' : 'a',
+ b'\xe6' : 'ae',
+ b'\xe7' : 'c',
+ b'\xe8' : 'e',
+ b'\xe9' : 'e',
+ b'\xea' : 'e',
+ b'\xeb' : 'e',
+ b'\xec' : 'i',
+ b'\xed' : 'i',
+ b'\xee' : 'i',
+ b'\xef' : 'i',
+ b'\xf0' : 'o',
+ b'\xf1' : 'n',
+ b'\xf2' : 'o',
+ b'\xf3' : 'o',
+ b'\xf4' : 'o',
+ b'\xf5' : 'o',
+ b'\xf6' : 'o',
+ b'\xf7' : '/',
+ b'\xf8' : 'o',
+ b'\xf9' : 'u',
+ b'\xfa' : 'u',
+ b'\xfb' : 'u',
+ b'\xfc' : 'u',
+ b'\xfd' : 'y',
+ b'\xfe' : 'b',
+ b'\xff' : 'y',
+ }
+
+ # A map used when removing rogue Windows-1252/ISO-8859-1
+ # characters in otherwise UTF-8 documents.
+ #
+ # Note that \x81, \x8d, \x8f, \x90, and \x9d are undefined in
+ # Windows-1252.
+ WINDOWS_1252_TO_UTF8 = {
+ 0x80 : b'\xe2\x82\xac', # €
+ 0x82 : b'\xe2\x80\x9a', # ‚
+ 0x83 : b'\xc6\x92', # ƒ
+ 0x84 : b'\xe2\x80\x9e', # „
+ 0x85 : b'\xe2\x80\xa6', # …
+ 0x86 : b'\xe2\x80\xa0', # †
+ 0x87 : b'\xe2\x80\xa1', # ‡
+ 0x88 : b'\xcb\x86', # ˆ
+ 0x89 : b'\xe2\x80\xb0', # ‰
+ 0x8a : b'\xc5\xa0', # Š
+ 0x8b : b'\xe2\x80\xb9', # ‹
+ 0x8c : b'\xc5\x92', # Œ
+ 0x8e : b'\xc5\xbd', # Ž
+ 0x91 : b'\xe2\x80\x98', # ‘
+ 0x92 : b'\xe2\x80\x99', # ’
+ 0x93 : b'\xe2\x80\x9c', # “
+ 0x94 : b'\xe2\x80\x9d', # ”
+ 0x95 : b'\xe2\x80\xa2', # •
+ 0x96 : b'\xe2\x80\x93', # –
+ 0x97 : b'\xe2\x80\x94', # —
+ 0x98 : b'\xcb\x9c', # ˜
+ 0x99 : b'\xe2\x84\xa2', # ™
+ 0x9a : b'\xc5\xa1', # š
+ 0x9b : b'\xe2\x80\xba', # ›
+ 0x9c : b'\xc5\x93', # œ
+ 0x9e : b'\xc5\xbe', # ž
+ 0x9f : b'\xc5\xb8', # Ÿ
+ 0xa0 : b'\xc2\xa0', #  
+ 0xa1 : b'\xc2\xa1', # ¡
+ 0xa2 : b'\xc2\xa2', # ¢
+ 0xa3 : b'\xc2\xa3', # £
+ 0xa4 : b'\xc2\xa4', # ¤
+ 0xa5 : b'\xc2\xa5', # ¥
+ 0xa6 : b'\xc2\xa6', # ¦
+ 0xa7 : b'\xc2\xa7', # §
+ 0xa8 : b'\xc2\xa8', # ¨
+ 0xa9 : b'\xc2\xa9', # ©
+ 0xaa : b'\xc2\xaa', # ª
+ 0xab : b'\xc2\xab', # «
+ 0xac : b'\xc2\xac', # ¬
+ 0xad : b'\xc2\xad', # ­
+ 0xae : b'\xc2\xae', # ®
+ 0xaf : b'\xc2\xaf', # ¯
+ 0xb0 : b'\xc2\xb0', # °
+ 0xb1 : b'\xc2\xb1', # ±
+ 0xb2 : b'\xc2\xb2', # ²
+ 0xb3 : b'\xc2\xb3', # ³
+ 0xb4 : b'\xc2\xb4', # ´
+ 0xb5 : b'\xc2\xb5', # µ
+ 0xb6 : b'\xc2\xb6', # ¶
+ 0xb7 : b'\xc2\xb7', # ·
+ 0xb8 : b'\xc2\xb8', # ¸
+ 0xb9 : b'\xc2\xb9', # ¹
+ 0xba : b'\xc2\xba', # º
+ 0xbb : b'\xc2\xbb', # »
+ 0xbc : b'\xc2\xbc', # ¼
+ 0xbd : b'\xc2\xbd', # ½
+ 0xbe : b'\xc2\xbe', # ¾
+ 0xbf : b'\xc2\xbf', # ¿
+ 0xc0 : b'\xc3\x80', # À
+ 0xc1 : b'\xc3\x81', # Á
+ 0xc2 : b'\xc3\x82', # Â
+ 0xc3 : b'\xc3\x83', # Ã
+ 0xc4 : b'\xc3\x84', # Ä
+ 0xc5 : b'\xc3\x85', # Å
+ 0xc6 : b'\xc3\x86', # Æ
+ 0xc7 : b'\xc3\x87', # Ç
+ 0xc8 : b'\xc3\x88', # È
+ 0xc9 : b'\xc3\x89', # É
+ 0xca : b'\xc3\x8a', # Ê
+ 0xcb : b'\xc3\x8b', # Ë
+ 0xcc : b'\xc3\x8c', # Ì
+ 0xcd : b'\xc3\x8d', # Í
+ 0xce : b'\xc3\x8e', # Î
+ 0xcf : b'\xc3\x8f', # Ï
+ 0xd0 : b'\xc3\x90', # Ð
+ 0xd1 : b'\xc3\x91', # Ñ
+ 0xd2 : b'\xc3\x92', # Ò
+ 0xd3 : b'\xc3\x93', # Ó
+ 0xd4 : b'\xc3\x94', # Ô
+ 0xd5 : b'\xc3\x95', # Õ
+ 0xd6 : b'\xc3\x96', # Ö
+ 0xd7 : b'\xc3\x97', # ×
+ 0xd8 : b'\xc3\x98', # Ø
+ 0xd9 : b'\xc3\x99', # Ù
+ 0xda : b'\xc3\x9a', # Ú
+ 0xdb : b'\xc3\x9b', # Û
+ 0xdc : b'\xc3\x9c', # Ü
+ 0xdd : b'\xc3\x9d', # Ý
+ 0xde : b'\xc3\x9e', # Þ
+ 0xdf : b'\xc3\x9f', # ß
+ 0xe0 : b'\xc3\xa0', # à
+ 0xe1 : b'\xa1', # á
+ 0xe2 : b'\xc3\xa2', # â
+ 0xe3 : b'\xc3\xa3', # ã
+ 0xe4 : b'\xc3\xa4', # ä
+ 0xe5 : b'\xc3\xa5', # å
+ 0xe6 : b'\xc3\xa6', # æ
+ 0xe7 : b'\xc3\xa7', # ç
+ 0xe8 : b'\xc3\xa8', # è
+ 0xe9 : b'\xc3\xa9', # é
+ 0xea : b'\xc3\xaa', # ê
+ 0xeb : b'\xc3\xab', # ë
+ 0xec : b'\xc3\xac', # ì
+ 0xed : b'\xc3\xad', # í
+ 0xee : b'\xc3\xae', # î
+ 0xef : b'\xc3\xaf', # ï
+ 0xf0 : b'\xc3\xb0', # ð
+ 0xf1 : b'\xc3\xb1', # ñ
+ 0xf2 : b'\xc3\xb2', # ò
+ 0xf3 : b'\xc3\xb3', # ó
+ 0xf4 : b'\xc3\xb4', # ô
+ 0xf5 : b'\xc3\xb5', # õ
+ 0xf6 : b'\xc3\xb6', # ö
+ 0xf7 : b'\xc3\xb7', # ÷
+ 0xf8 : b'\xc3\xb8', # ø
+ 0xf9 : b'\xc3\xb9', # ù
+ 0xfa : b'\xc3\xba', # ú
+ 0xfb : b'\xc3\xbb', # û
+ 0xfc : b'\xc3\xbc', # ü
+ 0xfd : b'\xc3\xbd', # ý
+ 0xfe : b'\xc3\xbe', # þ
+ }
+
+ MULTIBYTE_MARKERS_AND_SIZES = [
+ (0xc2, 0xdf, 2), # 2-byte characters start with a byte C2-DF
+ (0xe0, 0xef, 3), # 3-byte characters start with E0-EF
+ (0xf0, 0xf4, 4), # 4-byte characters start with F0-F4
+ ]
+
+ FIRST_MULTIBYTE_MARKER = MULTIBYTE_MARKERS_AND_SIZES[0][0]
+ LAST_MULTIBYTE_MARKER = MULTIBYTE_MARKERS_AND_SIZES[-1][1]
+
+ @classmethod
+ def detwingle(cls, in_bytes, main_encoding="utf8",
+ embedded_encoding="windows-1252"):
+ """Fix characters from one encoding embedded in some other encoding.
+
+ Currently the only situation supported is Windows-1252 (or its
+ subset ISO-8859-1), embedded in UTF-8.
+
+ The input must be a bytestring. If you've already converted
+ the document to Unicode, you're too late.
+
+ The output is a bytestring in which `embedded_encoding`
+ characters have been converted to their `main_encoding`
+ equivalents.
+ """
+ if embedded_encoding.replace('_', '-').lower() not in (
+ 'windows-1252', 'windows_1252'):
+ raise NotImplementedError(
+ "Windows-1252 and ISO-8859-1 are the only currently supported "
+ "embedded encodings.")
+
+ if main_encoding.lower() not in ('utf8', 'utf-8'):
+ raise NotImplementedError(
+ "UTF-8 is the only currently supported main encoding.")
+
+ byte_chunks = []
+
+ chunk_start = 0
+ pos = 0
+ while pos < len(in_bytes):
+ byte = in_bytes[pos]
+ if not isinstance(byte, int):
+ # Python 2.x
+ byte = ord(byte)
+ if (byte >= cls.FIRST_MULTIBYTE_MARKER
+ and byte <= cls.LAST_MULTIBYTE_MARKER):
+ # This is the start of a UTF-8 multibyte character. Skip
+ # to the end.
+ for start, end, size in cls.MULTIBYTE_MARKERS_AND_SIZES:
+ if byte >= start and byte <= end:
+ pos += size
+ break
+ elif byte >= 0x80 and byte in cls.WINDOWS_1252_TO_UTF8:
+ # We found a Windows-1252 character!
+ # Save the string up to this point as a chunk.
+ byte_chunks.append(in_bytes[chunk_start:pos])
+
+ # Now translate the Windows-1252 character into UTF-8
+ # and add it as another, one-byte chunk.
+ byte_chunks.append(cls.WINDOWS_1252_TO_UTF8[byte])
+ pos += 1
+ chunk_start = pos
+ else:
+ # Go on to the next character.
+ pos += 1
+ if chunk_start == 0:
+ # The string is unchanged.
+ return in_bytes
+ else:
+ # Store the final chunk.
+ byte_chunks.append(in_bytes[chunk_start:])
+ return b''.join(byte_chunks)
+
diff --git a/libs/bs4/diagnose.py b/libs/bs4/diagnose.py
new file mode 100644
index 000000000..b5f6e6c8b
--- /dev/null
+++ b/libs/bs4/diagnose.py
@@ -0,0 +1,224 @@
+"""Diagnostic functions, mainly for use when doing tech support."""
+
+# Use of this source code is governed by the MIT license.
+__license__ = "MIT"
+
+import cProfile
+from io import StringIO
+from html.parser import HTMLParser
+import bs4
+from bs4 import BeautifulSoup, __version__
+from bs4.builder import builder_registry
+
+import os
+import pstats
+import random
+import tempfile
+import time
+import traceback
+import sys
+import cProfile
+
+def diagnose(data):
+ """Diagnostic suite for isolating common problems."""
+ print("Diagnostic running on Beautiful Soup %s" % __version__)
+ print("Python version %s" % sys.version)
+
+ basic_parsers = ["html.parser", "html5lib", "lxml"]
+ for name in basic_parsers:
+ for builder in builder_registry.builders:
+ if name in builder.features:
+ break
+ else:
+ basic_parsers.remove(name)
+ print((
+ "I noticed that %s is not installed. Installing it may help." %
+ name))
+
+ if 'lxml' in basic_parsers:
+ basic_parsers.append("lxml-xml")
+ try:
+ from lxml import etree
+ print("Found lxml version %s" % ".".join(map(str,etree.LXML_VERSION)))
+ except ImportError as e:
+ print (
+ "lxml is not installed or couldn't be imported.")
+
+
+ if 'html5lib' in basic_parsers:
+ try:
+ import html5lib
+ print("Found html5lib version %s" % html5lib.__version__)
+ except ImportError as e:
+ print (
+ "html5lib is not installed or couldn't be imported.")
+
+ if hasattr(data, 'read'):
+ data = data.read()
+ elif data.startswith("http:") or data.startswith("https:"):
+ print('"%s" looks like a URL. Beautiful Soup is not an HTTP client.' % data)
+ print("You need to use some other library to get the document behind the URL, and feed that document to Beautiful Soup.")
+ return
+ else:
+ try:
+ if os.path.exists(data):
+ print('"%s" looks like a filename. Reading data from the file.' % data)
+ with open(data) as fp:
+ data = fp.read()
+ except ValueError:
+ # This can happen on some platforms when the 'filename' is
+ # too long. Assume it's data and not a filename.
+ pass
+ print()
+
+ for parser in basic_parsers:
+ print("Trying to parse your markup with %s" % parser)
+ success = False
+ try:
+ soup = BeautifulSoup(data, features=parser)
+ success = True
+ except Exception as e:
+ print("%s could not parse the markup." % parser)
+ traceback.print_exc()
+ if success:
+ print("Here's what %s did with the markup:" % parser)
+ print(soup.prettify())
+
+ print("-" * 80)
+
+def lxml_trace(data, html=True, **kwargs):
+ """Print out the lxml events that occur during parsing.
+
+ This lets you see how lxml parses a document when no Beautiful
+ Soup code is running.
+ """
+ from lxml import etree
+ for event, element in etree.iterparse(StringIO(data), html=html, **kwargs):
+ print(("%s, %4s, %s" % (event, element.tag, element.text)))
+
+class AnnouncingParser(HTMLParser):
+ """Announces HTMLParser parse events, without doing anything else."""
+
+ def _p(self, s):
+ print(s)
+
+ def handle_starttag(self, name, attrs):
+ self._p("%s START" % name)
+
+ def handle_endtag(self, name):
+ self._p("%s END" % name)
+
+ def handle_data(self, data):
+ self._p("%s DATA" % data)
+
+ def handle_charref(self, name):
+ self._p("%s CHARREF" % name)
+
+ def handle_entityref(self, name):
+ self._p("%s ENTITYREF" % name)
+
+ def handle_comment(self, data):
+ self._p("%s COMMENT" % data)
+
+ def handle_decl(self, data):
+ self._p("%s DECL" % data)
+
+ def unknown_decl(self, data):
+ self._p("%s UNKNOWN-DECL" % data)
+
+ def handle_pi(self, data):
+ self._p("%s PI" % data)
+
+def htmlparser_trace(data):
+ """Print out the HTMLParser events that occur during parsing.
+
+ This lets you see how HTMLParser parses a document when no
+ Beautiful Soup code is running.
+ """
+ parser = AnnouncingParser()
+ parser.feed(data)
+
+_vowels = "aeiou"
+_consonants = "bcdfghjklmnpqrstvwxyz"
+
+def rword(length=5):
+ "Generate a random word-like string."
+ s = ''
+ for i in range(length):
+ if i % 2 == 0:
+ t = _consonants
+ else:
+ t = _vowels
+ s += random.choice(t)
+ return s
+
+def rsentence(length=4):
+ "Generate a random sentence-like string."
+ return " ".join(rword(random.randint(4,9)) for i in list(range(length)))
+
+def rdoc(num_elements=1000):
+ """Randomly generate an invalid HTML document."""
+ tag_names = ['p', 'div', 'span', 'i', 'b', 'script', 'table']
+ elements = []
+ for i in range(num_elements):
+ choice = random.randint(0,3)
+ if choice == 0:
+ # New tag.
+ tag_name = random.choice(tag_names)
+ elements.append("<%s>" % tag_name)
+ elif choice == 1:
+ elements.append(rsentence(random.randint(1,4)))
+ elif choice == 2:
+ # Close a tag.
+ tag_name = random.choice(tag_names)
+ elements.append("</%s>" % tag_name)
+ return "<html>" + "\n".join(elements) + "</html>"
+
+def benchmark_parsers(num_elements=100000):
+ """Very basic head-to-head performance benchmark."""
+ print("Comparative parser benchmark on Beautiful Soup %s" % __version__)
+ data = rdoc(num_elements)
+ print("Generated a large invalid HTML document (%d bytes)." % len(data))
+
+ for parser in ["lxml", ["lxml", "html"], "html5lib", "html.parser"]:
+ success = False
+ try:
+ a = time.time()
+ soup = BeautifulSoup(data, parser)
+ b = time.time()
+ success = True
+ except Exception as e:
+ print("%s could not parse the markup." % parser)
+ traceback.print_exc()
+ if success:
+ print("BS4+%s parsed the markup in %.2fs." % (parser, b-a))
+
+ from lxml import etree
+ a = time.time()
+ etree.HTML(data)
+ b = time.time()
+ print("Raw lxml parsed the markup in %.2fs." % (b-a))
+
+ import html5lib
+ parser = html5lib.HTMLParser()
+ a = time.time()
+ parser.parse(data)
+ b = time.time()
+ print("Raw html5lib parsed the markup in %.2fs." % (b-a))
+
+def profile(num_elements=100000, parser="lxml"):
+
+ filehandle = tempfile.NamedTemporaryFile()
+ filename = filehandle.name
+
+ data = rdoc(num_elements)
+ vars = dict(bs4=bs4, data=data, parser=parser)
+ cProfile.runctx('bs4.BeautifulSoup(data, parser)' , vars, vars, filename)
+
+ stats = pstats.Stats(filename)
+ # stats.strip_dirs()
+ stats.sort_stats("cumulative")
+ stats.print_stats('_html5lib|bs4', 50)
+
+if __name__ == '__main__':
+ diagnose(sys.stdin.read())
diff --git a/libs/bs4/element.py b/libs/bs4/element.py
new file mode 100644
index 000000000..f16b1663e
--- /dev/null
+++ b/libs/bs4/element.py
@@ -0,0 +1,1579 @@
+# Use of this source code is governed by the MIT license.
+__license__ = "MIT"
+
+try:
+ from collections.abc import Callable # Python 3.6
+except ImportError as e:
+ from collections import Callable
+import re
+import sys
+import warnings
+try:
+ import soupsieve
+except ImportError as e:
+ soupsieve = None
+ warnings.warn(
+ 'The soupsieve package is not installed. CSS selectors cannot be used.'
+ )
+
+from bs4.formatter import (
+ Formatter,
+ HTMLFormatter,
+ XMLFormatter,
+)
+
+DEFAULT_OUTPUT_ENCODING = "utf-8"
+PY3K = (sys.version_info[0] > 2)
+
+nonwhitespace_re = re.compile(r"\S+")
+
+# NOTE: This isn't used as of 4.7.0. I'm leaving it for a little bit on
+# the off chance someone imported it for their own use.
+whitespace_re = re.compile(r"\s+")
+
+def _alias(attr):
+ """Alias one attribute name to another for backward compatibility"""
+ @property
+ def alias(self):
+ return getattr(self, attr)
+
+ @alias.setter
+ def alias(self):
+ return setattr(self, attr)
+ return alias
+
+
+class NamespacedAttribute(str):
+
+ def __new__(cls, prefix, name, namespace=None):
+ if name is None:
+ obj = str.__new__(cls, prefix)
+ elif prefix is None:
+ # Not really namespaced.
+ obj = str.__new__(cls, name)
+ else:
+ obj = str.__new__(cls, prefix + ":" + name)
+ obj.prefix = prefix
+ obj.name = name
+ obj.namespace = namespace
+ return obj
+
+class AttributeValueWithCharsetSubstitution(str):
+ """A stand-in object for a character encoding specified in HTML."""
+
+class CharsetMetaAttributeValue(AttributeValueWithCharsetSubstitution):
+ """A generic stand-in for the value of a meta tag's 'charset' attribute.
+
+ When Beautiful Soup parses the markup '<meta charset="utf8">', the
+ value of the 'charset' attribute will be one of these objects.
+ """
+
+ def __new__(cls, original_value):
+ obj = str.__new__(cls, original_value)
+ obj.original_value = original_value
+ return obj
+
+ def encode(self, encoding):
+ return encoding
+
+
+class ContentMetaAttributeValue(AttributeValueWithCharsetSubstitution):
+ """A generic stand-in for the value of a meta tag's 'content' attribute.
+
+ When Beautiful Soup parses the markup:
+ <meta http-equiv="content-type" content="text/html; charset=utf8">
+
+ The value of the 'content' attribute will be one of these objects.
+ """
+
+ CHARSET_RE = re.compile(r"((^|;)\s*charset=)([^;]*)", re.M)
+
+ def __new__(cls, original_value):
+ match = cls.CHARSET_RE.search(original_value)
+ if match is None:
+ # No substitution necessary.
+ return str.__new__(str, original_value)
+
+ obj = str.__new__(cls, original_value)
+ obj.original_value = original_value
+ return obj
+
+ def encode(self, encoding):
+ def rewrite(match):
+ return match.group(1) + encoding
+ return self.CHARSET_RE.sub(rewrite, self.original_value)
+
+
+class PageElement(object):
+ """Contains the navigational information for some part of the page
+ (either a tag or a piece of text)"""
+
+ def setup(self, parent=None, previous_element=None, next_element=None,
+ previous_sibling=None, next_sibling=None):
+ """Sets up the initial relations between this element and
+ other elements."""
+ self.parent = parent
+
+ self.previous_element = previous_element
+ if previous_element is not None:
+ self.previous_element.next_element = self
+
+ self.next_element = next_element
+ if self.next_element is not None:
+ self.next_element.previous_element = self
+
+ self.next_sibling = next_sibling
+ if self.next_sibling is not None:
+ self.next_sibling.previous_sibling = self
+
+ if (previous_sibling is None
+ and self.parent is not None and self.parent.contents):
+ previous_sibling = self.parent.contents[-1]
+
+ self.previous_sibling = previous_sibling
+ if previous_sibling is not None:
+ self.previous_sibling.next_sibling = self
+
+ def format_string(self, s, formatter):
+ """Format the given string using the given formatter."""
+ if formatter is None:
+ return s
+ if not isinstance(formatter, Formatter):
+ formatter = self.formatter_for_name(formatter)
+ output = formatter.substitute(s)
+ return output
+
+ def formatter_for_name(self, formatter):
+ """Look up or create a Formatter for the given identifier,
+ if necessary.
+
+ :param formatter: Can be a Formatter object (used as-is), a
+ function (used as the entity substitution hook for an
+ XMLFormatter or HTMLFormatter), or a string (used to look up
+ an XMLFormatter or HTMLFormatter in the appropriate registry.
+ """
+ if isinstance(formatter, Formatter):
+ return formatter
+ if self._is_xml:
+ c = XMLFormatter
+ else:
+ c = HTMLFormatter
+ if callable(formatter):
+ return c(entity_substitution=formatter)
+ return c.REGISTRY[formatter]
+
+ @property
+ def _is_xml(self):
+ """Is this element part of an XML tree or an HTML tree?
+
+ This is used in formatter_for_name, when deciding whether an
+ XMLFormatter or HTMLFormatter is more appropriate. It can be
+ inefficient, but it should be called very rarely.
+ """
+ if self.known_xml is not None:
+ # Most of the time we will have determined this when the
+ # document is parsed.
+ return self.known_xml
+
+ # Otherwise, it's likely that this element was created by
+ # direct invocation of the constructor from within the user's
+ # Python code.
+ if self.parent is None:
+ # This is the top-level object. It should have .known_xml set
+ # from tree creation. If not, take a guess--BS is usually
+ # used on HTML markup.
+ return getattr(self, 'is_xml', False)
+ return self.parent._is_xml
+
+ nextSibling = _alias("next_sibling") # BS3
+ previousSibling = _alias("previous_sibling") # BS3
+
+ def replace_with(self, replace_with):
+ if self.parent is None:
+ raise ValueError(
+ "Cannot replace one element with another when the "
+ "element to be replaced is not part of a tree.")
+ if replace_with is self:
+ return
+ if replace_with is self.parent:
+ raise ValueError("Cannot replace a Tag with its parent.")
+ old_parent = self.parent
+ my_index = self.parent.index(self)
+ self.extract()
+ old_parent.insert(my_index, replace_with)
+ return self
+ replaceWith = replace_with # BS3
+
+ def unwrap(self):
+ my_parent = self.parent
+ if self.parent is None:
+ raise ValueError(
+ "Cannot replace an element with its contents when that"
+ "element is not part of a tree.")
+ my_index = self.parent.index(self)
+ self.extract()
+ for child in reversed(self.contents[:]):
+ my_parent.insert(my_index, child)
+ return self
+ replace_with_children = unwrap
+ replaceWithChildren = unwrap # BS3
+
+ def wrap(self, wrap_inside):
+ me = self.replace_with(wrap_inside)
+ wrap_inside.append(me)
+ return wrap_inside
+
+ def extract(self):
+ """Destructively rips this element out of the tree."""
+ if self.parent is not None:
+ del self.parent.contents[self.parent.index(self)]
+
+ #Find the two elements that would be next to each other if
+ #this element (and any children) hadn't been parsed. Connect
+ #the two.
+ last_child = self._last_descendant()
+ next_element = last_child.next_element
+
+ if (self.previous_element is not None and
+ self.previous_element is not next_element):
+ self.previous_element.next_element = next_element
+ if next_element is not None and next_element is not self.previous_element:
+ next_element.previous_element = self.previous_element
+ self.previous_element = None
+ last_child.next_element = None
+
+ self.parent = None
+ if (self.previous_sibling is not None
+ and self.previous_sibling is not self.next_sibling):
+ self.previous_sibling.next_sibling = self.next_sibling
+ if (self.next_sibling is not None
+ and self.next_sibling is not self.previous_sibling):
+ self.next_sibling.previous_sibling = self.previous_sibling
+ self.previous_sibling = self.next_sibling = None
+ return self
+
+ def _last_descendant(self, is_initialized=True, accept_self=True):
+ "Finds the last element beneath this object to be parsed."
+ if is_initialized and self.next_sibling is not None:
+ last_child = self.next_sibling.previous_element
+ else:
+ last_child = self
+ while isinstance(last_child, Tag) and last_child.contents:
+ last_child = last_child.contents[-1]
+ if not accept_self and last_child is self:
+ last_child = None
+ return last_child
+ # BS3: Not part of the API!
+ _lastRecursiveChild = _last_descendant
+
+ def insert(self, position, new_child):
+ if new_child is None:
+ raise ValueError("Cannot insert None into a tag.")
+ if new_child is self:
+ raise ValueError("Cannot insert a tag into itself.")
+ if (isinstance(new_child, str)
+ and not isinstance(new_child, NavigableString)):
+ new_child = NavigableString(new_child)
+
+ from bs4 import BeautifulSoup
+ if isinstance(new_child, BeautifulSoup):
+ # We don't want to end up with a situation where one BeautifulSoup
+ # object contains another. Insert the children one at a time.
+ for subchild in list(new_child.contents):
+ self.insert(position, subchild)
+ position += 1
+ return
+ position = min(position, len(self.contents))
+ if hasattr(new_child, 'parent') and new_child.parent is not None:
+ # We're 'inserting' an element that's already one
+ # of this object's children.
+ if new_child.parent is self:
+ current_index = self.index(new_child)
+ if current_index < position:
+ # We're moving this element further down the list
+ # of this object's children. That means that when
+ # we extract this element, our target index will
+ # jump down one.
+ position -= 1
+ new_child.extract()
+
+ new_child.parent = self
+ previous_child = None
+ if position == 0:
+ new_child.previous_sibling = None
+ new_child.previous_element = self
+ else:
+ previous_child = self.contents[position - 1]
+ new_child.previous_sibling = previous_child
+ new_child.previous_sibling.next_sibling = new_child
+ new_child.previous_element = previous_child._last_descendant(False)
+ if new_child.previous_element is not None:
+ new_child.previous_element.next_element = new_child
+
+ new_childs_last_element = new_child._last_descendant(False)
+
+ if position >= len(self.contents):
+ new_child.next_sibling = None
+
+ parent = self
+ parents_next_sibling = None
+ while parents_next_sibling is None and parent is not None:
+ parents_next_sibling = parent.next_sibling
+ parent = parent.parent
+ if parents_next_sibling is not None:
+ # We found the element that comes next in the document.
+ break
+ if parents_next_sibling is not None:
+ new_childs_last_element.next_element = parents_next_sibling
+ else:
+ # The last element of this tag is the last element in
+ # the document.
+ new_childs_last_element.next_element = None
+ else:
+ next_child = self.contents[position]
+ new_child.next_sibling = next_child
+ if new_child.next_sibling is not None:
+ new_child.next_sibling.previous_sibling = new_child
+ new_childs_last_element.next_element = next_child
+
+ if new_childs_last_element.next_element is not None:
+ new_childs_last_element.next_element.previous_element = new_childs_last_element
+ self.contents.insert(position, new_child)
+
+ def append(self, tag):
+ """Appends the given tag to the contents of this tag."""
+ self.insert(len(self.contents), tag)
+
+ def extend(self, tags):
+ """Appends the given tags to the contents of this tag."""
+ for tag in tags:
+ self.append(tag)
+
+ def insert_before(self, *args):
+ """Makes the given element(s) the immediate predecessor of this one.
+
+ The elements will have the same parent, and the given elements
+ will be immediately before this one.
+ """
+ parent = self.parent
+ if parent is None:
+ raise ValueError(
+ "Element has no parent, so 'before' has no meaning.")
+ if any(x is self for x in args):
+ raise ValueError("Can't insert an element before itself.")
+ for predecessor in args:
+ # Extract first so that the index won't be screwed up if they
+ # are siblings.
+ if isinstance(predecessor, PageElement):
+ predecessor.extract()
+ index = parent.index(self)
+ parent.insert(index, predecessor)
+
+ def insert_after(self, *args):
+ """Makes the given element(s) the immediate successor of this one.
+
+ The elements will have the same parent, and the given elements
+ will be immediately after this one.
+ """
+ # Do all error checking before modifying the tree.
+ parent = self.parent
+ if parent is None:
+ raise ValueError(
+ "Element has no parent, so 'after' has no meaning.")
+ if any(x is self for x in args):
+ raise ValueError("Can't insert an element after itself.")
+
+ offset = 0
+ for successor in args:
+ # Extract first so that the index won't be screwed up if they
+ # are siblings.
+ if isinstance(successor, PageElement):
+ successor.extract()
+ index = parent.index(self)
+ parent.insert(index+1+offset, successor)
+ offset += 1
+
+ def find_next(self, name=None, attrs={}, text=None, **kwargs):
+ """Returns the first item that matches the given criteria and
+ appears after this Tag in the document."""
+ return self._find_one(self.find_all_next, name, attrs, text, **kwargs)
+ findNext = find_next # BS3
+
+ def find_all_next(self, name=None, attrs={}, text=None, limit=None,
+ **kwargs):
+ """Returns all items that match the given criteria and appear
+ after this Tag in the document."""
+ return self._find_all(name, attrs, text, limit, self.next_elements,
+ **kwargs)
+ findAllNext = find_all_next # BS3
+
+ def find_next_sibling(self, name=None, attrs={}, text=None, **kwargs):
+ """Returns the closest sibling to this Tag that matches the
+ given criteria and appears after this Tag in the document."""
+ return self._find_one(self.find_next_siblings, name, attrs, text,
+ **kwargs)
+ findNextSibling = find_next_sibling # BS3
+
+ def find_next_siblings(self, name=None, attrs={}, text=None, limit=None,
+ **kwargs):
+ """Returns the siblings of this Tag that match the given
+ criteria and appear after this Tag in the document."""
+ return self._find_all(name, attrs, text, limit,
+ self.next_siblings, **kwargs)
+ findNextSiblings = find_next_siblings # BS3
+ fetchNextSiblings = find_next_siblings # BS2
+
+ def find_previous(self, name=None, attrs={}, text=None, **kwargs):
+ """Returns the first item that matches the given criteria and
+ appears before this Tag in the document."""
+ return self._find_one(
+ self.find_all_previous, name, attrs, text, **kwargs)
+ findPrevious = find_previous # BS3
+
+ def find_all_previous(self, name=None, attrs={}, text=None, limit=None,
+ **kwargs):
+ """Returns all items that match the given criteria and appear
+ before this Tag in the document."""
+ return self._find_all(name, attrs, text, limit, self.previous_elements,
+ **kwargs)
+ findAllPrevious = find_all_previous # BS3
+ fetchPrevious = find_all_previous # BS2
+
+ def find_previous_sibling(self, name=None, attrs={}, text=None, **kwargs):
+ """Returns the closest sibling to this Tag that matches the
+ given criteria and appears before this Tag in the document."""
+ return self._find_one(self.find_previous_siblings, name, attrs, text,
+ **kwargs)
+ findPreviousSibling = find_previous_sibling # BS3
+
+ def find_previous_siblings(self, name=None, attrs={}, text=None,
+ limit=None, **kwargs):
+ """Returns the siblings of this Tag that match the given
+ criteria and appear before this Tag in the document."""
+ return self._find_all(name, attrs, text, limit,
+ self.previous_siblings, **kwargs)
+ findPreviousSiblings = find_previous_siblings # BS3
+ fetchPreviousSiblings = find_previous_siblings # BS2
+
+ def find_parent(self, name=None, attrs={}, **kwargs):
+ """Returns the closest parent of this Tag that matches the given
+ criteria."""
+ # NOTE: We can't use _find_one because findParents takes a different
+ # set of arguments.
+ r = None
+ l = self.find_parents(name, attrs, 1, **kwargs)
+ if l:
+ r = l[0]
+ return r
+ findParent = find_parent # BS3
+
+ def find_parents(self, name=None, attrs={}, limit=None, **kwargs):
+ """Returns the parents of this Tag that match the given
+ criteria."""
+
+ return self._find_all(name, attrs, None, limit, self.parents,
+ **kwargs)
+ findParents = find_parents # BS3
+ fetchParents = find_parents # BS2
+
+ @property
+ def next(self):
+ return self.next_element
+
+ @property
+ def previous(self):
+ return self.previous_element
+
+ #These methods do the real heavy lifting.
+
+ def _find_one(self, method, name, attrs, text, **kwargs):
+ r = None
+ l = method(name, attrs, text, 1, **kwargs)
+ if l:
+ r = l[0]
+ return r
+
+ def _find_all(self, name, attrs, text, limit, generator, **kwargs):
+ "Iterates over a generator looking for things that match."
+
+ if text is None and 'string' in kwargs:
+ text = kwargs['string']
+ del kwargs['string']
+
+ if isinstance(name, SoupStrainer):
+ strainer = name
+ else:
+ strainer = SoupStrainer(name, attrs, text, **kwargs)
+
+ if text is None and not limit and not attrs and not kwargs:
+ if name is True or name is None:
+ # Optimization to find all tags.
+ result = (element for element in generator
+ if isinstance(element, Tag))
+ return ResultSet(strainer, result)
+ elif isinstance(name, str):
+ # Optimization to find all tags with a given name.
+ if name.count(':') == 1:
+ # This is a name with a prefix. If this is a namespace-aware document,
+ # we need to match the local name against tag.name. If not,
+ # we need to match the fully-qualified name against tag.name.
+ prefix, local_name = name.split(':', 1)
+ else:
+ prefix = None
+ local_name = name
+ result = (element for element in generator
+ if isinstance(element, Tag)
+ and (
+ element.name == name
+ ) or (
+ element.name == local_name
+ and (prefix is None or element.prefix == prefix)
+ )
+ )
+ return ResultSet(strainer, result)
+ results = ResultSet(strainer)
+ while True:
+ try:
+ i = next(generator)
+ except StopIteration:
+ break
+ if i:
+ found = strainer.search(i)
+ if found:
+ results.append(found)
+ if limit and len(results) >= limit:
+ break
+ return results
+
+ #These generators can be used to navigate starting from both
+ #NavigableStrings and Tags.
+ @property
+ def next_elements(self):
+ i = self.next_element
+ while i is not None:
+ yield i
+ i = i.next_element
+
+ @property
+ def next_siblings(self):
+ i = self.next_sibling
+ while i is not None:
+ yield i
+ i = i.next_sibling
+
+ @property
+ def previous_elements(self):
+ i = self.previous_element
+ while i is not None:
+ yield i
+ i = i.previous_element
+
+ @property
+ def previous_siblings(self):
+ i = self.previous_sibling
+ while i is not None:
+ yield i
+ i = i.previous_sibling
+
+ @property
+ def parents(self):
+ i = self.parent
+ while i is not None:
+ yield i
+ i = i.parent
+
+ # Old non-property versions of the generators, for backwards
+ # compatibility with BS3.
+ def nextGenerator(self):
+ return self.next_elements
+
+ def nextSiblingGenerator(self):
+ return self.next_siblings
+
+ def previousGenerator(self):
+ return self.previous_elements
+
+ def previousSiblingGenerator(self):
+ return self.previous_siblings
+
+ def parentGenerator(self):
+ return self.parents
+
+
+class NavigableString(str, PageElement):
+
+ PREFIX = ''
+ SUFFIX = ''
+
+ # We can't tell just by looking at a string whether it's contained
+ # in an XML document or an HTML document.
+
+ known_xml = None
+
+ def __new__(cls, value):
+ """Create a new NavigableString.
+
+ When unpickling a NavigableString, this method is called with
+ the string in DEFAULT_OUTPUT_ENCODING. That encoding needs to be
+ passed in to the superclass's __new__ or the superclass won't know
+ how to handle non-ASCII characters.
+ """
+ if isinstance(value, str):
+ u = str.__new__(cls, value)
+ else:
+ u = str.__new__(cls, value, DEFAULT_OUTPUT_ENCODING)
+ u.setup()
+ return u
+
+ def __copy__(self):
+ """A copy of a NavigableString has the same contents and class
+ as the original, but it is not connected to the parse tree.
+ """
+ return type(self)(self)
+
+ def __getnewargs__(self):
+ return (str(self),)
+
+ def __getattr__(self, attr):
+ """text.string gives you text. This is for backwards
+ compatibility for Navigable*String, but for CData* it lets you
+ get the string without the CData wrapper."""
+ if attr == 'string':
+ return self
+ else:
+ raise AttributeError(
+ "'%s' object has no attribute '%s'" % (
+ self.__class__.__name__, attr))
+
+ def output_ready(self, formatter="minimal"):
+ """Run the string through the provided formatter."""
+ output = self.format_string(self, formatter)
+ return self.PREFIX + output + self.SUFFIX
+
+ @property
+ def name(self):
+ return None
+
+ @name.setter
+ def name(self, name):
+ raise AttributeError("A NavigableString cannot be given a name.")
+
+class PreformattedString(NavigableString):
+ """A NavigableString not subject to the normal formatting rules.
+
+ The string will be passed into the formatter (to trigger side effects),
+ but the return value will be ignored.
+ """
+
+ def output_ready(self, formatter=None):
+ """CData strings are passed into the formatter, purely
+ for any side effects. The return value is ignored.
+ """
+ if formatter is not None:
+ ignore = self.format_string(self, formatter)
+ return self.PREFIX + self + self.SUFFIX
+
+class CData(PreformattedString):
+
+ PREFIX = '<![CDATA['
+ SUFFIX = ']]>'
+
+class ProcessingInstruction(PreformattedString):
+ """A SGML processing instruction."""
+
+ PREFIX = '<?'
+ SUFFIX = '>'
+
+class XMLProcessingInstruction(ProcessingInstruction):
+ """An XML processing instruction."""
+ PREFIX = '<?'
+ SUFFIX = '?>'
+
+class Comment(PreformattedString):
+
+ PREFIX = '<!--'
+ SUFFIX = '-->'
+
+
+class Declaration(PreformattedString):
+ PREFIX = '<?'
+ SUFFIX = '?>'
+
+
+class Doctype(PreformattedString):
+
+ @classmethod
+ def for_name_and_ids(cls, name, pub_id, system_id):
+ value = name or ''
+ if pub_id is not None:
+ value += ' PUBLIC "%s"' % pub_id
+ if system_id is not None:
+ value += ' "%s"' % system_id
+ elif system_id is not None:
+ value += ' SYSTEM "%s"' % system_id
+
+ return Doctype(value)
+
+ PREFIX = '<!DOCTYPE '
+ SUFFIX = '>\n'
+
+
+class Tag(PageElement):
+
+ """Represents a found HTML tag with its attributes and contents."""
+
+ def __init__(self, parser=None, builder=None, name=None, namespace=None,
+ prefix=None, attrs=None, parent=None, previous=None,
+ is_xml=None):
+ "Basic constructor."
+
+ if parser is None:
+ self.parser_class = None
+ else:
+ # We don't actually store the parser object: that lets extracted
+ # chunks be garbage-collected.
+ self.parser_class = parser.__class__
+ if name is None:
+ raise ValueError("No value provided for new tag's name.")
+ self.name = name
+ self.namespace = namespace
+ self.prefix = prefix
+ if attrs is None:
+ attrs = {}
+ elif attrs:
+ if builder is not None and builder.cdata_list_attributes:
+ attrs = builder._replace_cdata_list_attribute_values(
+ self.name, attrs)
+ else:
+ attrs = dict(attrs)
+ else:
+ attrs = dict(attrs)
+
+ # If possible, determine ahead of time whether this tag is an
+ # XML tag.
+ if builder:
+ self.known_xml = builder.is_xml
+ else:
+ self.known_xml = is_xml
+ self.attrs = attrs
+ self.contents = []
+ self.setup(parent, previous)
+ self.hidden = False
+
+ if builder is None:
+ # In the absence of a TreeBuilder, assume this tag is nothing
+ # special.
+ self.can_be_empty_element = False
+ self.cdata_list_attributes = None
+ else:
+ # Set up any substitutions for this tag, such as the charset in a META tag.
+ builder.set_up_substitutions(self)
+
+ # Ask the TreeBuilder whether this tag might be an empty-element tag.
+ self.can_be_empty_element = builder.can_be_empty_element(name)
+
+ # Keep track of the list of attributes of this tag that
+ # might need to be treated as a list.
+ #
+ # For performance reasons, we store the whole data structure
+ # rather than asking the question of every tag. Asking would
+ # require building a new data structure every time, and
+ # (unlike can_be_empty_element), we almost never need
+ # to check this.
+ self.cdata_list_attributes = builder.cdata_list_attributes
+
+ # Keep track of the names that might cause this tag to be treated as a
+ # whitespace-preserved tag.
+ self.preserve_whitespace_tags = builder.preserve_whitespace_tags
+
+ parserClass = _alias("parser_class") # BS3
+
+ def __copy__(self):
+ """A copy of a Tag is a new Tag, unconnected to the parse tree.
+ Its contents are a copy of the old Tag's contents.
+ """
+ clone = type(self)(None, self.builder, self.name, self.namespace,
+ self.prefix, self.attrs, is_xml=self._is_xml)
+ for attr in ('can_be_empty_element', 'hidden'):
+ setattr(clone, attr, getattr(self, attr))
+ for child in self.contents:
+ clone.append(child.__copy__())
+ return clone
+
+ @property
+ def is_empty_element(self):
+ """Is this tag an empty-element tag? (aka a self-closing tag)
+
+ A tag that has contents is never an empty-element tag.
+
+ A tag that has no contents may or may not be an empty-element
+ tag. It depends on the builder used to create the tag. If the
+ builder has a designated list of empty-element tags, then only
+ a tag whose name shows up in that list is considered an
+ empty-element tag.
+
+ If the builder has no designated list of empty-element tags,
+ then any tag with no contents is an empty-element tag.
+ """
+ return len(self.contents) == 0 and self.can_be_empty_element
+ isSelfClosing = is_empty_element # BS3
+
+ @property
+ def string(self):
+ """Convenience property to get the single string within this tag.
+
+ :Return: If this tag has a single string child, return value
+ is that string. If this tag has no children, or more than one
+ child, return value is None. If this tag has one child tag,
+ return value is the 'string' attribute of the child tag,
+ recursively.
+ """
+ if len(self.contents) != 1:
+ return None
+ child = self.contents[0]
+ if isinstance(child, NavigableString):
+ return child
+ return child.string
+
+ @string.setter
+ def string(self, string):
+ self.clear()
+ self.append(string.__class__(string))
+
+ def _all_strings(self, strip=False, types=(NavigableString, CData)):
+ """Yield all strings of certain classes, possibly stripping them.
+
+ By default, yields only NavigableString and CData objects. So
+ no comments, processing instructions, etc.
+ """
+ for descendant in self.descendants:
+ if (
+ (types is None and not isinstance(descendant, NavigableString))
+ or
+ (types is not None and type(descendant) not in types)):
+ continue
+ if strip:
+ descendant = descendant.strip()
+ if len(descendant) == 0:
+ continue
+ yield descendant
+
+ strings = property(_all_strings)
+
+ @property
+ def stripped_strings(self):
+ for string in self._all_strings(True):
+ yield string
+
+ def get_text(self, separator="", strip=False,
+ types=(NavigableString, CData)):
+ """
+ Get all child strings, concatenated using the given separator.
+ """
+ return separator.join([s for s in self._all_strings(
+ strip, types=types)])
+ getText = get_text
+ text = property(get_text)
+
+ def decompose(self):
+ """Recursively destroys the contents of this tree."""
+ self.extract()
+ i = self
+ while i is not None:
+ next = i.next_element
+ i.__dict__.clear()
+ i.contents = []
+ i = next
+
+ def clear(self, decompose=False):
+ """
+ Extract all children. If decompose is True, decompose instead.
+ """
+ if decompose:
+ for element in self.contents[:]:
+ if isinstance(element, Tag):
+ element.decompose()
+ else:
+ element.extract()
+ else:
+ for element in self.contents[:]:
+ element.extract()
+
+ def smooth(self):
+ """Smooth out this element's children by consolidating consecutive strings.
+
+ This makes pretty-printed output look more natural following a
+ lot of operations that modified the tree.
+ """
+ # Mark the first position of every pair of children that need
+ # to be consolidated. Do this rather than making a copy of
+ # self.contents, since in most cases very few strings will be
+ # affected.
+ marked = []
+ for i, a in enumerate(self.contents):
+ if isinstance(a, Tag):
+ # Recursively smooth children.
+ a.smooth()
+ if i == len(self.contents)-1:
+ # This is the last item in .contents, and it's not a
+ # tag. There's no chance it needs any work.
+ continue
+ b = self.contents[i+1]
+ if (isinstance(a, NavigableString)
+ and isinstance(b, NavigableString)
+ and not isinstance(a, PreformattedString)
+ and not isinstance(b, PreformattedString)
+ ):
+ marked.append(i)
+
+ # Go over the marked positions in reverse order, so that
+ # removing items from .contents won't affect the remaining
+ # positions.
+ for i in reversed(marked):
+ a = self.contents[i]
+ b = self.contents[i+1]
+ b.extract()
+ n = NavigableString(a+b)
+ a.replace_with(n)
+
+ def index(self, element):
+ """
+ Find the index of a child by identity, not value. Avoids issues with
+ tag.contents.index(element) getting the index of equal elements.
+ """
+ for i, child in enumerate(self.contents):
+ if child is element:
+ return i
+ raise ValueError("Tag.index: element not in tag")
+
+ def get(self, key, default=None):
+ """Returns the value of the 'key' attribute for the tag, or
+ the value given for 'default' if it doesn't have that
+ attribute."""
+ return self.attrs.get(key, default)
+
+ def get_attribute_list(self, key, default=None):
+ """The same as get(), but always returns a list."""
+ value = self.get(key, default)
+ if not isinstance(value, list):
+ value = [value]
+ return value
+
+ def has_attr(self, key):
+ return key in self.attrs
+
+ def __hash__(self):
+ return str(self).__hash__()
+
+ def __getitem__(self, key):
+ """tag[key] returns the value of the 'key' attribute for the tag,
+ and throws an exception if it's not there."""
+ return self.attrs[key]
+
+ def __iter__(self):
+ "Iterating over a tag iterates over its contents."
+ return iter(self.contents)
+
+ def __len__(self):
+ "The length of a tag is the length of its list of contents."
+ return len(self.contents)
+
+ def __contains__(self, x):
+ return x in self.contents
+
+ def __bool__(self):
+ "A tag is non-None even if it has no contents."
+ return True
+
+ def __setitem__(self, key, value):
+ """Setting tag[key] sets the value of the 'key' attribute for the
+ tag."""
+ self.attrs[key] = value
+
+ def __delitem__(self, key):
+ "Deleting tag[key] deletes all 'key' attributes for the tag."
+ self.attrs.pop(key, None)
+
+ def __call__(self, *args, **kwargs):
+ """Calling a tag like a function is the same as calling its
+ find_all() method. Eg. tag('a') returns a list of all the A tags
+ found within this tag."""
+ return self.find_all(*args, **kwargs)
+
+ def __getattr__(self, tag):
+ #print "Getattr %s.%s" % (self.__class__, tag)
+ if len(tag) > 3 and tag.endswith('Tag'):
+ # BS3: soup.aTag -> "soup.find("a")
+ tag_name = tag[:-3]
+ warnings.warn(
+ '.%(name)sTag is deprecated, use .find("%(name)s") instead. If you really were looking for a tag called %(name)sTag, use .find("%(name)sTag")' % dict(
+ name=tag_name
+ )
+ )
+ return self.find(tag_name)
+ # We special case contents to avoid recursion.
+ elif not tag.startswith("__") and not tag == "contents":
+ return self.find(tag)
+ raise AttributeError(
+ "'%s' object has no attribute '%s'" % (self.__class__, tag))
+
+ def __eq__(self, other):
+ """Returns true iff this tag has the same name, the same attributes,
+ and the same contents (recursively) as the given tag."""
+ if self is other:
+ return True
+ if (not hasattr(other, 'name') or
+ not hasattr(other, 'attrs') or
+ not hasattr(other, 'contents') or
+ self.name != other.name or
+ self.attrs != other.attrs or
+ len(self) != len(other)):
+ return False
+ for i, my_child in enumerate(self.contents):
+ if my_child != other.contents[i]:
+ return False
+ return True
+
+ def __ne__(self, other):
+ """Returns true iff this tag is not identical to the other tag,
+ as defined in __eq__."""
+ return not self == other
+
+ def __repr__(self, encoding="unicode-escape"):
+ """Renders this tag as a string."""
+ if PY3K:
+ # "The return value must be a string object", i.e. Unicode
+ return self.decode()
+ else:
+ # "The return value must be a string object", i.e. a bytestring.
+ # By convention, the return value of __repr__ should also be
+ # an ASCII string.
+ return self.encode(encoding)
+
+ def __unicode__(self):
+ return self.decode()
+
+ def __str__(self):
+ if PY3K:
+ return self.decode()
+ else:
+ return self.encode()
+
+ if PY3K:
+ __str__ = __repr__ = __unicode__
+
+ def encode(self, encoding=DEFAULT_OUTPUT_ENCODING,
+ indent_level=None, formatter="minimal",
+ errors="xmlcharrefreplace"):
+ # Turn the data structure into Unicode, then encode the
+ # Unicode.
+ u = self.decode(indent_level, encoding, formatter)
+ return u.encode(encoding, errors)
+
+ def decode(self, indent_level=None,
+ eventual_encoding=DEFAULT_OUTPUT_ENCODING,
+ formatter="minimal"):
+ """Returns a Unicode representation of this tag and its contents.
+
+ :param eventual_encoding: The tag is destined to be
+ encoded into this encoding. This method is _not_
+ responsible for performing that encoding. This information
+ is passed in so that it can be substituted in if the
+ document contains a <META> tag that mentions the document's
+ encoding.
+ """
+
+ # First off, turn a non-Formatter `formatter` into a Formatter
+ # object. This will stop the lookup from happening over and
+ # over again.
+ if not isinstance(formatter, Formatter):
+ formatter = self.formatter_for_name(formatter)
+ attributes = formatter.attributes(self)
+ attrs = []
+ for key, val in attributes:
+ if val is None:
+ decoded = key
+ else:
+ if isinstance(val, list) or isinstance(val, tuple):
+ val = ' '.join(val)
+ elif not isinstance(val, str):
+ val = str(val)
+ elif (
+ isinstance(val, AttributeValueWithCharsetSubstitution)
+ and eventual_encoding is not None
+ ):
+ val = val.encode(eventual_encoding)
+
+ text = formatter.attribute_value(val)
+ decoded = (
+ str(key) + '='
+ + formatter.quoted_attribute_value(text))
+ attrs.append(decoded)
+ close = ''
+ closeTag = ''
+
+ prefix = ''
+ if self.prefix:
+ prefix = self.prefix + ":"
+
+ if self.is_empty_element:
+ close = formatter.void_element_close_prefix or ''
+ else:
+ closeTag = '</%s%s>' % (prefix, self.name)
+
+ pretty_print = self._should_pretty_print(indent_level)
+ space = ''
+ indent_space = ''
+ if indent_level is not None:
+ indent_space = (' ' * (indent_level - 1))
+ if pretty_print:
+ space = indent_space
+ indent_contents = indent_level + 1
+ else:
+ indent_contents = None
+ contents = self.decode_contents(
+ indent_contents, eventual_encoding, formatter
+ )
+
+ if self.hidden:
+ # This is the 'document root' object.
+ s = contents
+ else:
+ s = []
+ attribute_string = ''
+ if attrs:
+ attribute_string = ' ' + ' '.join(attrs)
+ if indent_level is not None:
+ # Even if this particular tag is not pretty-printed,
+ # we should indent up to the start of the tag.
+ s.append(indent_space)
+ s.append('<%s%s%s%s>' % (
+ prefix, self.name, attribute_string, close))
+ if pretty_print:
+ s.append("\n")
+ s.append(contents)
+ if pretty_print and contents and contents[-1] != "\n":
+ s.append("\n")
+ if pretty_print and closeTag:
+ s.append(space)
+ s.append(closeTag)
+ if indent_level is not None and closeTag and self.next_sibling:
+ # Even if this particular tag is not pretty-printed,
+ # we're now done with the tag, and we should add a
+ # newline if appropriate.
+ s.append("\n")
+ s = ''.join(s)
+ return s
+
+ def _should_pretty_print(self, indent_level):
+ """Should this tag be pretty-printed?"""
+ return (
+ indent_level is not None
+ and self.name not in self.preserve_whitespace_tags
+ )
+
+ def prettify(self, encoding=None, formatter="minimal"):
+ if encoding is None:
+ return self.decode(True, formatter=formatter)
+ else:
+ return self.encode(encoding, True, formatter=formatter)
+
+ def decode_contents(self, indent_level=None,
+ eventual_encoding=DEFAULT_OUTPUT_ENCODING,
+ formatter="minimal"):
+ """Renders the contents of this tag as a Unicode string.
+
+ :param indent_level: Each line of the rendering will be
+ indented this many spaces.
+
+ :param eventual_encoding: The tag is destined to be
+ encoded into this encoding. decode_contents() is _not_
+ responsible for performing that encoding. This information
+ is passed in so that it can be substituted in if the
+ document contains a <META> tag that mentions the document's
+ encoding.
+
+ :param formatter: A Formatter object, or a string naming one of
+ the standard Formatters.
+ """
+ # First off, turn a string formatter into a Formatter object. This
+ # will stop the lookup from happening over and over again.
+ if not isinstance(formatter, Formatter):
+ formatter = self.formatter_for_name(formatter)
+
+ pretty_print = (indent_level is not None)
+ s = []
+ for c in self:
+ text = None
+ if isinstance(c, NavigableString):
+ text = c.output_ready(formatter)
+ elif isinstance(c, Tag):
+ s.append(c.decode(indent_level, eventual_encoding,
+ formatter))
+ preserve_whitespace = (
+ self.preserve_whitespace_tags and self.name in self.preserve_whitespace_tags
+ )
+ if text and indent_level and not preserve_whitespace:
+ text = text.strip()
+ if text:
+ if pretty_print and not preserve_whitespace:
+ s.append(" " * (indent_level - 1))
+ s.append(text)
+ if pretty_print and not preserve_whitespace:
+ s.append("\n")
+ return ''.join(s)
+
+ def encode_contents(
+ self, indent_level=None, encoding=DEFAULT_OUTPUT_ENCODING,
+ formatter="minimal"):
+ """Renders the contents of this tag as a bytestring.
+
+ :param indent_level: Each line of the rendering will be
+ indented this many spaces.
+
+ :param eventual_encoding: The bytestring will be in this encoding.
+
+ :param formatter: The output formatter responsible for converting
+ entities to Unicode characters.
+ """
+
+ contents = self.decode_contents(indent_level, encoding, formatter)
+ return contents.encode(encoding)
+
+ # Old method for BS3 compatibility
+ def renderContents(self, encoding=DEFAULT_OUTPUT_ENCODING,
+ prettyPrint=False, indentLevel=0):
+ if not prettyPrint:
+ indentLevel = None
+ return self.encode_contents(
+ indent_level=indentLevel, encoding=encoding)
+
+ #Soup methods
+
+ def find(self, name=None, attrs={}, recursive=True, text=None,
+ **kwargs):
+ """Return only the first child of this Tag matching the given
+ criteria."""
+ r = None
+ l = self.find_all(name, attrs, recursive, text, 1, **kwargs)
+ if l:
+ r = l[0]
+ return r
+ findChild = find
+
+ def find_all(self, name=None, attrs={}, recursive=True, text=None,
+ limit=None, **kwargs):
+ """Extracts a list of Tag objects that match the given
+ criteria. You can specify the name of the Tag and any
+ attributes you want the Tag to have.
+
+ The value of a key-value pair in the 'attrs' map can be a
+ string, a list of strings, a regular expression object, or a
+ callable that takes a string and returns whether or not the
+ string matches for some custom definition of 'matches'. The
+ same is true of the tag name."""
+
+ generator = self.descendants
+ if not recursive:
+ generator = self.children
+ return self._find_all(name, attrs, text, limit, generator, **kwargs)
+ findAll = find_all # BS3
+ findChildren = find_all # BS2
+
+ #Generator methods
+ @property
+ def children(self):
+ # return iter() to make the purpose of the method clear
+ return iter(self.contents) # XXX This seems to be untested.
+
+ @property
+ def descendants(self):
+ if not len(self.contents):
+ return
+ stopNode = self._last_descendant().next_element
+ current = self.contents[0]
+ while current is not stopNode:
+ yield current
+ current = current.next_element
+
+ # CSS selector code
+ def select_one(self, selector, namespaces=None, **kwargs):
+ """Perform a CSS selection operation on the current element."""
+ value = self.select(selector, namespaces, 1, **kwargs)
+ if value:
+ return value[0]
+ return None
+
+ def select(self, selector, namespaces=None, limit=None, **kwargs):
+ """Perform a CSS selection operation on the current element.
+
+ This uses the SoupSieve library.
+
+ :param selector: A string containing a CSS selector.
+
+ :param namespaces: A dictionary mapping namespace prefixes
+ used in the CSS selector to namespace URIs. By default,
+ Beautiful Soup will use the prefixes it encountered while
+ parsing the document.
+
+ :param limit: After finding this number of results, stop looking.
+
+ :param kwargs: Any extra arguments you'd like to pass in to
+ soupsieve.select().
+ """
+ if namespaces is None:
+ namespaces = self._namespaces
+
+ if limit is None:
+ limit = 0
+ if soupsieve is None:
+ raise NotImplementedError(
+ "Cannot execute CSS selectors because the soupsieve package is not installed."
+ )
+
+ return soupsieve.select(selector, self, namespaces, limit, **kwargs)
+
+ # Old names for backwards compatibility
+ def childGenerator(self):
+ return self.children
+
+ def recursiveChildGenerator(self):
+ return self.descendants
+
+ def has_key(self, key):
+ """This was kind of misleading because has_key() (attributes)
+ was different from __in__ (contents). has_key() is gone in
+ Python 3, anyway."""
+ warnings.warn('has_key is deprecated. Use has_attr("%s") instead.' % (
+ key))
+ return self.has_attr(key)
+
+# Next, a couple classes to represent queries and their results.
+class SoupStrainer(object):
+ """Encapsulates a number of ways of matching a markup element (tag or
+ text)."""
+
+ def __init__(self, name=None, attrs={}, text=None, **kwargs):
+ self.name = self._normalize_search_value(name)
+ if not isinstance(attrs, dict):
+ # Treat a non-dict value for attrs as a search for the 'class'
+ # attribute.
+ kwargs['class'] = attrs
+ attrs = None
+
+ if 'class_' in kwargs:
+ # Treat class_="foo" as a search for the 'class'
+ # attribute, overriding any non-dict value for attrs.
+ kwargs['class'] = kwargs['class_']
+ del kwargs['class_']
+
+ if kwargs:
+ if attrs:
+ attrs = attrs.copy()
+ attrs.update(kwargs)
+ else:
+ attrs = kwargs
+ normalized_attrs = {}
+ for key, value in list(attrs.items()):
+ normalized_attrs[key] = self._normalize_search_value(value)
+
+ self.attrs = normalized_attrs
+ self.text = self._normalize_search_value(text)
+
+ def _normalize_search_value(self, value):
+ # Leave it alone if it's a Unicode string, a callable, a
+ # regular expression, a boolean, or None.
+ if (isinstance(value, str) or isinstance(value, Callable) or hasattr(value, 'match')
+ or isinstance(value, bool) or value is None):
+ return value
+
+ # If it's a bytestring, convert it to Unicode, treating it as UTF-8.
+ if isinstance(value, bytes):
+ return value.decode("utf8")
+
+ # If it's listlike, convert it into a list of strings.
+ if hasattr(value, '__iter__'):
+ new_value = []
+ for v in value:
+ if (hasattr(v, '__iter__') and not isinstance(v, bytes)
+ and not isinstance(v, str)):
+ # This is almost certainly the user's mistake. In the
+ # interests of avoiding infinite loops, we'll let
+ # it through as-is rather than doing a recursive call.
+ new_value.append(v)
+ else:
+ new_value.append(self._normalize_search_value(v))
+ return new_value
+
+ # Otherwise, convert it into a Unicode string.
+ # The unicode(str()) thing is so this will do the same thing on Python 2
+ # and Python 3.
+ return str(str(value))
+
+ def __str__(self):
+ if self.text:
+ return self.text
+ else:
+ return "%s|%s" % (self.name, self.attrs)
+
+ def search_tag(self, markup_name=None, markup_attrs={}):
+ found = None
+ markup = None
+ if isinstance(markup_name, Tag):
+ markup = markup_name
+ markup_attrs = markup
+ call_function_with_tag_data = (
+ isinstance(self.name, Callable)
+ and not isinstance(markup_name, Tag))
+
+ if ((not self.name)
+ or call_function_with_tag_data
+ or (markup and self._matches(markup, self.name))
+ or (not markup and self._matches(markup_name, self.name))):
+ if call_function_with_tag_data:
+ match = self.name(markup_name, markup_attrs)
+ else:
+ match = True
+ markup_attr_map = None
+ for attr, match_against in list(self.attrs.items()):
+ if not markup_attr_map:
+ if hasattr(markup_attrs, 'get'):
+ markup_attr_map = markup_attrs
+ else:
+ markup_attr_map = {}
+ for k, v in markup_attrs:
+ markup_attr_map[k] = v
+ attr_value = markup_attr_map.get(attr)
+ if not self._matches(attr_value, match_against):
+ match = False
+ break
+ if match:
+ if markup:
+ found = markup
+ else:
+ found = markup_name
+ if found and self.text and not self._matches(found.string, self.text):
+ found = None
+ return found
+ searchTag = search_tag
+
+ def search(self, markup):
+ # print 'looking for %s in %s' % (self, markup)
+ found = None
+ # If given a list of items, scan it for a text element that
+ # matches.
+ if hasattr(markup, '__iter__') and not isinstance(markup, (Tag, str)):
+ for element in markup:
+ if isinstance(element, NavigableString) \
+ and self.search(element):
+ found = element
+ break
+ # If it's a Tag, make sure its name or attributes match.
+ # Don't bother with Tags if we're searching for text.
+ elif isinstance(markup, Tag):
+ if not self.text or self.name or self.attrs:
+ found = self.search_tag(markup)
+ # If it's text, make sure the text matches.
+ elif isinstance(markup, NavigableString) or \
+ isinstance(markup, str):
+ if not self.name and not self.attrs and self._matches(markup, self.text):
+ found = markup
+ else:
+ raise Exception(
+ "I don't know how to match against a %s" % markup.__class__)
+ return found
+
+ def _matches(self, markup, match_against, already_tried=None):
+ # print u"Matching %s against %s" % (markup, match_against)
+ result = False
+ if isinstance(markup, list) or isinstance(markup, tuple):
+ # This should only happen when searching a multi-valued attribute
+ # like 'class'.
+ for item in markup:
+ if self._matches(item, match_against):
+ return True
+ # We didn't match any particular value of the multivalue
+ # attribute, but maybe we match the attribute value when
+ # considered as a string.
+ if self._matches(' '.join(markup), match_against):
+ return True
+ return False
+
+ if match_against is True:
+ # True matches any non-None value.
+ return markup is not None
+
+ if isinstance(match_against, Callable):
+ return match_against(markup)
+
+ # Custom callables take the tag as an argument, but all
+ # other ways of matching match the tag name as a string.
+ original_markup = markup
+ if isinstance(markup, Tag):
+ markup = markup.name
+
+ # Ensure that `markup` is either a Unicode string, or None.
+ markup = self._normalize_search_value(markup)
+
+ if markup is None:
+ # None matches None, False, an empty string, an empty list, and so on.
+ return not match_against
+
+ if (hasattr(match_against, '__iter__')
+ and not isinstance(match_against, str)):
+ # We're asked to match against an iterable of items.
+ # The markup must be match at least one item in the
+ # iterable. We'll try each one in turn.
+ #
+ # To avoid infinite recursion we need to keep track of
+ # items we've already seen.
+ if not already_tried:
+ already_tried = set()
+ for item in match_against:
+ if item.__hash__:
+ key = item
+ else:
+ key = id(item)
+ if key in already_tried:
+ continue
+ else:
+ already_tried.add(key)
+ if self._matches(original_markup, item, already_tried):
+ return True
+ else:
+ return False
+
+ # Beyond this point we might need to run the test twice: once against
+ # the tag's name and once against its prefixed name.
+ match = False
+
+ if not match and isinstance(match_against, str):
+ # Exact string match
+ match = markup == match_against
+
+ if not match and hasattr(match_against, 'search'):
+ # Regexp match
+ return match_against.search(markup)
+
+ if (not match
+ and isinstance(original_markup, Tag)
+ and original_markup.prefix):
+ # Try the whole thing again with the prefixed tag name.
+ return self._matches(
+ original_markup.prefix + ':' + original_markup.name, match_against
+ )
+
+ return match
+
+
+class ResultSet(list):
+ """A ResultSet is just a list that keeps track of the SoupStrainer
+ that created it."""
+ def __init__(self, source, result=()):
+ super(ResultSet, self).__init__(result)
+ self.source = source
+
+ def __getattr__(self, key):
+ raise AttributeError(
+ "ResultSet object has no attribute '%s'. You're probably treating a list of items like a single item. Did you call find_all() when you meant to call find()?" % key
+ )
diff --git a/libs/bs4/formatter.py b/libs/bs4/formatter.py
new file mode 100644
index 000000000..7dbaa3850
--- /dev/null
+++ b/libs/bs4/formatter.py
@@ -0,0 +1,99 @@
+from bs4.dammit import EntitySubstitution
+
+class Formatter(EntitySubstitution):
+ """Describes a strategy to use when outputting a parse tree to a string.
+
+ Some parts of this strategy come from the distinction between
+ HTML4, HTML5, and XML. Others are configurable by the user.
+ """
+ # Registries of XML and HTML formatters.
+ XML_FORMATTERS = {}
+ HTML_FORMATTERS = {}
+
+ HTML = 'html'
+ XML = 'xml'
+
+ HTML_DEFAULTS = dict(
+ cdata_containing_tags=set(["script", "style"]),
+ )
+
+ def _default(self, language, value, kwarg):
+ if value is not None:
+ return value
+ if language == self.XML:
+ return set()
+ return self.HTML_DEFAULTS[kwarg]
+
+ def __init__(
+ self, language=None, entity_substitution=None,
+ void_element_close_prefix='/', cdata_containing_tags=None,
+ ):
+ """
+
+ :param void_element_close_prefix: By default, represent void
+ elements as <tag/> rather than <tag>
+ """
+ self.language = language
+ self.entity_substitution = entity_substitution
+ self.void_element_close_prefix = void_element_close_prefix
+ self.cdata_containing_tags = self._default(
+ language, cdata_containing_tags, 'cdata_containing_tags'
+ )
+
+ def substitute(self, ns):
+ """Process a string that needs to undergo entity substitution."""
+ if not self.entity_substitution:
+ return ns
+ from .element import NavigableString
+ if (isinstance(ns, NavigableString)
+ and ns.parent is not None
+ and ns.parent.name in self.cdata_containing_tags):
+ # Do nothing.
+ return ns
+ # Substitute.
+ return self.entity_substitution(ns)
+
+ def attribute_value(self, value):
+ """Process the value of an attribute."""
+ return self.substitute(value)
+
+ def attributes(self, tag):
+ """Reorder a tag's attributes however you want."""
+ return sorted(tag.attrs.items())
+
+
+class HTMLFormatter(Formatter):
+ REGISTRY = {}
+ def __init__(self, *args, **kwargs):
+ return super(HTMLFormatter, self).__init__(self.HTML, *args, **kwargs)
+
+
+class XMLFormatter(Formatter):
+ REGISTRY = {}
+ def __init__(self, *args, **kwargs):
+ return super(XMLFormatter, self).__init__(self.XML, *args, **kwargs)
+
+
+# Set up aliases for the default formatters.
+HTMLFormatter.REGISTRY['html'] = HTMLFormatter(
+ entity_substitution=EntitySubstitution.substitute_html
+)
+HTMLFormatter.REGISTRY["html5"] = HTMLFormatter(
+ entity_substitution=EntitySubstitution.substitute_html,
+ void_element_close_prefix = None
+)
+HTMLFormatter.REGISTRY["minimal"] = HTMLFormatter(
+ entity_substitution=EntitySubstitution.substitute_xml
+)
+HTMLFormatter.REGISTRY[None] = HTMLFormatter(
+ entity_substitution=None
+)
+XMLFormatter.REGISTRY["html"] = XMLFormatter(
+ entity_substitution=EntitySubstitution.substitute_html
+)
+XMLFormatter.REGISTRY["minimal"] = XMLFormatter(
+ entity_substitution=EntitySubstitution.substitute_xml
+)
+XMLFormatter.REGISTRY[None] = Formatter(
+ Formatter(Formatter.XML, entity_substitution=None)
+)
diff --git a/libs/bs4/testing.py b/libs/bs4/testing.py
new file mode 100644
index 000000000..cc9966601
--- /dev/null
+++ b/libs/bs4/testing.py
@@ -0,0 +1,992 @@
+# encoding: utf-8
+"""Helper classes for tests."""
+
+# Use of this source code is governed by the MIT license.
+__license__ = "MIT"
+
+import pickle
+import copy
+import functools
+import unittest
+from unittest import TestCase
+from bs4 import BeautifulSoup
+from bs4.element import (
+ CharsetMetaAttributeValue,
+ Comment,
+ ContentMetaAttributeValue,
+ Doctype,
+ SoupStrainer,
+ Tag
+)
+
+from bs4.builder import HTMLParserTreeBuilder
+default_builder = HTMLParserTreeBuilder
+
+BAD_DOCUMENT = """A bare string
+<!DOCTYPE xsl:stylesheet SYSTEM "htmlent.dtd">
+<!DOCTYPE xsl:stylesheet PUBLIC "htmlent.dtd">
+<div><![CDATA[A CDATA section where it doesn't belong]]></div>
+<div><svg><![CDATA[HTML5 does allow CDATA sections in SVG]]></svg></div>
+<div>A <meta> tag</div>
+<div>A <br> tag that supposedly has contents.</br></div>
+<div>AT&T</div>
+<div><textarea>Within a textarea, markup like <b> tags and <&<&amp; should be treated as literal</textarea></div>
+<div><script>if (i < 2) { alert("<b>Markup within script tags should be treated as literal.</b>"); }</script></div>
+<div>This numeric entity is missing the final semicolon: <x t="pi&#241ata"></div>
+<div><a href="http://example.com/</a> that attribute value never got closed</div>
+<div><a href="foo</a>, </a><a href="bar">that attribute value was closed by the subsequent tag</a></div>
+<! This document starts with a bogus declaration ><div>a</div>
+<div>This document contains <!an incomplete declaration <div>(do you see it?)</div>
+<div>This document ends with <!an incomplete declaration
+<div><a style={height:21px;}>That attribute value was bogus</a></div>
+<! DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN">The doctype is invalid because it contains extra whitespace
+<div><table><td nowrap>That boolean attribute had no value</td></table></div>
+<div>Here's a nonexistent entity: &#foo; (do you see it?)</div>
+<div>This document ends before the entity finishes: &gt
+<div><p>Paragraphs shouldn't contain block display elements, but this one does: <dl><dt>you see?</dt></p>
+<b b="20" a="1" b="10" a="2" a="3" a="4">Multiple values for the same attribute.</b>
+<div><table><tr><td>Here's a table</td></tr></table></div>
+<div><table id="1"><tr><td>Here's a nested table:<table id="2"><tr><td>foo</td></tr></table></td></div>
+<div>This tag contains nothing but whitespace: <b> </b></div>
+<div><blockquote><p><b>This p tag is cut off by</blockquote></p>the end of the blockquote tag</div>
+<div><table><div>This table contains bare markup</div></table></div>
+<div><div id="1">\n <a href="link1">This link is never closed.\n</div>\n<div id="2">\n <div id="3">\n <a href="link2">This link is closed.</a>\n </div>\n</div></div>
+<div>This document contains a <!DOCTYPE surprise>surprise doctype</div>
+<div><a><B><Cd><EFG>Mixed case tags are folded to lowercase</efg></CD></b></A></div>
+<div><our\u2603>Tag name contains Unicode characters</our\u2603></div>
+<div><a \u2603="snowman">Attribute name contains Unicode characters</a></div>
+<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
+"""
+
+
+class SoupTest(unittest.TestCase):
+
+ @property
+ def default_builder(self):
+ return default_builder
+
+ def soup(self, markup, **kwargs):
+ """Build a Beautiful Soup object from markup."""
+ builder = kwargs.pop('builder', self.default_builder)
+ return BeautifulSoup(markup, builder=builder, **kwargs)
+
+ def document_for(self, markup, **kwargs):
+ """Turn an HTML fragment into a document.
+
+ The details depend on the builder.
+ """
+ return self.default_builder(**kwargs).test_fragment_to_document(markup)
+
+ def assertSoupEquals(self, to_parse, compare_parsed_to=None):
+ builder = self.default_builder
+ obj = BeautifulSoup(to_parse, builder=builder)
+ if compare_parsed_to is None:
+ compare_parsed_to = to_parse
+
+ self.assertEqual(obj.decode(), self.document_for(compare_parsed_to))
+
+ def assertConnectedness(self, element):
+ """Ensure that next_element and previous_element are properly
+ set for all descendants of the given element.
+ """
+ earlier = None
+ for e in element.descendants:
+ if earlier:
+ self.assertEqual(e, earlier.next_element)
+ self.assertEqual(earlier, e.previous_element)
+ earlier = e
+
+ def linkage_validator(self, el, _recursive_call=False):
+ """Ensure proper linkage throughout the document."""
+ descendant = None
+ # Document element should have no previous element or previous sibling.
+ # It also shouldn't have a next sibling.
+ if el.parent is None:
+ assert el.previous_element is None,\
+ "Bad previous_element\nNODE: {}\nPREV: {}\nEXPECTED: {}".format(
+ el, el.previous_element, None
+ )
+ assert el.previous_sibling is None,\
+ "Bad previous_sibling\nNODE: {}\nPREV: {}\nEXPECTED: {}".format(
+ el, el.previous_sibling, None
+ )
+ assert el.next_sibling is None,\
+ "Bad next_sibling\nNODE: {}\nNEXT: {}\nEXPECTED: {}".format(
+ el, el.next_sibling, None
+ )
+
+ idx = 0
+ child = None
+ last_child = None
+ last_idx = len(el.contents) - 1
+ for child in el.contents:
+ descendant = None
+
+ # Parent should link next element to their first child
+ # That child should have no previous sibling
+ if idx == 0:
+ if el.parent is not None:
+ assert el.next_element is child,\
+ "Bad next_element\nNODE: {}\nNEXT: {}\nEXPECTED: {}".format(
+ el, el.next_element, child
+ )
+ assert child.previous_element is el,\
+ "Bad previous_element\nNODE: {}\nPREV: {}\nEXPECTED: {}".format(
+ child, child.previous_element, el
+ )
+ assert child.previous_sibling is None,\
+ "Bad previous_sibling\nNODE: {}\nPREV {}\nEXPECTED: {}".format(
+ child, child.previous_sibling, None
+ )
+
+ # If not the first child, previous index should link as sibling to this index
+ # Previous element should match the last index or the last bubbled up descendant
+ else:
+ assert child.previous_sibling is el.contents[idx - 1],\
+ "Bad previous_sibling\nNODE: {}\nPREV {}\nEXPECTED {}".format(
+ child, child.previous_sibling, el.contents[idx - 1]
+ )
+ assert el.contents[idx - 1].next_sibling is child,\
+ "Bad next_sibling\nNODE: {}\nNEXT {}\nEXPECTED {}".format(
+ el.contents[idx - 1], el.contents[idx - 1].next_sibling, child
+ )
+
+ if last_child is not None:
+ assert child.previous_element is last_child,\
+ "Bad previous_element\nNODE: {}\nPREV {}\nEXPECTED {}\nCONTENTS {}".format(
+ child, child.previous_element, last_child, child.parent.contents
+ )
+ assert last_child.next_element is child,\
+ "Bad next_element\nNODE: {}\nNEXT {}\nEXPECTED {}".format(
+ last_child, last_child.next_element, child
+ )
+
+ if isinstance(child, Tag) and child.contents:
+ descendant = self.linkage_validator(child, True)
+ # A bubbled up descendant should have no next siblings
+ assert descendant.next_sibling is None,\
+ "Bad next_sibling\nNODE: {}\nNEXT {}\nEXPECTED {}".format(
+ descendant, descendant.next_sibling, None
+ )
+
+ # Mark last child as either the bubbled up descendant or the current child
+ if descendant is not None:
+ last_child = descendant
+ else:
+ last_child = child
+
+ # If last child, there are non next siblings
+ if idx == last_idx:
+ assert child.next_sibling is None,\
+ "Bad next_sibling\nNODE: {}\nNEXT {}\nEXPECTED {}".format(
+ child, child.next_sibling, None
+ )
+ idx += 1
+
+ child = descendant if descendant is not None else child
+ if child is None:
+ child = el
+
+ if not _recursive_call and child is not None:
+ target = el
+ while True:
+ if target is None:
+ assert child.next_element is None, \
+ "Bad next_element\nNODE: {}\nNEXT {}\nEXPECTED {}".format(
+ child, child.next_element, None
+ )
+ break
+ elif target.next_sibling is not None:
+ assert child.next_element is target.next_sibling, \
+ "Bad next_element\nNODE: {}\nNEXT {}\nEXPECTED {}".format(
+ child, child.next_element, target.next_sibling
+ )
+ break
+ target = target.parent
+
+ # We are done, so nothing to return
+ return None
+ else:
+ # Return the child to the recursive caller
+ return child
+
+
+class HTMLTreeBuilderSmokeTest(object):
+
+ """A basic test of a treebuilder's competence.
+
+ Any HTML treebuilder, present or future, should be able to pass
+ these tests. With invalid markup, there's room for interpretation,
+ and different parsers can handle it differently. But with the
+ markup in these tests, there's not much room for interpretation.
+ """
+
+ def test_empty_element_tags(self):
+ """Verify that all HTML4 and HTML5 empty element (aka void element) tags
+ are handled correctly.
+ """
+ for name in [
+ 'area', 'base', 'br', 'col', 'embed', 'hr', 'img', 'input', 'keygen', 'link', 'menuitem', 'meta', 'param', 'source', 'track', 'wbr',
+ 'spacer', 'frame'
+ ]:
+ soup = self.soup("")
+ new_tag = soup.new_tag(name)
+ self.assertEqual(True, new_tag.is_empty_element)
+
+ def test_pickle_and_unpickle_identity(self):
+ # Pickling a tree, then unpickling it, yields a tree identical
+ # to the original.
+ tree = self.soup("<a><b>foo</a>")
+ dumped = pickle.dumps(tree, 2)
+ loaded = pickle.loads(dumped)
+ self.assertEqual(loaded.__class__, BeautifulSoup)
+ self.assertEqual(loaded.decode(), tree.decode())
+
+ def assertDoctypeHandled(self, doctype_fragment):
+ """Assert that a given doctype string is handled correctly."""
+ doctype_str, soup = self._document_with_doctype(doctype_fragment)
+
+ # Make sure a Doctype object was created.
+ doctype = soup.contents[0]
+ self.assertEqual(doctype.__class__, Doctype)
+ self.assertEqual(doctype, doctype_fragment)
+ self.assertEqual(str(soup)[:len(doctype_str)], doctype_str)
+
+ # Make sure that the doctype was correctly associated with the
+ # parse tree and that the rest of the document parsed.
+ self.assertEqual(soup.p.contents[0], 'foo')
+
+ def _document_with_doctype(self, doctype_fragment):
+ """Generate and parse a document with the given doctype."""
+ doctype = '<!DOCTYPE %s>' % doctype_fragment
+ markup = doctype + '\n<p>foo</p>'
+ soup = self.soup(markup)
+ return doctype, soup
+
+ def test_normal_doctypes(self):
+ """Make sure normal, everyday HTML doctypes are handled correctly."""
+ self.assertDoctypeHandled("html")
+ self.assertDoctypeHandled(
+ 'html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"')
+
+ def test_empty_doctype(self):
+ soup = self.soup("<!DOCTYPE>")
+ doctype = soup.contents[0]
+ self.assertEqual("", doctype.strip())
+
+ def test_public_doctype_with_url(self):
+ doctype = 'html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"'
+ self.assertDoctypeHandled(doctype)
+
+ def test_system_doctype(self):
+ self.assertDoctypeHandled('foo SYSTEM "http://www.example.com/"')
+
+ def test_namespaced_system_doctype(self):
+ # We can handle a namespaced doctype with a system ID.
+ self.assertDoctypeHandled('xsl:stylesheet SYSTEM "htmlent.dtd"')
+
+ def test_namespaced_public_doctype(self):
+ # Test a namespaced doctype with a public id.
+ self.assertDoctypeHandled('xsl:stylesheet PUBLIC "htmlent.dtd"')
+
+ def test_real_xhtml_document(self):
+ """A real XHTML document should come out more or less the same as it went in."""
+ markup = b"""<?xml version="1.0" encoding="utf-8"?>
+<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN">
+<html xmlns="http://www.w3.org/1999/xhtml">
+<head><title>Hello.</title></head>
+<body>Goodbye.</body>
+</html>"""
+ soup = self.soup(markup)
+ self.assertEqual(
+ soup.encode("utf-8").replace(b"\n", b""),
+ markup.replace(b"\n", b""))
+
+ def test_namespaced_html(self):
+ """When a namespaced XML document is parsed as HTML it should
+ be treated as HTML with weird tag names.
+ """
+ markup = b"""<ns1:foo>content</ns1:foo><ns1:foo/><ns2:foo/>"""
+ soup = self.soup(markup)
+ self.assertEqual(2, len(soup.find_all("ns1:foo")))
+
+ def test_processing_instruction(self):
+ # We test both Unicode and bytestring to verify that
+ # process_markup correctly sets processing_instruction_class
+ # even when the markup is already Unicode and there is no
+ # need to process anything.
+ markup = """<?PITarget PIContent?>"""
+ soup = self.soup(markup)
+ self.assertEqual(markup, soup.decode())
+
+ markup = b"""<?PITarget PIContent?>"""
+ soup = self.soup(markup)
+ self.assertEqual(markup, soup.encode("utf8"))
+
+ def test_deepcopy(self):
+ """Make sure you can copy the tree builder.
+
+ This is important because the builder is part of a
+ BeautifulSoup object, and we want to be able to copy that.
+ """
+ copy.deepcopy(self.default_builder)
+
+ def test_p_tag_is_never_empty_element(self):
+ """A <p> tag is never designated as an empty-element tag.
+
+ Even if the markup shows it as an empty-element tag, it
+ shouldn't be presented that way.
+ """
+ soup = self.soup("<p/>")
+ self.assertFalse(soup.p.is_empty_element)
+ self.assertEqual(str(soup.p), "<p></p>")
+
+ def test_unclosed_tags_get_closed(self):
+ """A tag that's not closed by the end of the document should be closed.
+
+ This applies to all tags except empty-element tags.
+ """
+ self.assertSoupEquals("<p>", "<p></p>")
+ self.assertSoupEquals("<b>", "<b></b>")
+
+ self.assertSoupEquals("<br>", "<br/>")
+
+ def test_br_is_always_empty_element_tag(self):
+ """A <br> tag is designated as an empty-element tag.
+
+ Some parsers treat <br></br> as one <br/> tag, some parsers as
+ two tags, but it should always be an empty-element tag.
+ """
+ soup = self.soup("<br></br>")
+ self.assertTrue(soup.br.is_empty_element)
+ self.assertEqual(str(soup.br), "<br/>")
+
+ def test_nested_formatting_elements(self):
+ self.assertSoupEquals("<em><em></em></em>")
+
+ def test_double_head(self):
+ html = '''<!DOCTYPE html>
+<html>
+<head>
+<title>Ordinary HEAD element test</title>
+</head>
+<script type="text/javascript">
+alert("Help!");
+</script>
+<body>
+Hello, world!
+</body>
+</html>
+'''
+ soup = self.soup(html)
+ self.assertEqual("text/javascript", soup.find('script')['type'])
+
+ def test_comment(self):
+ # Comments are represented as Comment objects.
+ markup = "<p>foo<!--foobar-->baz</p>"
+ self.assertSoupEquals(markup)
+
+ soup = self.soup(markup)
+ comment = soup.find(text="foobar")
+ self.assertEqual(comment.__class__, Comment)
+
+ # The comment is properly integrated into the tree.
+ foo = soup.find(text="foo")
+ self.assertEqual(comment, foo.next_element)
+ baz = soup.find(text="baz")
+ self.assertEqual(comment, baz.previous_element)
+
+ def test_preserved_whitespace_in_pre_and_textarea(self):
+ """Whitespace must be preserved in <pre> and <textarea> tags,
+ even if that would mean not prettifying the markup.
+ """
+ pre_markup = "<pre> </pre>"
+ textarea_markup = "<textarea> woo\nwoo </textarea>"
+ self.assertSoupEquals(pre_markup)
+ self.assertSoupEquals(textarea_markup)
+
+ soup = self.soup(pre_markup)
+ self.assertEqual(soup.pre.prettify(), pre_markup)
+
+ soup = self.soup(textarea_markup)
+ self.assertEqual(soup.textarea.prettify(), textarea_markup)
+
+ soup = self.soup("<textarea></textarea>")
+ self.assertEqual(soup.textarea.prettify(), "<textarea></textarea>")
+
+ def test_nested_inline_elements(self):
+ """Inline elements can be nested indefinitely."""
+ b_tag = "<b>Inside a B tag</b>"
+ self.assertSoupEquals(b_tag)
+
+ nested_b_tag = "<p>A <i>nested <b>tag</b></i></p>"
+ self.assertSoupEquals(nested_b_tag)
+
+ double_nested_b_tag = "<p>A <a>doubly <i>nested <b>tag</b></i></a></p>"
+ self.assertSoupEquals(nested_b_tag)
+
+ def test_nested_block_level_elements(self):
+ """Block elements can be nested."""
+ soup = self.soup('<blockquote><p><b>Foo</b></p></blockquote>')
+ blockquote = soup.blockquote
+ self.assertEqual(blockquote.p.b.string, 'Foo')
+ self.assertEqual(blockquote.b.string, 'Foo')
+
+ def test_correctly_nested_tables(self):
+ """One table can go inside another one."""
+ markup = ('<table id="1">'
+ '<tr>'
+ "<td>Here's another table:"
+ '<table id="2">'
+ '<tr><td>foo</td></tr>'
+ '</table></td>')
+
+ self.assertSoupEquals(
+ markup,
+ '<table id="1"><tr><td>Here\'s another table:'
+ '<table id="2"><tr><td>foo</td></tr></table>'
+ '</td></tr></table>')
+
+ self.assertSoupEquals(
+ "<table><thead><tr><td>Foo</td></tr></thead>"
+ "<tbody><tr><td>Bar</td></tr></tbody>"
+ "<tfoot><tr><td>Baz</td></tr></tfoot></table>")
+
+ def test_multivalued_attribute_with_whitespace(self):
+ # Whitespace separating the values of a multi-valued attribute
+ # should be ignored.
+
+ markup = '<div class=" foo bar "></a>'
+ soup = self.soup(markup)
+ self.assertEqual(['foo', 'bar'], soup.div['class'])
+
+ # If you search by the literal name of the class it's like the whitespace
+ # wasn't there.
+ self.assertEqual(soup.div, soup.find('div', class_="foo bar"))
+
+ def test_deeply_nested_multivalued_attribute(self):
+ # html5lib can set the attributes of the same tag many times
+ # as it rearranges the tree. This has caused problems with
+ # multivalued attributes.
+ markup = '<table><div><div class="css"></div></div></table>'
+ soup = self.soup(markup)
+ self.assertEqual(["css"], soup.div.div['class'])
+
+ def test_multivalued_attribute_on_html(self):
+ # html5lib uses a different API to set the attributes ot the
+ # <html> tag. This has caused problems with multivalued
+ # attributes.
+ markup = '<html class="a b"></html>'
+ soup = self.soup(markup)
+ self.assertEqual(["a", "b"], soup.html['class'])
+
+ def test_angle_brackets_in_attribute_values_are_escaped(self):
+ self.assertSoupEquals('<a b="<a>"></a>', '<a b="&lt;a&gt;"></a>')
+
+ def test_strings_resembling_character_entity_references(self):
+ # "&T" and "&p" look like incomplete character entities, but they are
+ # not.
+ self.assertSoupEquals(
+ "<p>&bull; AT&T is in the s&p 500</p>",
+ "<p>\u2022 AT&amp;T is in the s&amp;p 500</p>"
+ )
+
+ def test_apos_entity(self):
+ self.assertSoupEquals(
+ "<p>Bob&apos;s Bar</p>",
+ "<p>Bob's Bar</p>",
+ )
+
+ def test_entities_in_foreign_document_encoding(self):
+ # &#147; and &#148; are invalid numeric entities referencing
+ # Windows-1252 characters. &#45; references a character common
+ # to Windows-1252 and Unicode, and &#9731; references a
+ # character only found in Unicode.
+ #
+ # All of these entities should be converted to Unicode
+ # characters.
+ markup = "<p>&#147;Hello&#148; &#45;&#9731;</p>"
+ soup = self.soup(markup)
+ self.assertEqual("“Hello” -☃", soup.p.string)
+
+ def test_entities_in_attributes_converted_to_unicode(self):
+ expect = '<p id="pi\N{LATIN SMALL LETTER N WITH TILDE}ata"></p>'
+ self.assertSoupEquals('<p id="pi&#241;ata"></p>', expect)
+ self.assertSoupEquals('<p id="pi&#xf1;ata"></p>', expect)
+ self.assertSoupEquals('<p id="pi&#Xf1;ata"></p>', expect)
+ self.assertSoupEquals('<p id="pi&ntilde;ata"></p>', expect)
+
+ def test_entities_in_text_converted_to_unicode(self):
+ expect = '<p>pi\N{LATIN SMALL LETTER N WITH TILDE}ata</p>'
+ self.assertSoupEquals("<p>pi&#241;ata</p>", expect)
+ self.assertSoupEquals("<p>pi&#xf1;ata</p>", expect)
+ self.assertSoupEquals("<p>pi&#Xf1;ata</p>", expect)
+ self.assertSoupEquals("<p>pi&ntilde;ata</p>", expect)
+
+ def test_quot_entity_converted_to_quotation_mark(self):
+ self.assertSoupEquals("<p>I said &quot;good day!&quot;</p>",
+ '<p>I said "good day!"</p>')
+
+ def test_out_of_range_entity(self):
+ expect = "\N{REPLACEMENT CHARACTER}"
+ self.assertSoupEquals("&#10000000000000;", expect)
+ self.assertSoupEquals("&#x10000000000000;", expect)
+ self.assertSoupEquals("&#1000000000;", expect)
+
+ def test_multipart_strings(self):
+ "Mostly to prevent a recurrence of a bug in the html5lib treebuilder."
+ soup = self.soup("<html><h2>\nfoo</h2><p></p></html>")
+ self.assertEqual("p", soup.h2.string.next_element.name)
+ self.assertEqual("p", soup.p.name)
+ self.assertConnectedness(soup)
+
+ def test_empty_element_tags(self):
+ """Verify consistent handling of empty-element tags,
+ no matter how they come in through the markup.
+ """
+ self.assertSoupEquals('<br/><br/><br/>', "<br/><br/><br/>")
+ self.assertSoupEquals('<br /><br /><br />', "<br/><br/><br/>")
+
+ def test_head_tag_between_head_and_body(self):
+ "Prevent recurrence of a bug in the html5lib treebuilder."
+ content = """<html><head></head>
+ <link></link>
+ <body>foo</body>
+</html>
+"""
+ soup = self.soup(content)
+ self.assertNotEqual(None, soup.html.body)
+ self.assertConnectedness(soup)
+
+ def test_multiple_copies_of_a_tag(self):
+ "Prevent recurrence of a bug in the html5lib treebuilder."
+ content = """<!DOCTYPE html>
+<html>
+ <body>
+ <article id="a" >
+ <div><a href="1"></div>
+ <footer>
+ <a href="2"></a>
+ </footer>
+ </article>
+ </body>
+</html>
+"""
+ soup = self.soup(content)
+ self.assertConnectedness(soup.article)
+
+ def test_basic_namespaces(self):
+ """Parsers don't need to *understand* namespaces, but at the
+ very least they should not choke on namespaces or lose
+ data."""
+
+ markup = b'<html xmlns="http://www.w3.org/1999/xhtml" xmlns:mathml="http://www.w3.org/1998/Math/MathML" xmlns:svg="http://www.w3.org/2000/svg"><head></head><body><mathml:msqrt>4</mathml:msqrt><b svg:fill="red"></b></body></html>'
+ soup = self.soup(markup)
+ self.assertEqual(markup, soup.encode())
+ html = soup.html
+ self.assertEqual('http://www.w3.org/1999/xhtml', soup.html['xmlns'])
+ self.assertEqual(
+ 'http://www.w3.org/1998/Math/MathML', soup.html['xmlns:mathml'])
+ self.assertEqual(
+ 'http://www.w3.org/2000/svg', soup.html['xmlns:svg'])
+
+ def test_multivalued_attribute_value_becomes_list(self):
+ markup = b'<a class="foo bar">'
+ soup = self.soup(markup)
+ self.assertEqual(['foo', 'bar'], soup.a['class'])
+
+ #
+ # Generally speaking, tests below this point are more tests of
+ # Beautiful Soup than tests of the tree builders. But parsers are
+ # weird, so we run these tests separately for every tree builder
+ # to detect any differences between them.
+ #
+
+ def test_can_parse_unicode_document(self):
+ # A seemingly innocuous document... but it's in Unicode! And
+ # it contains characters that can't be represented in the
+ # encoding found in the declaration! The horror!
+ markup = '<html><head><meta encoding="euc-jp"></head><body>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</body>'
+ soup = self.soup(markup)
+ self.assertEqual('Sacr\xe9 bleu!', soup.body.string)
+
+ def test_soupstrainer(self):
+ """Parsers should be able to work with SoupStrainers."""
+ strainer = SoupStrainer("b")
+ soup = self.soup("A <b>bold</b> <meta/> <i>statement</i>",
+ parse_only=strainer)
+ self.assertEqual(soup.decode(), "<b>bold</b>")
+
+ def test_single_quote_attribute_values_become_double_quotes(self):
+ self.assertSoupEquals("<foo attr='bar'></foo>",
+ '<foo attr="bar"></foo>')
+
+ def test_attribute_values_with_nested_quotes_are_left_alone(self):
+ text = """<foo attr='bar "brawls" happen'>a</foo>"""
+ self.assertSoupEquals(text)
+
+ def test_attribute_values_with_double_nested_quotes_get_quoted(self):
+ text = """<foo attr='bar "brawls" happen'>a</foo>"""
+ soup = self.soup(text)
+ soup.foo['attr'] = 'Brawls happen at "Bob\'s Bar"'
+ self.assertSoupEquals(
+ soup.foo.decode(),
+ """<foo attr="Brawls happen at &quot;Bob\'s Bar&quot;">a</foo>""")
+
+ def test_ampersand_in_attribute_value_gets_escaped(self):
+ self.assertSoupEquals('<this is="really messed up & stuff"></this>',
+ '<this is="really messed up &amp; stuff"></this>')
+
+ self.assertSoupEquals(
+ '<a href="http://example.org?a=1&b=2;3">foo</a>',
+ '<a href="http://example.org?a=1&amp;b=2;3">foo</a>')
+
+ def test_escaped_ampersand_in_attribute_value_is_left_alone(self):
+ self.assertSoupEquals('<a href="http://example.org?a=1&amp;b=2;3"></a>')
+
+ def test_entities_in_strings_converted_during_parsing(self):
+ # Both XML and HTML entities are converted to Unicode characters
+ # during parsing.
+ text = "<p>&lt;&lt;sacr&eacute;&#32;bleu!&gt;&gt;</p>"
+ expected = "<p>&lt;&lt;sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</p>"
+ self.assertSoupEquals(text, expected)
+
+ def test_smart_quotes_converted_on_the_way_in(self):
+ # Microsoft smart quotes are converted to Unicode characters during
+ # parsing.
+ quote = b"<p>\x91Foo\x92</p>"
+ soup = self.soup(quote)
+ self.assertEqual(
+ soup.p.string,
+ "\N{LEFT SINGLE QUOTATION MARK}Foo\N{RIGHT SINGLE QUOTATION MARK}")
+
+ def test_non_breaking_spaces_converted_on_the_way_in(self):
+ soup = self.soup("<a>&nbsp;&nbsp;</a>")
+ self.assertEqual(soup.a.string, "\N{NO-BREAK SPACE}" * 2)
+
+ def test_entities_converted_on_the_way_out(self):
+ text = "<p>&lt;&lt;sacr&eacute;&#32;bleu!&gt;&gt;</p>"
+ expected = "<p>&lt;&lt;sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</p>".encode("utf-8")
+ soup = self.soup(text)
+ self.assertEqual(soup.p.encode("utf-8"), expected)
+
+ def test_real_iso_latin_document(self):
+ # Smoke test of interrelated functionality, using an
+ # easy-to-understand document.
+
+ # Here it is in Unicode. Note that it claims to be in ISO-Latin-1.
+ unicode_html = '<html><head><meta content="text/html; charset=ISO-Latin-1" http-equiv="Content-type"/></head><body><p>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</p></body></html>'
+
+ # That's because we're going to encode it into ISO-Latin-1, and use
+ # that to test.
+ iso_latin_html = unicode_html.encode("iso-8859-1")
+
+ # Parse the ISO-Latin-1 HTML.
+ soup = self.soup(iso_latin_html)
+ # Encode it to UTF-8.
+ result = soup.encode("utf-8")
+
+ # What do we expect the result to look like? Well, it would
+ # look like unicode_html, except that the META tag would say
+ # UTF-8 instead of ISO-Latin-1.
+ expected = unicode_html.replace("ISO-Latin-1", "utf-8")
+
+ # And, of course, it would be in UTF-8, not Unicode.
+ expected = expected.encode("utf-8")
+
+ # Ta-da!
+ self.assertEqual(result, expected)
+
+ def test_real_shift_jis_document(self):
+ # Smoke test to make sure the parser can handle a document in
+ # Shift-JIS encoding, without choking.
+ shift_jis_html = (
+ b'<html><head></head><body><pre>'
+ b'\x82\xb1\x82\xea\x82\xcdShift-JIS\x82\xc5\x83R\x81[\x83f'
+ b'\x83B\x83\x93\x83O\x82\xb3\x82\xea\x82\xbd\x93\xfa\x96{\x8c'
+ b'\xea\x82\xcc\x83t\x83@\x83C\x83\x8b\x82\xc5\x82\xb7\x81B'
+ b'</pre></body></html>')
+ unicode_html = shift_jis_html.decode("shift-jis")
+ soup = self.soup(unicode_html)
+
+ # Make sure the parse tree is correctly encoded to various
+ # encodings.
+ self.assertEqual(soup.encode("utf-8"), unicode_html.encode("utf-8"))
+ self.assertEqual(soup.encode("euc_jp"), unicode_html.encode("euc_jp"))
+
+ def test_real_hebrew_document(self):
+ # A real-world test to make sure we can convert ISO-8859-9 (a
+ # Hebrew encoding) to UTF-8.
+ hebrew_document = b'<html><head><title>Hebrew (ISO 8859-8) in Visual Directionality</title></head><body><h1>Hebrew (ISO 8859-8) in Visual Directionality</h1>\xed\xe5\xec\xf9</body></html>'
+ soup = self.soup(
+ hebrew_document, from_encoding="iso8859-8")
+ # Some tree builders call it iso8859-8, others call it iso-8859-9.
+ # That's not a difference we really care about.
+ assert soup.original_encoding in ('iso8859-8', 'iso-8859-8')
+ self.assertEqual(
+ soup.encode('utf-8'),
+ hebrew_document.decode("iso8859-8").encode("utf-8"))
+
+ def test_meta_tag_reflects_current_encoding(self):
+ # Here's the <meta> tag saying that a document is
+ # encoded in Shift-JIS.
+ meta_tag = ('<meta content="text/html; charset=x-sjis" '
+ 'http-equiv="Content-type"/>')
+
+ # Here's a document incorporating that meta tag.
+ shift_jis_html = (
+ '<html><head>\n%s\n'
+ '<meta http-equiv="Content-language" content="ja"/>'
+ '</head><body>Shift-JIS markup goes here.') % meta_tag
+ soup = self.soup(shift_jis_html)
+
+ # Parse the document, and the charset is seemingly unaffected.
+ parsed_meta = soup.find('meta', {'http-equiv': 'Content-type'})
+ content = parsed_meta['content']
+ self.assertEqual('text/html; charset=x-sjis', content)
+
+ # But that value is actually a ContentMetaAttributeValue object.
+ self.assertTrue(isinstance(content, ContentMetaAttributeValue))
+
+ # And it will take on a value that reflects its current
+ # encoding.
+ self.assertEqual('text/html; charset=utf8', content.encode("utf8"))
+
+ # For the rest of the story, see TestSubstitutions in
+ # test_tree.py.
+
+ def test_html5_style_meta_tag_reflects_current_encoding(self):
+ # Here's the <meta> tag saying that a document is
+ # encoded in Shift-JIS.
+ meta_tag = ('<meta id="encoding" charset="x-sjis" />')
+
+ # Here's a document incorporating that meta tag.
+ shift_jis_html = (
+ '<html><head>\n%s\n'
+ '<meta http-equiv="Content-language" content="ja"/>'
+ '</head><body>Shift-JIS markup goes here.') % meta_tag
+ soup = self.soup(shift_jis_html)
+
+ # Parse the document, and the charset is seemingly unaffected.
+ parsed_meta = soup.find('meta', id="encoding")
+ charset = parsed_meta['charset']
+ self.assertEqual('x-sjis', charset)
+
+ # But that value is actually a CharsetMetaAttributeValue object.
+ self.assertTrue(isinstance(charset, CharsetMetaAttributeValue))
+
+ # And it will take on a value that reflects its current
+ # encoding.
+ self.assertEqual('utf8', charset.encode("utf8"))
+
+ def test_tag_with_no_attributes_can_have_attributes_added(self):
+ data = self.soup("<a>text</a>")
+ data.a['foo'] = 'bar'
+ self.assertEqual('<a foo="bar">text</a>', data.a.decode())
+
+ def test_worst_case(self):
+ """Test the worst case (currently) for linking issues."""
+
+ soup = self.soup(BAD_DOCUMENT)
+ self.linkage_validator(soup)
+
+
+class XMLTreeBuilderSmokeTest(object):
+
+ def test_pickle_and_unpickle_identity(self):
+ # Pickling a tree, then unpickling it, yields a tree identical
+ # to the original.
+ tree = self.soup("<a><b>foo</a>")
+ dumped = pickle.dumps(tree, 2)
+ loaded = pickle.loads(dumped)
+ self.assertEqual(loaded.__class__, BeautifulSoup)
+ self.assertEqual(loaded.decode(), tree.decode())
+
+ def test_docstring_generated(self):
+ soup = self.soup("<root/>")
+ self.assertEqual(
+ soup.encode(), b'<?xml version="1.0" encoding="utf-8"?>\n<root/>')
+
+ def test_xml_declaration(self):
+ markup = b"""<?xml version="1.0" encoding="utf8"?>\n<foo/>"""
+ soup = self.soup(markup)
+ self.assertEqual(markup, soup.encode("utf8"))
+
+ def test_processing_instruction(self):
+ markup = b"""<?xml version="1.0" encoding="utf8"?>\n<?PITarget PIContent?>"""
+ soup = self.soup(markup)
+ self.assertEqual(markup, soup.encode("utf8"))
+
+ def test_real_xhtml_document(self):
+ """A real XHTML document should come out *exactly* the same as it went in."""
+ markup = b"""<?xml version="1.0" encoding="utf-8"?>
+<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN">
+<html xmlns="http://www.w3.org/1999/xhtml">
+<head><title>Hello.</title></head>
+<body>Goodbye.</body>
+</html>"""
+ soup = self.soup(markup)
+ self.assertEqual(
+ soup.encode("utf-8"), markup)
+
+ def test_nested_namespaces(self):
+ doc = b"""<?xml version="1.0" encoding="utf-8"?>
+<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.1//EN" "http://www.w3.org/TR/xhtml11/DTD/xhtml11.dtd">
+<parent xmlns="http://ns1/">
+<child xmlns="http://ns2/" xmlns:ns3="http://ns3/">
+<grandchild ns3:attr="value" xmlns="http://ns4/"/>
+</child>
+</parent>"""
+ soup = self.soup(doc)
+ self.assertEqual(doc, soup.encode())
+
+ def test_formatter_processes_script_tag_for_xml_documents(self):
+ doc = """
+ <script type="text/javascript">
+ </script>
+"""
+ soup = BeautifulSoup(doc, "lxml-xml")
+ # lxml would have stripped this while parsing, but we can add
+ # it later.
+ soup.script.string = 'console.log("< < hey > > ");'
+ encoded = soup.encode()
+ self.assertTrue(b"&lt; &lt; hey &gt; &gt;" in encoded)
+
+ def test_can_parse_unicode_document(self):
+ markup = '<?xml version="1.0" encoding="euc-jp"><root>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</root>'
+ soup = self.soup(markup)
+ self.assertEqual('Sacr\xe9 bleu!', soup.root.string)
+
+ def test_popping_namespaced_tag(self):
+ markup = '<rss xmlns:dc="foo"><dc:creator>b</dc:creator><dc:date>2012-07-02T20:33:42Z</dc:date><dc:rights>c</dc:rights><image>d</image></rss>'
+ soup = self.soup(markup)
+ self.assertEqual(
+ str(soup.rss), markup)
+
+ def test_docstring_includes_correct_encoding(self):
+ soup = self.soup("<root/>")
+ self.assertEqual(
+ soup.encode("latin1"),
+ b'<?xml version="1.0" encoding="latin1"?>\n<root/>')
+
+ def test_large_xml_document(self):
+ """A large XML document should come out the same as it went in."""
+ markup = (b'<?xml version="1.0" encoding="utf-8"?>\n<root>'
+ + b'0' * (2**12)
+ + b'</root>')
+ soup = self.soup(markup)
+ self.assertEqual(soup.encode("utf-8"), markup)
+
+
+ def test_tags_are_empty_element_if_and_only_if_they_are_empty(self):
+ self.assertSoupEquals("<p>", "<p/>")
+ self.assertSoupEquals("<p>foo</p>")
+
+ def test_namespaces_are_preserved(self):
+ markup = '<root xmlns:a="http://example.com/" xmlns:b="http://example.net/"><a:foo>This tag is in the a namespace</a:foo><b:foo>This tag is in the b namespace</b:foo></root>'
+ soup = self.soup(markup)
+ root = soup.root
+ self.assertEqual("http://example.com/", root['xmlns:a'])
+ self.assertEqual("http://example.net/", root['xmlns:b'])
+
+ def test_closing_namespaced_tag(self):
+ markup = '<p xmlns:dc="http://purl.org/dc/elements/1.1/"><dc:date>20010504</dc:date></p>'
+ soup = self.soup(markup)
+ self.assertEqual(str(soup.p), markup)
+
+ def test_namespaced_attributes(self):
+ markup = '<foo xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"><bar xsi:schemaLocation="http://www.example.com"/></foo>'
+ soup = self.soup(markup)
+ self.assertEqual(str(soup.foo), markup)
+
+ def test_namespaced_attributes_xml_namespace(self):
+ markup = '<foo xml:lang="fr">bar</foo>'
+ soup = self.soup(markup)
+ self.assertEqual(str(soup.foo), markup)
+
+ def test_find_by_prefixed_name(self):
+ doc = """<?xml version="1.0" encoding="utf-8"?>
+<Document xmlns="http://example.com/ns0"
+ xmlns:ns1="http://example.com/ns1"
+ xmlns:ns2="http://example.com/ns2"
+ <ns1:tag>foo</ns1:tag>
+ <ns1:tag>bar</ns1:tag>
+ <ns2:tag key="value">baz</ns2:tag>
+</Document>
+"""
+ soup = self.soup(doc)
+
+ # There are three <tag> tags.
+ self.assertEqual(3, len(soup.find_all('tag')))
+
+ # But two of them are ns1:tag and one of them is ns2:tag.
+ self.assertEqual(2, len(soup.find_all('ns1:tag')))
+ self.assertEqual(1, len(soup.find_all('ns2:tag')))
+
+ self.assertEqual(1, len(soup.find_all('ns2:tag', key='value')))
+ self.assertEqual(3, len(soup.find_all(['ns1:tag', 'ns2:tag'])))
+
+ def test_copy_tag_preserves_namespace(self):
+ xml = """<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
+<w:document xmlns:w="http://example.com/ns0"/>"""
+
+ soup = self.soup(xml)
+ tag = soup.document
+ duplicate = copy.copy(tag)
+
+ # The two tags have the same namespace prefix.
+ self.assertEqual(tag.prefix, duplicate.prefix)
+
+ def test_worst_case(self):
+ """Test the worst case (currently) for linking issues."""
+
+ soup = self.soup(BAD_DOCUMENT)
+ self.linkage_validator(soup)
+
+
+class HTML5TreeBuilderSmokeTest(HTMLTreeBuilderSmokeTest):
+ """Smoke test for a tree builder that supports HTML5."""
+
+ def test_real_xhtml_document(self):
+ # Since XHTML is not HTML5, HTML5 parsers are not tested to handle
+ # XHTML documents in any particular way.
+ pass
+
+ def test_html_tags_have_namespace(self):
+ markup = "<a>"
+ soup = self.soup(markup)
+ self.assertEqual("http://www.w3.org/1999/xhtml", soup.a.namespace)
+
+ def test_svg_tags_have_namespace(self):
+ markup = '<svg><circle/></svg>'
+ soup = self.soup(markup)
+ namespace = "http://www.w3.org/2000/svg"
+ self.assertEqual(namespace, soup.svg.namespace)
+ self.assertEqual(namespace, soup.circle.namespace)
+
+
+ def test_mathml_tags_have_namespace(self):
+ markup = '<math><msqrt>5</msqrt></math>'
+ soup = self.soup(markup)
+ namespace = 'http://www.w3.org/1998/Math/MathML'
+ self.assertEqual(namespace, soup.math.namespace)
+ self.assertEqual(namespace, soup.msqrt.namespace)
+
+ def test_xml_declaration_becomes_comment(self):
+ markup = '<?xml version="1.0" encoding="utf-8"?><html></html>'
+ soup = self.soup(markup)
+ self.assertTrue(isinstance(soup.contents[0], Comment))
+ self.assertEqual(soup.contents[0], '?xml version="1.0" encoding="utf-8"?')
+ self.assertEqual("html", soup.contents[0].next_element.name)
+
+def skipIf(condition, reason):
+ def nothing(test, *args, **kwargs):
+ return None
+
+ def decorator(test_item):
+ if condition:
+ return nothing
+ else:
+ return test_item
+
+ return decorator
diff --git a/libs/bs4/tests/__init__.py b/libs/bs4/tests/__init__.py
new file mode 100644
index 000000000..142c8cc3f
--- /dev/null
+++ b/libs/bs4/tests/__init__.py
@@ -0,0 +1 @@
+"The beautifulsoup tests."
diff --git a/libs/bs4/tests/test_builder_registry.py b/libs/bs4/tests/test_builder_registry.py
new file mode 100644
index 000000000..90cad8293
--- /dev/null
+++ b/libs/bs4/tests/test_builder_registry.py
@@ -0,0 +1,147 @@
+"""Tests of the builder registry."""
+
+import unittest
+import warnings
+
+from bs4 import BeautifulSoup
+from bs4.builder import (
+ builder_registry as registry,
+ HTMLParserTreeBuilder,
+ TreeBuilderRegistry,
+)
+
+try:
+ from bs4.builder import HTML5TreeBuilder
+ HTML5LIB_PRESENT = True
+except ImportError:
+ HTML5LIB_PRESENT = False
+
+try:
+ from bs4.builder import (
+ LXMLTreeBuilderForXML,
+ LXMLTreeBuilder,
+ )
+ LXML_PRESENT = True
+except ImportError:
+ LXML_PRESENT = False
+
+
+class BuiltInRegistryTest(unittest.TestCase):
+ """Test the built-in registry with the default builders registered."""
+
+ def test_combination(self):
+ if LXML_PRESENT:
+ self.assertEqual(registry.lookup('fast', 'html'),
+ LXMLTreeBuilder)
+
+ if LXML_PRESENT:
+ self.assertEqual(registry.lookup('permissive', 'xml'),
+ LXMLTreeBuilderForXML)
+ self.assertEqual(registry.lookup('strict', 'html'),
+ HTMLParserTreeBuilder)
+ if HTML5LIB_PRESENT:
+ self.assertEqual(registry.lookup('html5lib', 'html'),
+ HTML5TreeBuilder)
+
+ def test_lookup_by_markup_type(self):
+ if LXML_PRESENT:
+ self.assertEqual(registry.lookup('html'), LXMLTreeBuilder)
+ self.assertEqual(registry.lookup('xml'), LXMLTreeBuilderForXML)
+ else:
+ self.assertEqual(registry.lookup('xml'), None)
+ if HTML5LIB_PRESENT:
+ self.assertEqual(registry.lookup('html'), HTML5TreeBuilder)
+ else:
+ self.assertEqual(registry.lookup('html'), HTMLParserTreeBuilder)
+
+ def test_named_library(self):
+ if LXML_PRESENT:
+ self.assertEqual(registry.lookup('lxml', 'xml'),
+ LXMLTreeBuilderForXML)
+ self.assertEqual(registry.lookup('lxml', 'html'),
+ LXMLTreeBuilder)
+ if HTML5LIB_PRESENT:
+ self.assertEqual(registry.lookup('html5lib'),
+ HTML5TreeBuilder)
+
+ self.assertEqual(registry.lookup('html.parser'),
+ HTMLParserTreeBuilder)
+
+ def test_beautifulsoup_constructor_does_lookup(self):
+
+ with warnings.catch_warnings(record=True) as w:
+ # This will create a warning about not explicitly
+ # specifying a parser, but we'll ignore it.
+
+ # You can pass in a string.
+ BeautifulSoup("", features="html")
+ # Or a list of strings.
+ BeautifulSoup("", features=["html", "fast"])
+
+ # You'll get an exception if BS can't find an appropriate
+ # builder.
+ self.assertRaises(ValueError, BeautifulSoup,
+ "", features="no-such-feature")
+
+class RegistryTest(unittest.TestCase):
+ """Test the TreeBuilderRegistry class in general."""
+
+ def setUp(self):
+ self.registry = TreeBuilderRegistry()
+
+ def builder_for_features(self, *feature_list):
+ cls = type('Builder_' + '_'.join(feature_list),
+ (object,), {'features' : feature_list})
+
+ self.registry.register(cls)
+ return cls
+
+ def test_register_with_no_features(self):
+ builder = self.builder_for_features()
+
+ # Since the builder advertises no features, you can't find it
+ # by looking up features.
+ self.assertEqual(self.registry.lookup('foo'), None)
+
+ # But you can find it by doing a lookup with no features, if
+ # this happens to be the only registered builder.
+ self.assertEqual(self.registry.lookup(), builder)
+
+ def test_register_with_features_makes_lookup_succeed(self):
+ builder = self.builder_for_features('foo', 'bar')
+ self.assertEqual(self.registry.lookup('foo'), builder)
+ self.assertEqual(self.registry.lookup('bar'), builder)
+
+ def test_lookup_fails_when_no_builder_implements_feature(self):
+ builder = self.builder_for_features('foo', 'bar')
+ self.assertEqual(self.registry.lookup('baz'), None)
+
+ def test_lookup_gets_most_recent_registration_when_no_feature_specified(self):
+ builder1 = self.builder_for_features('foo')
+ builder2 = self.builder_for_features('bar')
+ self.assertEqual(self.registry.lookup(), builder2)
+
+ def test_lookup_fails_when_no_tree_builders_registered(self):
+ self.assertEqual(self.registry.lookup(), None)
+
+ def test_lookup_gets_most_recent_builder_supporting_all_features(self):
+ has_one = self.builder_for_features('foo')
+ has_the_other = self.builder_for_features('bar')
+ has_both_early = self.builder_for_features('foo', 'bar', 'baz')
+ has_both_late = self.builder_for_features('foo', 'bar', 'quux')
+ lacks_one = self.builder_for_features('bar')
+ has_the_other = self.builder_for_features('foo')
+
+ # There are two builders featuring 'foo' and 'bar', but
+ # the one that also features 'quux' was registered later.
+ self.assertEqual(self.registry.lookup('foo', 'bar'),
+ has_both_late)
+
+ # There is only one builder featuring 'foo', 'bar', and 'baz'.
+ self.assertEqual(self.registry.lookup('foo', 'bar', 'baz'),
+ has_both_early)
+
+ def test_lookup_fails_when_cannot_reconcile_requested_features(self):
+ builder1 = self.builder_for_features('foo', 'bar')
+ builder2 = self.builder_for_features('foo', 'baz')
+ self.assertEqual(self.registry.lookup('bar', 'baz'), None)
diff --git a/libs/bs4/tests/test_docs.py b/libs/bs4/tests/test_docs.py
new file mode 100644
index 000000000..5b9f67709
--- /dev/null
+++ b/libs/bs4/tests/test_docs.py
@@ -0,0 +1,36 @@
+"Test harness for doctests."
+
+# pylint: disable-msg=E0611,W0142
+
+__metaclass__ = type
+__all__ = [
+ 'additional_tests',
+ ]
+
+import atexit
+import doctest
+import os
+#from pkg_resources import (
+# resource_filename, resource_exists, resource_listdir, cleanup_resources)
+import unittest
+
+DOCTEST_FLAGS = (
+ doctest.ELLIPSIS |
+ doctest.NORMALIZE_WHITESPACE |
+ doctest.REPORT_NDIFF)
+
+
+# def additional_tests():
+# "Run the doc tests (README.txt and docs/*, if any exist)"
+# doctest_files = [
+# os.path.abspath(resource_filename('bs4', 'README.txt'))]
+# if resource_exists('bs4', 'docs'):
+# for name in resource_listdir('bs4', 'docs'):
+# if name.endswith('.txt'):
+# doctest_files.append(
+# os.path.abspath(
+# resource_filename('bs4', 'docs/%s' % name)))
+# kwargs = dict(module_relative=False, optionflags=DOCTEST_FLAGS)
+# atexit.register(cleanup_resources)
+# return unittest.TestSuite((
+# doctest.DocFileSuite(*doctest_files, **kwargs)))
diff --git a/libs/bs4/tests/test_html5lib.py b/libs/bs4/tests/test_html5lib.py
new file mode 100644
index 000000000..96529b0b3
--- /dev/null
+++ b/libs/bs4/tests/test_html5lib.py
@@ -0,0 +1,170 @@
+"""Tests to ensure that the html5lib tree builder generates good trees."""
+
+import warnings
+
+try:
+ from bs4.builder import HTML5TreeBuilder
+ HTML5LIB_PRESENT = True
+except ImportError as e:
+ HTML5LIB_PRESENT = False
+from bs4.element import SoupStrainer
+from bs4.testing import (
+ HTML5TreeBuilderSmokeTest,
+ SoupTest,
+ skipIf,
+)
+
+@skipIf(
+ not HTML5LIB_PRESENT,
+ "html5lib seems not to be present, not testing its tree builder.")
+class HTML5LibBuilderSmokeTest(SoupTest, HTML5TreeBuilderSmokeTest):
+ """See ``HTML5TreeBuilderSmokeTest``."""
+
+ @property
+ def default_builder(self):
+ return HTML5TreeBuilder
+
+ def test_soupstrainer(self):
+ # The html5lib tree builder does not support SoupStrainers.
+ strainer = SoupStrainer("b")
+ markup = "<p>A <b>bold</b> statement.</p>"
+ with warnings.catch_warnings(record=True) as w:
+ soup = self.soup(markup, parse_only=strainer)
+ self.assertEqual(
+ soup.decode(), self.document_for(markup))
+
+ self.assertTrue(
+ "the html5lib tree builder doesn't support parse_only" in
+ str(w[0].message))
+
+ def test_correctly_nested_tables(self):
+ """html5lib inserts <tbody> tags where other parsers don't."""
+ markup = ('<table id="1">'
+ '<tr>'
+ "<td>Here's another table:"
+ '<table id="2">'
+ '<tr><td>foo</td></tr>'
+ '</table></td>')
+
+ self.assertSoupEquals(
+ markup,
+ '<table id="1"><tbody><tr><td>Here\'s another table:'
+ '<table id="2"><tbody><tr><td>foo</td></tr></tbody></table>'
+ '</td></tr></tbody></table>')
+
+ self.assertSoupEquals(
+ "<table><thead><tr><td>Foo</td></tr></thead>"
+ "<tbody><tr><td>Bar</td></tr></tbody>"
+ "<tfoot><tr><td>Baz</td></tr></tfoot></table>")
+
+ def test_xml_declaration_followed_by_doctype(self):
+ markup = '''<?xml version="1.0" encoding="utf-8"?>
+<!DOCTYPE html>
+<html>
+ <head>
+ </head>
+ <body>
+ <p>foo</p>
+ </body>
+</html>'''
+ soup = self.soup(markup)
+ # Verify that we can reach the <p> tag; this means the tree is connected.
+ self.assertEqual(b"<p>foo</p>", soup.p.encode())
+
+ def test_reparented_markup(self):
+ markup = '<p><em>foo</p>\n<p>bar<a></a></em></p>'
+ soup = self.soup(markup)
+ self.assertEqual("<body><p><em>foo</em></p><em>\n</em><p><em>bar<a></a></em></p></body>", soup.body.decode())
+ self.assertEqual(2, len(soup.find_all('p')))
+
+
+ def test_reparented_markup_ends_with_whitespace(self):
+ markup = '<p><em>foo</p>\n<p>bar<a></a></em></p>\n'
+ soup = self.soup(markup)
+ self.assertEqual("<body><p><em>foo</em></p><em>\n</em><p><em>bar<a></a></em></p>\n</body>", soup.body.decode())
+ self.assertEqual(2, len(soup.find_all('p')))
+
+ def test_reparented_markup_containing_identical_whitespace_nodes(self):
+ """Verify that we keep the two whitespace nodes in this
+ document distinct when reparenting the adjacent <tbody> tags.
+ """
+ markup = '<table> <tbody><tbody><ims></tbody> </table>'
+ soup = self.soup(markup)
+ space1, space2 = soup.find_all(string=' ')
+ tbody1, tbody2 = soup.find_all('tbody')
+ assert space1.next_element is tbody1
+ assert tbody2.next_element is space2
+
+ def test_reparented_markup_containing_children(self):
+ markup = '<div><a>aftermath<p><noscript>target</noscript>aftermath</a></p></div>'
+ soup = self.soup(markup)
+ noscript = soup.noscript
+ self.assertEqual("target", noscript.next_element)
+ target = soup.find(string='target')
+
+ # The 'aftermath' string was duplicated; we want the second one.
+ final_aftermath = soup.find_all(string='aftermath')[-1]
+
+ # The <noscript> tag was moved beneath a copy of the <a> tag,
+ # but the 'target' string within is still connected to the
+ # (second) 'aftermath' string.
+ self.assertEqual(final_aftermath, target.next_element)
+ self.assertEqual(target, final_aftermath.previous_element)
+
+ def test_processing_instruction(self):
+ """Processing instructions become comments."""
+ markup = b"""<?PITarget PIContent?>"""
+ soup = self.soup(markup)
+ assert str(soup).startswith("<!--?PITarget PIContent?-->")
+
+ def test_cloned_multivalue_node(self):
+ markup = b"""<a class="my_class"><p></a>"""
+ soup = self.soup(markup)
+ a1, a2 = soup.find_all('a')
+ self.assertEqual(a1, a2)
+ assert a1 is not a2
+
+ def test_foster_parenting(self):
+ markup = b"""<table><td></tbody>A"""
+ soup = self.soup(markup)
+ self.assertEqual("<body>A<table><tbody><tr><td></td></tr></tbody></table></body>", soup.body.decode())
+
+ def test_extraction(self):
+ """
+ Test that extraction does not destroy the tree.
+
+ https://bugs.launchpad.net/beautifulsoup/+bug/1782928
+ """
+
+ markup = """
+<html><head></head>
+<style>
+</style><script></script><body><p>hello</p></body></html>
+"""
+ soup = self.soup(markup)
+ [s.extract() for s in soup('script')]
+ [s.extract() for s in soup('style')]
+
+ self.assertEqual(len(soup.find_all("p")), 1)
+
+ def test_empty_comment(self):
+ """
+ Test that empty comment does not break structure.
+
+ https://bugs.launchpad.net/beautifulsoup/+bug/1806598
+ """
+
+ markup = """
+<html>
+<body>
+<form>
+<!----><input type="text">
+</form>
+</body>
+</html>
+"""
+ soup = self.soup(markup)
+ inputs = []
+ for form in soup.find_all('form'):
+ inputs.extend(form.find_all('input'))
+ self.assertEqual(len(inputs), 1)
diff --git a/libs/bs4/tests/test_htmlparser.py b/libs/bs4/tests/test_htmlparser.py
new file mode 100644
index 000000000..790489aa1
--- /dev/null
+++ b/libs/bs4/tests/test_htmlparser.py
@@ -0,0 +1,47 @@
+"""Tests to ensure that the html.parser tree builder generates good
+trees."""
+
+from pdb import set_trace
+import pickle
+from bs4.testing import SoupTest, HTMLTreeBuilderSmokeTest
+from bs4.builder import HTMLParserTreeBuilder
+from bs4.builder._htmlparser import BeautifulSoupHTMLParser
+
+class HTMLParserTreeBuilderSmokeTest(SoupTest, HTMLTreeBuilderSmokeTest):
+
+ default_builder = HTMLParserTreeBuilder
+
+ def test_namespaced_system_doctype(self):
+ # html.parser can't handle namespaced doctypes, so skip this one.
+ pass
+
+ def test_namespaced_public_doctype(self):
+ # html.parser can't handle namespaced doctypes, so skip this one.
+ pass
+
+ def test_builder_is_pickled(self):
+ """Unlike most tree builders, HTMLParserTreeBuilder and will
+ be restored after pickling.
+ """
+ tree = self.soup("<a><b>foo</a>")
+ dumped = pickle.dumps(tree, 2)
+ loaded = pickle.loads(dumped)
+ self.assertTrue(isinstance(loaded.builder, type(tree.builder)))
+
+ def test_redundant_empty_element_closing_tags(self):
+ self.assertSoupEquals('<br></br><br></br><br></br>', "<br/><br/><br/>")
+ self.assertSoupEquals('</br></br></br>', "")
+
+ def test_empty_element(self):
+ # This verifies that any buffered data present when the parser
+ # finishes working is handled.
+ self.assertSoupEquals("foo &# bar", "foo &amp;# bar")
+
+
+class TestHTMLParserSubclass(SoupTest):
+ def test_error(self):
+ """Verify that our HTMLParser subclass implements error() in a way
+ that doesn't cause a crash.
+ """
+ parser = BeautifulSoupHTMLParser()
+ parser.error("don't crash")
diff --git a/libs/bs4/tests/test_lxml.py b/libs/bs4/tests/test_lxml.py
new file mode 100644
index 000000000..29da71149
--- /dev/null
+++ b/libs/bs4/tests/test_lxml.py
@@ -0,0 +1,100 @@
+"""Tests to ensure that the lxml tree builder generates good trees."""
+
+import re
+import warnings
+
+try:
+ import lxml.etree
+ LXML_PRESENT = True
+ LXML_VERSION = lxml.etree.LXML_VERSION
+except ImportError as e:
+ LXML_PRESENT = False
+ LXML_VERSION = (0,)
+
+if LXML_PRESENT:
+ from bs4.builder import LXMLTreeBuilder, LXMLTreeBuilderForXML
+
+from bs4 import (
+ BeautifulSoup,
+ BeautifulStoneSoup,
+ )
+from bs4.element import Comment, Doctype, SoupStrainer
+from bs4.testing import skipIf
+from bs4.tests import test_htmlparser
+from bs4.testing import (
+ HTMLTreeBuilderSmokeTest,
+ XMLTreeBuilderSmokeTest,
+ SoupTest,
+ skipIf,
+)
+
+@skipIf(
+ not LXML_PRESENT,
+ "lxml seems not to be present, not testing its tree builder.")
+class LXMLTreeBuilderSmokeTest(SoupTest, HTMLTreeBuilderSmokeTest):
+ """See ``HTMLTreeBuilderSmokeTest``."""
+
+ @property
+ def default_builder(self):
+ return LXMLTreeBuilder
+
+ def test_out_of_range_entity(self):
+ self.assertSoupEquals(
+ "<p>foo&#10000000000000;bar</p>", "<p>foobar</p>")
+ self.assertSoupEquals(
+ "<p>foo&#x10000000000000;bar</p>", "<p>foobar</p>")
+ self.assertSoupEquals(
+ "<p>foo&#1000000000;bar</p>", "<p>foobar</p>")
+
+ def test_entities_in_foreign_document_encoding(self):
+ # We can't implement this case correctly because by the time we
+ # hear about markup like "&#147;", it's been (incorrectly) converted into
+ # a string like u'\x93'
+ pass
+
+ # In lxml < 2.3.5, an empty doctype causes a segfault. Skip this
+ # test if an old version of lxml is installed.
+
+ @skipIf(
+ not LXML_PRESENT or LXML_VERSION < (2,3,5,0),
+ "Skipping doctype test for old version of lxml to avoid segfault.")
+ def test_empty_doctype(self):
+ soup = self.soup("<!DOCTYPE>")
+ doctype = soup.contents[0]
+ self.assertEqual("", doctype.strip())
+
+ def test_beautifulstonesoup_is_xml_parser(self):
+ # Make sure that the deprecated BSS class uses an xml builder
+ # if one is installed.
+ with warnings.catch_warnings(record=True) as w:
+ soup = BeautifulStoneSoup("<b />")
+ self.assertEqual("<b/>", str(soup.b))
+ self.assertTrue("BeautifulStoneSoup class is deprecated" in str(w[0].message))
+
+@skipIf(
+ not LXML_PRESENT,
+ "lxml seems not to be present, not testing its XML tree builder.")
+class LXMLXMLTreeBuilderSmokeTest(SoupTest, XMLTreeBuilderSmokeTest):
+ """See ``HTMLTreeBuilderSmokeTest``."""
+
+ @property
+ def default_builder(self):
+ return LXMLTreeBuilderForXML
+
+ def test_namespace_indexing(self):
+ # We should not track un-prefixed namespaces as we can only hold one
+ # and it will be recognized as the default namespace by soupsieve,
+ # which may be confusing in some situations. When no namespace is provided
+ # for a selector, the default namespace (if defined) is assumed.
+
+ soup = self.soup(
+ '<?xml version="1.1"?>\n'
+ '<root>'
+ '<tag xmlns="http://unprefixed-namespace.com">content</tag>'
+ '<prefix:tag xmlns:prefix="http://prefixed-namespace.com">content</tag>'
+ '</root>'
+ )
+ self.assertEqual(
+ soup._namespaces,
+ {'xml': 'http://www.w3.org/XML/1998/namespace', 'prefix': 'http://prefixed-namespace.com'}
+ )
diff --git a/libs/bs4/tests/test_soup.py b/libs/bs4/tests/test_soup.py
new file mode 100644
index 000000000..1eda9484b
--- /dev/null
+++ b/libs/bs4/tests/test_soup.py
@@ -0,0 +1,567 @@
+# -*- coding: utf-8 -*-
+"""Tests of Beautiful Soup as a whole."""
+
+from pdb import set_trace
+import logging
+import unittest
+import sys
+import tempfile
+
+from bs4 import (
+ BeautifulSoup,
+ BeautifulStoneSoup,
+)
+from bs4.element import (
+ CharsetMetaAttributeValue,
+ ContentMetaAttributeValue,
+ SoupStrainer,
+ NamespacedAttribute,
+ )
+import bs4.dammit
+from bs4.dammit import (
+ EntitySubstitution,
+ UnicodeDammit,
+ EncodingDetector,
+)
+from bs4.testing import (
+ default_builder,
+ SoupTest,
+ skipIf,
+)
+import warnings
+
+try:
+ from bs4.builder import LXMLTreeBuilder, LXMLTreeBuilderForXML
+ LXML_PRESENT = True
+except ImportError as e:
+ LXML_PRESENT = False
+
+PYTHON_3_PRE_3_2 = (sys.version_info[0] == 3 and sys.version_info < (3,2))
+
+class TestConstructor(SoupTest):
+
+ def test_short_unicode_input(self):
+ data = "<h1>éé</h1>"
+ soup = self.soup(data)
+ self.assertEqual("éé", soup.h1.string)
+
+ def test_embedded_null(self):
+ data = "<h1>foo\0bar</h1>"
+ soup = self.soup(data)
+ self.assertEqual("foo\0bar", soup.h1.string)
+
+ def test_exclude_encodings(self):
+ utf8_data = "Räksmörgås".encode("utf-8")
+ soup = self.soup(utf8_data, exclude_encodings=["utf-8"])
+ self.assertEqual("windows-1252", soup.original_encoding)
+
+ def test_custom_builder_class(self):
+ # Verify that you can pass in a custom Builder class and
+ # it'll be instantiated with the appropriate keyword arguments.
+ class Mock(object):
+ def __init__(self, **kwargs):
+ self.called_with = kwargs
+ self.is_xml = True
+ def initialize_soup(self, soup):
+ pass
+ def prepare_markup(self, *args, **kwargs):
+ return ''
+
+ kwargs = dict(
+ var="value",
+ # This is a deprecated BS3-era keyword argument, which
+ # will be stripped out.
+ convertEntities=True,
+ )
+ with warnings.catch_warnings(record=True):
+ soup = BeautifulSoup('', builder=Mock, **kwargs)
+ assert isinstance(soup.builder, Mock)
+ self.assertEqual(dict(var="value"), soup.builder.called_with)
+
+ # You can also instantiate the TreeBuilder yourself. In this
+ # case, that specific object is used and any keyword arguments
+ # to the BeautifulSoup constructor are ignored.
+ builder = Mock(**kwargs)
+ with warnings.catch_warnings(record=True) as w:
+ soup = BeautifulSoup(
+ '', builder=builder, ignored_value=True,
+ )
+ msg = str(w[0].message)
+ assert msg.startswith("Keyword arguments to the BeautifulSoup constructor will be ignored.")
+ self.assertEqual(builder, soup.builder)
+ self.assertEqual(kwargs, builder.called_with)
+
+ def test_cdata_list_attributes(self):
+ # Most attribute values are represented as scalars, but the
+ # HTML standard says that some attributes, like 'class' have
+ # space-separated lists as values.
+ markup = '<a id=" an id " class=" a class "></a>'
+ soup = self.soup(markup)
+
+ # Note that the spaces are stripped for 'class' but not for 'id'.
+ a = soup.a
+ self.assertEqual(" an id ", a['id'])
+ self.assertEqual(["a", "class"], a['class'])
+
+ # TreeBuilder takes an argument called 'mutli_valued_attributes' which lets
+ # you customize or disable this. As always, you can customize the TreeBuilder
+ # by passing in a keyword argument to the BeautifulSoup constructor.
+ soup = self.soup(markup, builder=default_builder, multi_valued_attributes=None)
+ self.assertEqual(" a class ", soup.a['class'])
+
+ # Here are two ways of saying that `id` is a multi-valued
+ # attribute in this context, but 'class' is not.
+ for switcheroo in ({'*': 'id'}, {'a': 'id'}):
+ with warnings.catch_warnings(record=True) as w:
+ # This will create a warning about not explicitly
+ # specifying a parser, but we'll ignore it.
+ soup = self.soup(markup, builder=None, multi_valued_attributes=switcheroo)
+ a = soup.a
+ self.assertEqual(["an", "id"], a['id'])
+ self.assertEqual(" a class ", a['class'])
+
+
+class TestWarnings(SoupTest):
+
+ def _no_parser_specified(self, s, is_there=True):
+ v = s.startswith(BeautifulSoup.NO_PARSER_SPECIFIED_WARNING[:80])
+ self.assertTrue(v)
+
+ def test_warning_if_no_parser_specified(self):
+ with warnings.catch_warnings(record=True) as w:
+ soup = self.soup("<a><b></b></a>")
+ msg = str(w[0].message)
+ self._assert_no_parser_specified(msg)
+
+ def test_warning_if_parser_specified_too_vague(self):
+ with warnings.catch_warnings(record=True) as w:
+ soup = self.soup("<a><b></b></a>", "html")
+ msg = str(w[0].message)
+ self._assert_no_parser_specified(msg)
+
+ def test_no_warning_if_explicit_parser_specified(self):
+ with warnings.catch_warnings(record=True) as w:
+ soup = self.soup("<a><b></b></a>", "html.parser")
+ self.assertEqual([], w)
+
+ def test_parseOnlyThese_renamed_to_parse_only(self):
+ with warnings.catch_warnings(record=True) as w:
+ soup = self.soup("<a><b></b></a>", parseOnlyThese=SoupStrainer("b"))
+ msg = str(w[0].message)
+ self.assertTrue("parseOnlyThese" in msg)
+ self.assertTrue("parse_only" in msg)
+ self.assertEqual(b"<b></b>", soup.encode())
+
+ def test_fromEncoding_renamed_to_from_encoding(self):
+ with warnings.catch_warnings(record=True) as w:
+ utf8 = b"\xc3\xa9"
+ soup = self.soup(utf8, fromEncoding="utf8")
+ msg = str(w[0].message)
+ self.assertTrue("fromEncoding" in msg)
+ self.assertTrue("from_encoding" in msg)
+ self.assertEqual("utf8", soup.original_encoding)
+
+ def test_unrecognized_keyword_argument(self):
+ self.assertRaises(
+ TypeError, self.soup, "<a>", no_such_argument=True)
+
+class TestWarnings(SoupTest):
+
+ def test_disk_file_warning(self):
+ filehandle = tempfile.NamedTemporaryFile()
+ filename = filehandle.name
+ try:
+ with warnings.catch_warnings(record=True) as w:
+ soup = self.soup(filename)
+ msg = str(w[0].message)
+ self.assertTrue("looks like a filename" in msg)
+ finally:
+ filehandle.close()
+
+ # The file no longer exists, so Beautiful Soup will no longer issue the warning.
+ with warnings.catch_warnings(record=True) as w:
+ soup = self.soup(filename)
+ self.assertEqual(0, len(w))
+
+ def test_url_warning_with_bytes_url(self):
+ with warnings.catch_warnings(record=True) as warning_list:
+ soup = self.soup(b"http://www.crummybytes.com/")
+ # Be aware this isn't the only warning that can be raised during
+ # execution..
+ self.assertTrue(any("looks like a URL" in str(w.message)
+ for w in warning_list))
+
+ def test_url_warning_with_unicode_url(self):
+ with warnings.catch_warnings(record=True) as warning_list:
+ # note - this url must differ from the bytes one otherwise
+ # python's warnings system swallows the second warning
+ soup = self.soup("http://www.crummyunicode.com/")
+ self.assertTrue(any("looks like a URL" in str(w.message)
+ for w in warning_list))
+
+ def test_url_warning_with_bytes_and_space(self):
+ with warnings.catch_warnings(record=True) as warning_list:
+ soup = self.soup(b"http://www.crummybytes.com/ is great")
+ self.assertFalse(any("looks like a URL" in str(w.message)
+ for w in warning_list))
+
+ def test_url_warning_with_unicode_and_space(self):
+ with warnings.catch_warnings(record=True) as warning_list:
+ soup = self.soup("http://www.crummyuncode.com/ is great")
+ self.assertFalse(any("looks like a URL" in str(w.message)
+ for w in warning_list))
+
+
+class TestSelectiveParsing(SoupTest):
+
+ def test_parse_with_soupstrainer(self):
+ markup = "No<b>Yes</b><a>No<b>Yes <c>Yes</c></b>"
+ strainer = SoupStrainer("b")
+ soup = self.soup(markup, parse_only=strainer)
+ self.assertEqual(soup.encode(), b"<b>Yes</b><b>Yes <c>Yes</c></b>")
+
+
+class TestEntitySubstitution(unittest.TestCase):
+ """Standalone tests of the EntitySubstitution class."""
+ def setUp(self):
+ self.sub = EntitySubstitution
+
+ def test_simple_html_substitution(self):
+ # Unicode characters corresponding to named HTML entites
+ # are substituted, and no others.
+ s = "foo\u2200\N{SNOWMAN}\u00f5bar"
+ self.assertEqual(self.sub.substitute_html(s),
+ "foo&forall;\N{SNOWMAN}&otilde;bar")
+
+ def test_smart_quote_substitution(self):
+ # MS smart quotes are a common source of frustration, so we
+ # give them a special test.
+ quotes = b"\x91\x92foo\x93\x94"
+ dammit = UnicodeDammit(quotes)
+ self.assertEqual(self.sub.substitute_html(dammit.markup),
+ "&lsquo;&rsquo;foo&ldquo;&rdquo;")
+
+ def test_xml_converstion_includes_no_quotes_if_make_quoted_attribute_is_false(self):
+ s = 'Welcome to "my bar"'
+ self.assertEqual(self.sub.substitute_xml(s, False), s)
+
+ def test_xml_attribute_quoting_normally_uses_double_quotes(self):
+ self.assertEqual(self.sub.substitute_xml("Welcome", True),
+ '"Welcome"')
+ self.assertEqual(self.sub.substitute_xml("Bob's Bar", True),
+ '"Bob\'s Bar"')
+
+ def test_xml_attribute_quoting_uses_single_quotes_when_value_contains_double_quotes(self):
+ s = 'Welcome to "my bar"'
+ self.assertEqual(self.sub.substitute_xml(s, True),
+ "'Welcome to \"my bar\"'")
+
+ def test_xml_attribute_quoting_escapes_single_quotes_when_value_contains_both_single_and_double_quotes(self):
+ s = 'Welcome to "Bob\'s Bar"'
+ self.assertEqual(
+ self.sub.substitute_xml(s, True),
+ '"Welcome to &quot;Bob\'s Bar&quot;"')
+
+ def test_xml_quotes_arent_escaped_when_value_is_not_being_quoted(self):
+ quoted = 'Welcome to "Bob\'s Bar"'
+ self.assertEqual(self.sub.substitute_xml(quoted), quoted)
+
+ def test_xml_quoting_handles_angle_brackets(self):
+ self.assertEqual(
+ self.sub.substitute_xml("foo<bar>"),
+ "foo&lt;bar&gt;")
+
+ def test_xml_quoting_handles_ampersands(self):
+ self.assertEqual(self.sub.substitute_xml("AT&T"), "AT&amp;T")
+
+ def test_xml_quoting_including_ampersands_when_they_are_part_of_an_entity(self):
+ self.assertEqual(
+ self.sub.substitute_xml("&Aacute;T&T"),
+ "&amp;Aacute;T&amp;T")
+
+ def test_xml_quoting_ignoring_ampersands_when_they_are_part_of_an_entity(self):
+ self.assertEqual(
+ self.sub.substitute_xml_containing_entities("&Aacute;T&T"),
+ "&Aacute;T&amp;T")
+
+ def test_quotes_not_html_substituted(self):
+ """There's no need to do this except inside attribute values."""
+ text = 'Bob\'s "bar"'
+ self.assertEqual(self.sub.substitute_html(text), text)
+
+
+class TestEncodingConversion(SoupTest):
+ # Test Beautiful Soup's ability to decode and encode from various
+ # encodings.
+
+ def setUp(self):
+ super(TestEncodingConversion, self).setUp()
+ self.unicode_data = '<html><head><meta charset="utf-8"/></head><body><foo>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</foo></body></html>'
+ self.utf8_data = self.unicode_data.encode("utf-8")
+ # Just so you know what it looks like.
+ self.assertEqual(
+ self.utf8_data,
+ b'<html><head><meta charset="utf-8"/></head><body><foo>Sacr\xc3\xa9 bleu!</foo></body></html>')
+
+ def test_ascii_in_unicode_out(self):
+ # ASCII input is converted to Unicode. The original_encoding
+ # attribute is set to 'utf-8', a superset of ASCII.
+ chardet = bs4.dammit.chardet_dammit
+ logging.disable(logging.WARNING)
+ try:
+ def noop(str):
+ return None
+ # Disable chardet, which will realize that the ASCII is ASCII.
+ bs4.dammit.chardet_dammit = noop
+ ascii = b"<foo>a</foo>"
+ soup_from_ascii = self.soup(ascii)
+ unicode_output = soup_from_ascii.decode()
+ self.assertTrue(isinstance(unicode_output, str))
+ self.assertEqual(unicode_output, self.document_for(ascii.decode()))
+ self.assertEqual(soup_from_ascii.original_encoding.lower(), "utf-8")
+ finally:
+ logging.disable(logging.NOTSET)
+ bs4.dammit.chardet_dammit = chardet
+
+ def test_unicode_in_unicode_out(self):
+ # Unicode input is left alone. The original_encoding attribute
+ # is not set.
+ soup_from_unicode = self.soup(self.unicode_data)
+ self.assertEqual(soup_from_unicode.decode(), self.unicode_data)
+ self.assertEqual(soup_from_unicode.foo.string, 'Sacr\xe9 bleu!')
+ self.assertEqual(soup_from_unicode.original_encoding, None)
+
+ def test_utf8_in_unicode_out(self):
+ # UTF-8 input is converted to Unicode. The original_encoding
+ # attribute is set.
+ soup_from_utf8 = self.soup(self.utf8_data)
+ self.assertEqual(soup_from_utf8.decode(), self.unicode_data)
+ self.assertEqual(soup_from_utf8.foo.string, 'Sacr\xe9 bleu!')
+
+ def test_utf8_out(self):
+ # The internal data structures can be encoded as UTF-8.
+ soup_from_unicode = self.soup(self.unicode_data)
+ self.assertEqual(soup_from_unicode.encode('utf-8'), self.utf8_data)
+
+ @skipIf(
+ PYTHON_3_PRE_3_2,
+ "Bad HTMLParser detected; skipping test of non-ASCII characters in attribute name.")
+ def test_attribute_name_containing_unicode_characters(self):
+ markup = '<div><a \N{SNOWMAN}="snowman"></a></div>'
+ self.assertEqual(self.soup(markup).div.encode("utf8"), markup.encode("utf8"))
+
+class TestUnicodeDammit(unittest.TestCase):
+ """Standalone tests of UnicodeDammit."""
+
+ def test_unicode_input(self):
+ markup = "I'm already Unicode! \N{SNOWMAN}"
+ dammit = UnicodeDammit(markup)
+ self.assertEqual(dammit.unicode_markup, markup)
+
+ def test_smart_quotes_to_unicode(self):
+ markup = b"<foo>\x91\x92\x93\x94</foo>"
+ dammit = UnicodeDammit(markup)
+ self.assertEqual(
+ dammit.unicode_markup, "<foo>\u2018\u2019\u201c\u201d</foo>")
+
+ def test_smart_quotes_to_xml_entities(self):
+ markup = b"<foo>\x91\x92\x93\x94</foo>"
+ dammit = UnicodeDammit(markup, smart_quotes_to="xml")
+ self.assertEqual(
+ dammit.unicode_markup, "<foo>&#x2018;&#x2019;&#x201C;&#x201D;</foo>")
+
+ def test_smart_quotes_to_html_entities(self):
+ markup = b"<foo>\x91\x92\x93\x94</foo>"
+ dammit = UnicodeDammit(markup, smart_quotes_to="html")
+ self.assertEqual(
+ dammit.unicode_markup, "<foo>&lsquo;&rsquo;&ldquo;&rdquo;</foo>")
+
+ def test_smart_quotes_to_ascii(self):
+ markup = b"<foo>\x91\x92\x93\x94</foo>"
+ dammit = UnicodeDammit(markup, smart_quotes_to="ascii")
+ self.assertEqual(
+ dammit.unicode_markup, """<foo>''""</foo>""")
+
+ def test_detect_utf8(self):
+ utf8 = b"Sacr\xc3\xa9 bleu! \xe2\x98\x83"
+ dammit = UnicodeDammit(utf8)
+ self.assertEqual(dammit.original_encoding.lower(), 'utf-8')
+ self.assertEqual(dammit.unicode_markup, 'Sacr\xe9 bleu! \N{SNOWMAN}')
+
+
+ def test_convert_hebrew(self):
+ hebrew = b"\xed\xe5\xec\xf9"
+ dammit = UnicodeDammit(hebrew, ["iso-8859-8"])
+ self.assertEqual(dammit.original_encoding.lower(), 'iso-8859-8')
+ self.assertEqual(dammit.unicode_markup, '\u05dd\u05d5\u05dc\u05e9')
+
+ def test_dont_see_smart_quotes_where_there_are_none(self):
+ utf_8 = b"\343\202\261\343\203\274\343\202\277\343\202\244 Watch"
+ dammit = UnicodeDammit(utf_8)
+ self.assertEqual(dammit.original_encoding.lower(), 'utf-8')
+ self.assertEqual(dammit.unicode_markup.encode("utf-8"), utf_8)
+
+ def test_ignore_inappropriate_codecs(self):
+ utf8_data = "Räksmörgås".encode("utf-8")
+ dammit = UnicodeDammit(utf8_data, ["iso-8859-8"])
+ self.assertEqual(dammit.original_encoding.lower(), 'utf-8')
+
+ def test_ignore_invalid_codecs(self):
+ utf8_data = "Räksmörgås".encode("utf-8")
+ for bad_encoding in ['.utf8', '...', 'utF---16.!']:
+ dammit = UnicodeDammit(utf8_data, [bad_encoding])
+ self.assertEqual(dammit.original_encoding.lower(), 'utf-8')
+
+ def test_exclude_encodings(self):
+ # This is UTF-8.
+ utf8_data = "Räksmörgås".encode("utf-8")
+
+ # But if we exclude UTF-8 from consideration, the guess is
+ # Windows-1252.
+ dammit = UnicodeDammit(utf8_data, exclude_encodings=["utf-8"])
+ self.assertEqual(dammit.original_encoding.lower(), 'windows-1252')
+
+ # And if we exclude that, there is no valid guess at all.
+ dammit = UnicodeDammit(
+ utf8_data, exclude_encodings=["utf-8", "windows-1252"])
+ self.assertEqual(dammit.original_encoding, None)
+
+ def test_encoding_detector_replaces_junk_in_encoding_name_with_replacement_character(self):
+ detected = EncodingDetector(
+ b'<?xml version="1.0" encoding="UTF-\xdb" ?>')
+ encodings = list(detected.encodings)
+ assert 'utf-\N{REPLACEMENT CHARACTER}' in encodings
+
+ def test_detect_html5_style_meta_tag(self):
+
+ for data in (
+ b'<html><meta charset="euc-jp" /></html>',
+ b"<html><meta charset='euc-jp' /></html>",
+ b"<html><meta charset=euc-jp /></html>",
+ b"<html><meta charset=euc-jp/></html>"):
+ dammit = UnicodeDammit(data, is_html=True)
+ self.assertEqual(
+ "euc-jp", dammit.original_encoding)
+
+ def test_last_ditch_entity_replacement(self):
+ # This is a UTF-8 document that contains bytestrings
+ # completely incompatible with UTF-8 (ie. encoded with some other
+ # encoding).
+ #
+ # Since there is no consistent encoding for the document,
+ # Unicode, Dammit will eventually encode the document as UTF-8
+ # and encode the incompatible characters as REPLACEMENT
+ # CHARACTER.
+ #
+ # If chardet is installed, it will detect that the document
+ # can be converted into ISO-8859-1 without errors. This happens
+ # to be the wrong encoding, but it is a consistent encoding, so the
+ # code we're testing here won't run.
+ #
+ # So we temporarily disable chardet if it's present.
+ doc = b"""\357\273\277<?xml version="1.0" encoding="UTF-8"?>
+<html><b>\330\250\330\252\330\261</b>
+<i>\310\322\321\220\312\321\355\344</i></html>"""
+ chardet = bs4.dammit.chardet_dammit
+ logging.disable(logging.WARNING)
+ try:
+ def noop(str):
+ return None
+ bs4.dammit.chardet_dammit = noop
+ dammit = UnicodeDammit(doc)
+ self.assertEqual(True, dammit.contains_replacement_characters)
+ self.assertTrue("\ufffd" in dammit.unicode_markup)
+
+ soup = BeautifulSoup(doc, "html.parser")
+ self.assertTrue(soup.contains_replacement_characters)
+ finally:
+ logging.disable(logging.NOTSET)
+ bs4.dammit.chardet_dammit = chardet
+
+ def test_byte_order_mark_removed(self):
+ # A document written in UTF-16LE will have its byte order marker stripped.
+ data = b'\xff\xfe<\x00a\x00>\x00\xe1\x00\xe9\x00<\x00/\x00a\x00>\x00'
+ dammit = UnicodeDammit(data)
+ self.assertEqual("<a>áé</a>", dammit.unicode_markup)
+ self.assertEqual("utf-16le", dammit.original_encoding)
+
+ def test_detwingle(self):
+ # Here's a UTF8 document.
+ utf8 = ("\N{SNOWMAN}" * 3).encode("utf8")
+
+ # Here's a Windows-1252 document.
+ windows_1252 = (
+ "\N{LEFT DOUBLE QUOTATION MARK}Hi, I like Windows!"
+ "\N{RIGHT DOUBLE QUOTATION MARK}").encode("windows_1252")
+
+ # Through some unholy alchemy, they've been stuck together.
+ doc = utf8 + windows_1252 + utf8
+
+ # The document can't be turned into UTF-8:
+ self.assertRaises(UnicodeDecodeError, doc.decode, "utf8")
+
+ # Unicode, Dammit thinks the whole document is Windows-1252,
+ # and decodes it into "☃☃☃“Hi, I like Windows!”☃☃☃"
+
+ # But if we run it through fix_embedded_windows_1252, it's fixed:
+
+ fixed = UnicodeDammit.detwingle(doc)
+ self.assertEqual(
+ "☃☃☃“Hi, I like Windows!”☃☃☃", fixed.decode("utf8"))
+
+ def test_detwingle_ignores_multibyte_characters(self):
+ # Each of these characters has a UTF-8 representation ending
+ # in \x93. \x93 is a smart quote if interpreted as
+ # Windows-1252. But our code knows to skip over multibyte
+ # UTF-8 characters, so they'll survive the process unscathed.
+ for tricky_unicode_char in (
+ "\N{LATIN SMALL LIGATURE OE}", # 2-byte char '\xc5\x93'
+ "\N{LATIN SUBSCRIPT SMALL LETTER X}", # 3-byte char '\xe2\x82\x93'
+ "\xf0\x90\x90\x93", # This is a CJK character, not sure which one.
+ ):
+ input = tricky_unicode_char.encode("utf8")
+ self.assertTrue(input.endswith(b'\x93'))
+ output = UnicodeDammit.detwingle(input)
+ self.assertEqual(output, input)
+
+class TestNamedspacedAttribute(SoupTest):
+
+ def test_name_may_be_none(self):
+ a = NamespacedAttribute("xmlns", None)
+ self.assertEqual(a, "xmlns")
+
+ def test_attribute_is_equivalent_to_colon_separated_string(self):
+ a = NamespacedAttribute("a", "b")
+ self.assertEqual("a:b", a)
+
+ def test_attributes_are_equivalent_if_prefix_and_name_identical(self):
+ a = NamespacedAttribute("a", "b", "c")
+ b = NamespacedAttribute("a", "b", "c")
+ self.assertEqual(a, b)
+
+ # The actual namespace is not considered.
+ c = NamespacedAttribute("a", "b", None)
+ self.assertEqual(a, c)
+
+ # But name and prefix are important.
+ d = NamespacedAttribute("a", "z", "c")
+ self.assertNotEqual(a, d)
+
+ e = NamespacedAttribute("z", "b", "c")
+ self.assertNotEqual(a, e)
+
+
+class TestAttributeValueWithCharsetSubstitution(unittest.TestCase):
+
+ def test_content_meta_attribute_value(self):
+ value = CharsetMetaAttributeValue("euc-jp")
+ self.assertEqual("euc-jp", value)
+ self.assertEqual("euc-jp", value.original_value)
+ self.assertEqual("utf8", value.encode("utf8"))
+
+
+ def test_content_meta_attribute_value(self):
+ value = ContentMetaAttributeValue("text/html; charset=euc-jp")
+ self.assertEqual("text/html; charset=euc-jp", value)
+ self.assertEqual("text/html; charset=euc-jp", value.original_value)
+ self.assertEqual("text/html; charset=utf8", value.encode("utf8"))
diff --git a/libs/bs4/tests/test_tree.py b/libs/bs4/tests/test_tree.py
new file mode 100644
index 000000000..3b4beeb8f
--- /dev/null
+++ b/libs/bs4/tests/test_tree.py
@@ -0,0 +1,2205 @@
+# -*- coding: utf-8 -*-
+"""Tests for Beautiful Soup's tree traversal methods.
+
+The tree traversal methods are the main advantage of using Beautiful
+Soup over just using a parser.
+
+Different parsers will build different Beautiful Soup trees given the
+same markup, but all Beautiful Soup trees can be traversed with the
+methods tested here.
+"""
+
+from pdb import set_trace
+import copy
+import pickle
+import re
+import warnings
+from bs4 import BeautifulSoup
+from bs4.builder import (
+ builder_registry,
+ HTMLParserTreeBuilder,
+)
+from bs4.element import (
+ PY3K,
+ CData,
+ Comment,
+ Declaration,
+ Doctype,
+ Formatter,
+ NavigableString,
+ SoupStrainer,
+ Tag,
+)
+from bs4.testing import (
+ SoupTest,
+ skipIf,
+)
+
+XML_BUILDER_PRESENT = (builder_registry.lookup("xml") is not None)
+LXML_PRESENT = (builder_registry.lookup("lxml") is not None)
+
+class TreeTest(SoupTest):
+
+ def assertSelects(self, tags, should_match):
+ """Make sure that the given tags have the correct text.
+
+ This is used in tests that define a bunch of tags, each
+ containing a single string, and then select certain strings by
+ some mechanism.
+ """
+ self.assertEqual([tag.string for tag in tags], should_match)
+
+ def assertSelectsIDs(self, tags, should_match):
+ """Make sure that the given tags have the correct IDs.
+
+ This is used in tests that define a bunch of tags, each
+ containing a single string, and then select certain strings by
+ some mechanism.
+ """
+ self.assertEqual([tag['id'] for tag in tags], should_match)
+
+
+class TestFind(TreeTest):
+ """Basic tests of the find() method.
+
+ find() just calls find_all() with limit=1, so it's not tested all
+ that thouroughly here.
+ """
+
+ def test_find_tag(self):
+ soup = self.soup("<a>1</a><b>2</b><a>3</a><b>4</b>")
+ self.assertEqual(soup.find("b").string, "2")
+
+ def test_unicode_text_find(self):
+ soup = self.soup('<h1>Räksmörgås</h1>')
+ self.assertEqual(soup.find(string='Räksmörgås'), 'Räksmörgås')
+
+ def test_unicode_attribute_find(self):
+ soup = self.soup('<h1 id="Räksmörgås">here it is</h1>')
+ str(soup)
+ self.assertEqual("here it is", soup.find(id='Räksmörgås').text)
+
+
+ def test_find_everything(self):
+ """Test an optimization that finds all tags."""
+ soup = self.soup("<a>foo</a><b>bar</b>")
+ self.assertEqual(2, len(soup.find_all()))
+
+ def test_find_everything_with_name(self):
+ """Test an optimization that finds all tags with a given name."""
+ soup = self.soup("<a>foo</a><b>bar</b><a>baz</a>")
+ self.assertEqual(2, len(soup.find_all('a')))
+
+class TestFindAll(TreeTest):
+ """Basic tests of the find_all() method."""
+
+ def test_find_all_text_nodes(self):
+ """You can search the tree for text nodes."""
+ soup = self.soup("<html>Foo<b>bar</b>\xbb</html>")
+ # Exact match.
+ self.assertEqual(soup.find_all(string="bar"), ["bar"])
+ self.assertEqual(soup.find_all(text="bar"), ["bar"])
+ # Match any of a number of strings.
+ self.assertEqual(
+ soup.find_all(text=["Foo", "bar"]), ["Foo", "bar"])
+ # Match a regular expression.
+ self.assertEqual(soup.find_all(text=re.compile('.*')),
+ ["Foo", "bar", '\xbb'])
+ # Match anything.
+ self.assertEqual(soup.find_all(text=True),
+ ["Foo", "bar", '\xbb'])
+
+ def test_find_all_limit(self):
+ """You can limit the number of items returned by find_all."""
+ soup = self.soup("<a>1</a><a>2</a><a>3</a><a>4</a><a>5</a>")
+ self.assertSelects(soup.find_all('a', limit=3), ["1", "2", "3"])
+ self.assertSelects(soup.find_all('a', limit=1), ["1"])
+ self.assertSelects(
+ soup.find_all('a', limit=10), ["1", "2", "3", "4", "5"])
+
+ # A limit of 0 means no limit.
+ self.assertSelects(
+ soup.find_all('a', limit=0), ["1", "2", "3", "4", "5"])
+
+ def test_calling_a_tag_is_calling_findall(self):
+ soup = self.soup("<a>1</a><b>2<a id='foo'>3</a></b>")
+ self.assertSelects(soup('a', limit=1), ["1"])
+ self.assertSelects(soup.b(id="foo"), ["3"])
+
+ def test_find_all_with_self_referential_data_structure_does_not_cause_infinite_recursion(self):
+ soup = self.soup("<a></a>")
+ # Create a self-referential list.
+ l = []
+ l.append(l)
+
+ # Without special code in _normalize_search_value, this would cause infinite
+ # recursion.
+ self.assertEqual([], soup.find_all(l))
+
+ def test_find_all_resultset(self):
+ """All find_all calls return a ResultSet"""
+ soup = self.soup("<a></a>")
+ result = soup.find_all("a")
+ self.assertTrue(hasattr(result, "source"))
+
+ result = soup.find_all(True)
+ self.assertTrue(hasattr(result, "source"))
+
+ result = soup.find_all(text="foo")
+ self.assertTrue(hasattr(result, "source"))
+
+
+class TestFindAllBasicNamespaces(TreeTest):
+
+ def test_find_by_namespaced_name(self):
+ soup = self.soup('<mathml:msqrt>4</mathml:msqrt><a svg:fill="red">')
+ self.assertEqual("4", soup.find("mathml:msqrt").string)
+ self.assertEqual("a", soup.find(attrs= { "svg:fill" : "red" }).name)
+
+
+class TestFindAllByName(TreeTest):
+ """Test ways of finding tags by tag name."""
+
+ def setUp(self):
+ super(TreeTest, self).setUp()
+ self.tree = self.soup("""<a>First tag.</a>
+ <b>Second tag.</b>
+ <c>Third <a>Nested tag.</a> tag.</c>""")
+
+ def test_find_all_by_tag_name(self):
+ # Find all the <a> tags.
+ self.assertSelects(
+ self.tree.find_all('a'), ['First tag.', 'Nested tag.'])
+
+ def test_find_all_by_name_and_text(self):
+ self.assertSelects(
+ self.tree.find_all('a', text='First tag.'), ['First tag.'])
+
+ self.assertSelects(
+ self.tree.find_all('a', text=True), ['First tag.', 'Nested tag.'])
+
+ self.assertSelects(
+ self.tree.find_all('a', text=re.compile("tag")),
+ ['First tag.', 'Nested tag.'])
+
+
+ def test_find_all_on_non_root_element(self):
+ # You can call find_all on any node, not just the root.
+ self.assertSelects(self.tree.c.find_all('a'), ['Nested tag.'])
+
+ def test_calling_element_invokes_find_all(self):
+ self.assertSelects(self.tree('a'), ['First tag.', 'Nested tag.'])
+
+ def test_find_all_by_tag_strainer(self):
+ self.assertSelects(
+ self.tree.find_all(SoupStrainer('a')),
+ ['First tag.', 'Nested tag.'])
+
+ def test_find_all_by_tag_names(self):
+ self.assertSelects(
+ self.tree.find_all(['a', 'b']),
+ ['First tag.', 'Second tag.', 'Nested tag.'])
+
+ def test_find_all_by_tag_dict(self):
+ self.assertSelects(
+ self.tree.find_all({'a' : True, 'b' : True}),
+ ['First tag.', 'Second tag.', 'Nested tag.'])
+
+ def test_find_all_by_tag_re(self):
+ self.assertSelects(
+ self.tree.find_all(re.compile('^[ab]$')),
+ ['First tag.', 'Second tag.', 'Nested tag.'])
+
+ def test_find_all_with_tags_matching_method(self):
+ # You can define an oracle method that determines whether
+ # a tag matches the search.
+ def id_matches_name(tag):
+ return tag.name == tag.get('id')
+
+ tree = self.soup("""<a id="a">Match 1.</a>
+ <a id="1">Does not match.</a>
+ <b id="b">Match 2.</a>""")
+
+ self.assertSelects(
+ tree.find_all(id_matches_name), ["Match 1.", "Match 2."])
+
+ def test_find_with_multi_valued_attribute(self):
+ soup = self.soup(
+ "<div class='a b'>1</div><div class='a c'>2</div><div class='a d'>3</div>"
+ )
+ r1 = soup.find('div', 'a d');
+ r2 = soup.find('div', re.compile(r'a d'));
+ r3, r4 = soup.find_all('div', ['a b', 'a d']);
+ self.assertEqual('3', r1.string)
+ self.assertEqual('3', r2.string)
+ self.assertEqual('1', r3.string)
+ self.assertEqual('3', r4.string)
+
+
+class TestFindAllByAttribute(TreeTest):
+
+ def test_find_all_by_attribute_name(self):
+ # You can pass in keyword arguments to find_all to search by
+ # attribute.
+ tree = self.soup("""
+ <a id="first">Matching a.</a>
+ <a id="second">
+ Non-matching <b id="first">Matching b.</b>a.
+ </a>""")
+ self.assertSelects(tree.find_all(id='first'),
+ ["Matching a.", "Matching b."])
+
+ def test_find_all_by_utf8_attribute_value(self):
+ peace = "םולש".encode("utf8")
+ data = '<a title="םולש"></a>'.encode("utf8")
+ soup = self.soup(data)
+ self.assertEqual([soup.a], soup.find_all(title=peace))
+ self.assertEqual([soup.a], soup.find_all(title=peace.decode("utf8")))
+ self.assertEqual([soup.a], soup.find_all(title=[peace, "something else"]))
+
+ def test_find_all_by_attribute_dict(self):
+ # You can pass in a dictionary as the argument 'attrs'. This
+ # lets you search for attributes like 'name' (a fixed argument
+ # to find_all) and 'class' (a reserved word in Python.)
+ tree = self.soup("""
+ <a name="name1" class="class1">Name match.</a>
+ <a name="name2" class="class2">Class match.</a>
+ <a name="name3" class="class3">Non-match.</a>
+ <name1>A tag called 'name1'.</name1>
+ """)
+
+ # This doesn't do what you want.
+ self.assertSelects(tree.find_all(name='name1'),
+ ["A tag called 'name1'."])
+ # This does what you want.
+ self.assertSelects(tree.find_all(attrs={'name' : 'name1'}),
+ ["Name match."])
+
+ self.assertSelects(tree.find_all(attrs={'class' : 'class2'}),
+ ["Class match."])
+
+ def test_find_all_by_class(self):
+ tree = self.soup("""
+ <a class="1">Class 1.</a>
+ <a class="2">Class 2.</a>
+ <b class="1">Class 1.</b>
+ <c class="3 4">Class 3 and 4.</c>
+ """)
+
+ # Passing in the class_ keyword argument will search against
+ # the 'class' attribute.
+ self.assertSelects(tree.find_all('a', class_='1'), ['Class 1.'])
+ self.assertSelects(tree.find_all('c', class_='3'), ['Class 3 and 4.'])
+ self.assertSelects(tree.find_all('c', class_='4'), ['Class 3 and 4.'])
+
+ # Passing in a string to 'attrs' will also search the CSS class.
+ self.assertSelects(tree.find_all('a', '1'), ['Class 1.'])
+ self.assertSelects(tree.find_all(attrs='1'), ['Class 1.', 'Class 1.'])
+ self.assertSelects(tree.find_all('c', '3'), ['Class 3 and 4.'])
+ self.assertSelects(tree.find_all('c', '4'), ['Class 3 and 4.'])
+
+ def test_find_by_class_when_multiple_classes_present(self):
+ tree = self.soup("<gar class='foo bar'>Found it</gar>")
+
+ f = tree.find_all("gar", class_=re.compile("o"))
+ self.assertSelects(f, ["Found it"])
+
+ f = tree.find_all("gar", class_=re.compile("a"))
+ self.assertSelects(f, ["Found it"])
+
+ # If the search fails to match the individual strings "foo" and "bar",
+ # it will be tried against the combined string "foo bar".
+ f = tree.find_all("gar", class_=re.compile("o b"))
+ self.assertSelects(f, ["Found it"])
+
+ def test_find_all_with_non_dictionary_for_attrs_finds_by_class(self):
+ soup = self.soup("<a class='bar'>Found it</a>")
+
+ self.assertSelects(soup.find_all("a", re.compile("ba")), ["Found it"])
+
+ def big_attribute_value(value):
+ return len(value) > 3
+
+ self.assertSelects(soup.find_all("a", big_attribute_value), [])
+
+ def small_attribute_value(value):
+ return len(value) <= 3
+
+ self.assertSelects(
+ soup.find_all("a", small_attribute_value), ["Found it"])
+
+ def test_find_all_with_string_for_attrs_finds_multiple_classes(self):
+ soup = self.soup('<a class="foo bar"></a><a class="foo"></a>')
+ a, a2 = soup.find_all("a")
+ self.assertEqual([a, a2], soup.find_all("a", "foo"))
+ self.assertEqual([a], soup.find_all("a", "bar"))
+
+ # If you specify the class as a string that contains a
+ # space, only that specific value will be found.
+ self.assertEqual([a], soup.find_all("a", class_="foo bar"))
+ self.assertEqual([a], soup.find_all("a", "foo bar"))
+ self.assertEqual([], soup.find_all("a", "bar foo"))
+
+ def test_find_all_by_attribute_soupstrainer(self):
+ tree = self.soup("""
+ <a id="first">Match.</a>
+ <a id="second">Non-match.</a>""")
+
+ strainer = SoupStrainer(attrs={'id' : 'first'})
+ self.assertSelects(tree.find_all(strainer), ['Match.'])
+
+ def test_find_all_with_missing_attribute(self):
+ # You can pass in None as the value of an attribute to find_all.
+ # This will match tags that do not have that attribute set.
+ tree = self.soup("""<a id="1">ID present.</a>
+ <a>No ID present.</a>
+ <a id="">ID is empty.</a>""")
+ self.assertSelects(tree.find_all('a', id=None), ["No ID present."])
+
+ def test_find_all_with_defined_attribute(self):
+ # You can pass in None as the value of an attribute to find_all.
+ # This will match tags that have that attribute set to any value.
+ tree = self.soup("""<a id="1">ID present.</a>
+ <a>No ID present.</a>
+ <a id="">ID is empty.</a>""")
+ self.assertSelects(
+ tree.find_all(id=True), ["ID present.", "ID is empty."])
+
+ def test_find_all_with_numeric_attribute(self):
+ # If you search for a number, it's treated as a string.
+ tree = self.soup("""<a id=1>Unquoted attribute.</a>
+ <a id="1">Quoted attribute.</a>""")
+
+ expected = ["Unquoted attribute.", "Quoted attribute."]
+ self.assertSelects(tree.find_all(id=1), expected)
+ self.assertSelects(tree.find_all(id="1"), expected)
+
+ def test_find_all_with_list_attribute_values(self):
+ # You can pass a list of attribute values instead of just one,
+ # and you'll get tags that match any of the values.
+ tree = self.soup("""<a id="1">1</a>
+ <a id="2">2</a>
+ <a id="3">3</a>
+ <a>No ID.</a>""")
+ self.assertSelects(tree.find_all(id=["1", "3", "4"]),
+ ["1", "3"])
+
+ def test_find_all_with_regular_expression_attribute_value(self):
+ # You can pass a regular expression as an attribute value, and
+ # you'll get tags whose values for that attribute match the
+ # regular expression.
+ tree = self.soup("""<a id="a">One a.</a>
+ <a id="aa">Two as.</a>
+ <a id="ab">Mixed as and bs.</a>
+ <a id="b">One b.</a>
+ <a>No ID.</a>""")
+
+ self.assertSelects(tree.find_all(id=re.compile("^a+$")),
+ ["One a.", "Two as."])
+
+ def test_find_by_name_and_containing_string(self):
+ soup = self.soup("<b>foo</b><b>bar</b><a>foo</a>")
+ a = soup.a
+
+ self.assertEqual([a], soup.find_all("a", text="foo"))
+ self.assertEqual([], soup.find_all("a", text="bar"))
+ self.assertEqual([], soup.find_all("a", text="bar"))
+
+ def test_find_by_name_and_containing_string_when_string_is_buried(self):
+ soup = self.soup("<a>foo</a><a><b><c>foo</c></b></a>")
+ self.assertEqual(soup.find_all("a"), soup.find_all("a", text="foo"))
+
+ def test_find_by_attribute_and_containing_string(self):
+ soup = self.soup('<b id="1">foo</b><a id="2">foo</a>')
+ a = soup.a
+
+ self.assertEqual([a], soup.find_all(id=2, text="foo"))
+ self.assertEqual([], soup.find_all(id=1, text="bar"))
+
+
+class TestSmooth(TreeTest):
+ """Test Tag.smooth."""
+
+ def test_smooth(self):
+ soup = self.soup("<div>a</div>")
+ div = soup.div
+ div.append("b")
+ div.append("c")
+ div.append(Comment("Comment 1"))
+ div.append(Comment("Comment 2"))
+ div.append("d")
+ builder = self.default_builder()
+ span = Tag(soup, builder, 'span')
+ span.append('1')
+ span.append('2')
+ div.append(span)
+
+ # At this point the tree has a bunch of adjacent
+ # NavigableStrings. This is normal, but it has no meaning in
+ # terms of HTML, so we may want to smooth things out for
+ # output.
+
+ # Since the <span> tag has two children, its .string is None.
+ self.assertEqual(None, div.span.string)
+
+ self.assertEqual(7, len(div.contents))
+ div.smooth()
+ self.assertEqual(5, len(div.contents))
+
+ # The three strings at the beginning of div.contents have been
+ # merged into on string.
+ #
+ self.assertEqual('abc', div.contents[0])
+
+ # The call is recursive -- the <span> tag was also smoothed.
+ self.assertEqual('12', div.span.string)
+
+ # The two comments have _not_ been merged, even though
+ # comments are strings. Merging comments would change the
+ # meaning of the HTML.
+ self.assertEqual('Comment 1', div.contents[1])
+ self.assertEqual('Comment 2', div.contents[2])
+
+
+class TestIndex(TreeTest):
+ """Test Tag.index"""
+ def test_index(self):
+ tree = self.soup("""<div>
+ <a>Identical</a>
+ <b>Not identical</b>
+ <a>Identical</a>
+
+ <c><d>Identical with child</d></c>
+ <b>Also not identical</b>
+ <c><d>Identical with child</d></c>
+ </div>""")
+ div = tree.div
+ for i, element in enumerate(div.contents):
+ self.assertEqual(i, div.index(element))
+ self.assertRaises(ValueError, tree.index, 1)
+
+
+class TestParentOperations(TreeTest):
+ """Test navigation and searching through an element's parents."""
+
+ def setUp(self):
+ super(TestParentOperations, self).setUp()
+ self.tree = self.soup('''<ul id="empty"></ul>
+ <ul id="top">
+ <ul id="middle">
+ <ul id="bottom">
+ <b>Start here</b>
+ </ul>
+ </ul>''')
+ self.start = self.tree.b
+
+
+ def test_parent(self):
+ self.assertEqual(self.start.parent['id'], 'bottom')
+ self.assertEqual(self.start.parent.parent['id'], 'middle')
+ self.assertEqual(self.start.parent.parent.parent['id'], 'top')
+
+ def test_parent_of_top_tag_is_soup_object(self):
+ top_tag = self.tree.contents[0]
+ self.assertEqual(top_tag.parent, self.tree)
+
+ def test_soup_object_has_no_parent(self):
+ self.assertEqual(None, self.tree.parent)
+
+ def test_find_parents(self):
+ self.assertSelectsIDs(
+ self.start.find_parents('ul'), ['bottom', 'middle', 'top'])
+ self.assertSelectsIDs(
+ self.start.find_parents('ul', id="middle"), ['middle'])
+
+ def test_find_parent(self):
+ self.assertEqual(self.start.find_parent('ul')['id'], 'bottom')
+ self.assertEqual(self.start.find_parent('ul', id='top')['id'], 'top')
+
+ def test_parent_of_text_element(self):
+ text = self.tree.find(text="Start here")
+ self.assertEqual(text.parent.name, 'b')
+
+ def test_text_element_find_parent(self):
+ text = self.tree.find(text="Start here")
+ self.assertEqual(text.find_parent('ul')['id'], 'bottom')
+
+ def test_parent_generator(self):
+ parents = [parent['id'] for parent in self.start.parents
+ if parent is not None and 'id' in parent.attrs]
+ self.assertEqual(parents, ['bottom', 'middle', 'top'])
+
+
+class ProximityTest(TreeTest):
+
+ def setUp(self):
+ super(TreeTest, self).setUp()
+ self.tree = self.soup(
+ '<html id="start"><head></head><body><b id="1">One</b><b id="2">Two</b><b id="3">Three</b></body></html>')
+
+
+class TestNextOperations(ProximityTest):
+
+ def setUp(self):
+ super(TestNextOperations, self).setUp()
+ self.start = self.tree.b
+
+ def test_next(self):
+ self.assertEqual(self.start.next_element, "One")
+ self.assertEqual(self.start.next_element.next_element['id'], "2")
+
+ def test_next_of_last_item_is_none(self):
+ last = self.tree.find(text="Three")
+ self.assertEqual(last.next_element, None)
+
+ def test_next_of_root_is_none(self):
+ # The document root is outside the next/previous chain.
+ self.assertEqual(self.tree.next_element, None)
+
+ def test_find_all_next(self):
+ self.assertSelects(self.start.find_all_next('b'), ["Two", "Three"])
+ self.start.find_all_next(id=3)
+ self.assertSelects(self.start.find_all_next(id=3), ["Three"])
+
+ def test_find_next(self):
+ self.assertEqual(self.start.find_next('b')['id'], '2')
+ self.assertEqual(self.start.find_next(text="Three"), "Three")
+
+ def test_find_next_for_text_element(self):
+ text = self.tree.find(text="One")
+ self.assertEqual(text.find_next("b").string, "Two")
+ self.assertSelects(text.find_all_next("b"), ["Two", "Three"])
+
+ def test_next_generator(self):
+ start = self.tree.find(text="Two")
+ successors = [node for node in start.next_elements]
+ # There are two successors: the final <b> tag and its text contents.
+ tag, contents = successors
+ self.assertEqual(tag['id'], '3')
+ self.assertEqual(contents, "Three")
+
+class TestPreviousOperations(ProximityTest):
+
+ def setUp(self):
+ super(TestPreviousOperations, self).setUp()
+ self.end = self.tree.find(text="Three")
+
+ def test_previous(self):
+ self.assertEqual(self.end.previous_element['id'], "3")
+ self.assertEqual(self.end.previous_element.previous_element, "Two")
+
+ def test_previous_of_first_item_is_none(self):
+ first = self.tree.find('html')
+ self.assertEqual(first.previous_element, None)
+
+ def test_previous_of_root_is_none(self):
+ # The document root is outside the next/previous chain.
+ # XXX This is broken!
+ #self.assertEqual(self.tree.previous_element, None)
+ pass
+
+ def test_find_all_previous(self):
+ # The <b> tag containing the "Three" node is the predecessor
+ # of the "Three" node itself, which is why "Three" shows up
+ # here.
+ self.assertSelects(
+ self.end.find_all_previous('b'), ["Three", "Two", "One"])
+ self.assertSelects(self.end.find_all_previous(id=1), ["One"])
+
+ def test_find_previous(self):
+ self.assertEqual(self.end.find_previous('b')['id'], '3')
+ self.assertEqual(self.end.find_previous(text="One"), "One")
+
+ def test_find_previous_for_text_element(self):
+ text = self.tree.find(text="Three")
+ self.assertEqual(text.find_previous("b").string, "Three")
+ self.assertSelects(
+ text.find_all_previous("b"), ["Three", "Two", "One"])
+
+ def test_previous_generator(self):
+ start = self.tree.find(text="One")
+ predecessors = [node for node in start.previous_elements]
+
+ # There are four predecessors: the <b> tag containing "One"
+ # the <body> tag, the <head> tag, and the <html> tag.
+ b, body, head, html = predecessors
+ self.assertEqual(b['id'], '1')
+ self.assertEqual(body.name, "body")
+ self.assertEqual(head.name, "head")
+ self.assertEqual(html.name, "html")
+
+
+class SiblingTest(TreeTest):
+
+ def setUp(self):
+ super(SiblingTest, self).setUp()
+ markup = '''<html>
+ <span id="1">
+ <span id="1.1"></span>
+ </span>
+ <span id="2">
+ <span id="2.1"></span>
+ </span>
+ <span id="3">
+ <span id="3.1"></span>
+ </span>
+ <span id="4"></span>
+ </html>'''
+ # All that whitespace looks good but makes the tests more
+ # difficult. Get rid of it.
+ markup = re.compile(r"\n\s*").sub("", markup)
+ self.tree = self.soup(markup)
+
+
+class TestNextSibling(SiblingTest):
+
+ def setUp(self):
+ super(TestNextSibling, self).setUp()
+ self.start = self.tree.find(id="1")
+
+ def test_next_sibling_of_root_is_none(self):
+ self.assertEqual(self.tree.next_sibling, None)
+
+ def test_next_sibling(self):
+ self.assertEqual(self.start.next_sibling['id'], '2')
+ self.assertEqual(self.start.next_sibling.next_sibling['id'], '3')
+
+ # Note the difference between next_sibling and next_element.
+ self.assertEqual(self.start.next_element['id'], '1.1')
+
+ def test_next_sibling_may_not_exist(self):
+ self.assertEqual(self.tree.html.next_sibling, None)
+
+ nested_span = self.tree.find(id="1.1")
+ self.assertEqual(nested_span.next_sibling, None)
+
+ last_span = self.tree.find(id="4")
+ self.assertEqual(last_span.next_sibling, None)
+
+ def test_find_next_sibling(self):
+ self.assertEqual(self.start.find_next_sibling('span')['id'], '2')
+
+ def test_next_siblings(self):
+ self.assertSelectsIDs(self.start.find_next_siblings("span"),
+ ['2', '3', '4'])
+
+ self.assertSelectsIDs(self.start.find_next_siblings(id='3'), ['3'])
+
+ def test_next_sibling_for_text_element(self):
+ soup = self.soup("Foo<b>bar</b>baz")
+ start = soup.find(text="Foo")
+ self.assertEqual(start.next_sibling.name, 'b')
+ self.assertEqual(start.next_sibling.next_sibling, 'baz')
+
+ self.assertSelects(start.find_next_siblings('b'), ['bar'])
+ self.assertEqual(start.find_next_sibling(text="baz"), "baz")
+ self.assertEqual(start.find_next_sibling(text="nonesuch"), None)
+
+
+class TestPreviousSibling(SiblingTest):
+
+ def setUp(self):
+ super(TestPreviousSibling, self).setUp()
+ self.end = self.tree.find(id="4")
+
+ def test_previous_sibling_of_root_is_none(self):
+ self.assertEqual(self.tree.previous_sibling, None)
+
+ def test_previous_sibling(self):
+ self.assertEqual(self.end.previous_sibling['id'], '3')
+ self.assertEqual(self.end.previous_sibling.previous_sibling['id'], '2')
+
+ # Note the difference between previous_sibling and previous_element.
+ self.assertEqual(self.end.previous_element['id'], '3.1')
+
+ def test_previous_sibling_may_not_exist(self):
+ self.assertEqual(self.tree.html.previous_sibling, None)
+
+ nested_span = self.tree.find(id="1.1")
+ self.assertEqual(nested_span.previous_sibling, None)
+
+ first_span = self.tree.find(id="1")
+ self.assertEqual(first_span.previous_sibling, None)
+
+ def test_find_previous_sibling(self):
+ self.assertEqual(self.end.find_previous_sibling('span')['id'], '3')
+
+ def test_previous_siblings(self):
+ self.assertSelectsIDs(self.end.find_previous_siblings("span"),
+ ['3', '2', '1'])
+
+ self.assertSelectsIDs(self.end.find_previous_siblings(id='1'), ['1'])
+
+ def test_previous_sibling_for_text_element(self):
+ soup = self.soup("Foo<b>bar</b>baz")
+ start = soup.find(text="baz")
+ self.assertEqual(start.previous_sibling.name, 'b')
+ self.assertEqual(start.previous_sibling.previous_sibling, 'Foo')
+
+ self.assertSelects(start.find_previous_siblings('b'), ['bar'])
+ self.assertEqual(start.find_previous_sibling(text="Foo"), "Foo")
+ self.assertEqual(start.find_previous_sibling(text="nonesuch"), None)
+
+
+class TestTagCreation(SoupTest):
+ """Test the ability to create new tags."""
+ def test_new_tag(self):
+ soup = self.soup("")
+ new_tag = soup.new_tag("foo", bar="baz", attrs={"name": "a name"})
+ self.assertTrue(isinstance(new_tag, Tag))
+ self.assertEqual("foo", new_tag.name)
+ self.assertEqual(dict(bar="baz", name="a name"), new_tag.attrs)
+ self.assertEqual(None, new_tag.parent)
+
+ def test_tag_inherits_self_closing_rules_from_builder(self):
+ if XML_BUILDER_PRESENT:
+ xml_soup = BeautifulSoup("", "lxml-xml")
+ xml_br = xml_soup.new_tag("br")
+ xml_p = xml_soup.new_tag("p")
+
+ # Both the <br> and <p> tag are empty-element, just because
+ # they have no contents.
+ self.assertEqual(b"<br/>", xml_br.encode())
+ self.assertEqual(b"<p/>", xml_p.encode())
+
+ html_soup = BeautifulSoup("", "html.parser")
+ html_br = html_soup.new_tag("br")
+ html_p = html_soup.new_tag("p")
+
+ # The HTML builder users HTML's rules about which tags are
+ # empty-element tags, and the new tags reflect these rules.
+ self.assertEqual(b"<br/>", html_br.encode())
+ self.assertEqual(b"<p></p>", html_p.encode())
+
+ def test_new_string_creates_navigablestring(self):
+ soup = self.soup("")
+ s = soup.new_string("foo")
+ self.assertEqual("foo", s)
+ self.assertTrue(isinstance(s, NavigableString))
+
+ def test_new_string_can_create_navigablestring_subclass(self):
+ soup = self.soup("")
+ s = soup.new_string("foo", Comment)
+ self.assertEqual("foo", s)
+ self.assertTrue(isinstance(s, Comment))
+
+class TestTreeModification(SoupTest):
+
+ def test_attribute_modification(self):
+ soup = self.soup('<a id="1"></a>')
+ soup.a['id'] = 2
+ self.assertEqual(soup.decode(), self.document_for('<a id="2"></a>'))
+ del(soup.a['id'])
+ self.assertEqual(soup.decode(), self.document_for('<a></a>'))
+ soup.a['id2'] = 'foo'
+ self.assertEqual(soup.decode(), self.document_for('<a id2="foo"></a>'))
+
+ def test_new_tag_creation(self):
+ builder = builder_registry.lookup('html')()
+ soup = self.soup("<body></body>", builder=builder)
+ a = Tag(soup, builder, 'a')
+ ol = Tag(soup, builder, 'ol')
+ a['href'] = 'http://foo.com/'
+ soup.body.insert(0, a)
+ soup.body.insert(1, ol)
+ self.assertEqual(
+ soup.body.encode(),
+ b'<body><a href="http://foo.com/"></a><ol></ol></body>')
+
+ def test_append_to_contents_moves_tag(self):
+ doc = """<p id="1">Don't leave me <b>here</b>.</p>
+ <p id="2">Don\'t leave!</p>"""
+ soup = self.soup(doc)
+ second_para = soup.find(id='2')
+ bold = soup.b
+
+ # Move the <b> tag to the end of the second paragraph.
+ soup.find(id='2').append(soup.b)
+
+ # The <b> tag is now a child of the second paragraph.
+ self.assertEqual(bold.parent, second_para)
+
+ self.assertEqual(
+ soup.decode(), self.document_for(
+ '<p id="1">Don\'t leave me .</p>\n'
+ '<p id="2">Don\'t leave!<b>here</b></p>'))
+
+ def test_replace_with_returns_thing_that_was_replaced(self):
+ text = "<a></a><b><c></c></b>"
+ soup = self.soup(text)
+ a = soup.a
+ new_a = a.replace_with(soup.c)
+ self.assertEqual(a, new_a)
+
+ def test_unwrap_returns_thing_that_was_replaced(self):
+ text = "<a><b></b><c></c></a>"
+ soup = self.soup(text)
+ a = soup.a
+ new_a = a.unwrap()
+ self.assertEqual(a, new_a)
+
+ def test_replace_with_and_unwrap_give_useful_exception_when_tag_has_no_parent(self):
+ soup = self.soup("<a><b>Foo</b></a><c>Bar</c>")
+ a = soup.a
+ a.extract()
+ self.assertEqual(None, a.parent)
+ self.assertRaises(ValueError, a.unwrap)
+ self.assertRaises(ValueError, a.replace_with, soup.c)
+
+ def test_replace_tag_with_itself(self):
+ text = "<a><b></b><c>Foo<d></d></c></a><a><e></e></a>"
+ soup = self.soup(text)
+ c = soup.c
+ soup.c.replace_with(c)
+ self.assertEqual(soup.decode(), self.document_for(text))
+
+ def test_replace_tag_with_its_parent_raises_exception(self):
+ text = "<a><b></b></a>"
+ soup = self.soup(text)
+ self.assertRaises(ValueError, soup.b.replace_with, soup.a)
+
+ def test_insert_tag_into_itself_raises_exception(self):
+ text = "<a><b></b></a>"
+ soup = self.soup(text)
+ self.assertRaises(ValueError, soup.a.insert, 0, soup.a)
+
+ def test_insert_beautifulsoup_object_inserts_children(self):
+ """Inserting one BeautifulSoup object into another actually inserts all
+ of its children -- you'll never combine BeautifulSoup objects.
+ """
+ soup = self.soup("<p>And now, a word:</p><p>And we're back.</p>")
+
+ text = "<p>p2</p><p>p3</p>"
+ to_insert = self.soup(text)
+ soup.insert(1, to_insert)
+
+ for i in soup.descendants:
+ assert not isinstance(i, BeautifulSoup)
+
+ p1, p2, p3, p4 = list(soup.children)
+ self.assertEqual("And now, a word:", p1.string)
+ self.assertEqual("p2", p2.string)
+ self.assertEqual("p3", p3.string)
+ self.assertEqual("And we're back.", p4.string)
+
+
+ def test_replace_with_maintains_next_element_throughout(self):
+ soup = self.soup('<p><a>one</a><b>three</b></p>')
+ a = soup.a
+ b = a.contents[0]
+ # Make it so the <a> tag has two text children.
+ a.insert(1, "two")
+
+ # Now replace each one with the empty string.
+ left, right = a.contents
+ left.replaceWith('')
+ right.replaceWith('')
+
+ # The <b> tag is still connected to the tree.
+ self.assertEqual("three", soup.b.string)
+
+ def test_replace_final_node(self):
+ soup = self.soup("<b>Argh!</b>")
+ soup.find(text="Argh!").replace_with("Hooray!")
+ new_text = soup.find(text="Hooray!")
+ b = soup.b
+ self.assertEqual(new_text.previous_element, b)
+ self.assertEqual(new_text.parent, b)
+ self.assertEqual(new_text.previous_element.next_element, new_text)
+ self.assertEqual(new_text.next_element, None)
+
+ def test_consecutive_text_nodes(self):
+ # A builder should never create two consecutive text nodes,
+ # but if you insert one next to another, Beautiful Soup will
+ # handle it correctly.
+ soup = self.soup("<a><b>Argh!</b><c></c></a>")
+ soup.b.insert(1, "Hooray!")
+
+ self.assertEqual(
+ soup.decode(), self.document_for(
+ "<a><b>Argh!Hooray!</b><c></c></a>"))
+
+ new_text = soup.find(text="Hooray!")
+ self.assertEqual(new_text.previous_element, "Argh!")
+ self.assertEqual(new_text.previous_element.next_element, new_text)
+
+ self.assertEqual(new_text.previous_sibling, "Argh!")
+ self.assertEqual(new_text.previous_sibling.next_sibling, new_text)
+
+ self.assertEqual(new_text.next_sibling, None)
+ self.assertEqual(new_text.next_element, soup.c)
+
+ def test_insert_string(self):
+ soup = self.soup("<a></a>")
+ soup.a.insert(0, "bar")
+ soup.a.insert(0, "foo")
+ # The string were added to the tag.
+ self.assertEqual(["foo", "bar"], soup.a.contents)
+ # And they were converted to NavigableStrings.
+ self.assertEqual(soup.a.contents[0].next_element, "bar")
+
+ def test_insert_tag(self):
+ builder = self.default_builder()
+ soup = self.soup(
+ "<a><b>Find</b><c>lady!</c><d></d></a>", builder=builder)
+ magic_tag = Tag(soup, builder, 'magictag')
+ magic_tag.insert(0, "the")
+ soup.a.insert(1, magic_tag)
+
+ self.assertEqual(
+ soup.decode(), self.document_for(
+ "<a><b>Find</b><magictag>the</magictag><c>lady!</c><d></d></a>"))
+
+ # Make sure all the relationships are hooked up correctly.
+ b_tag = soup.b
+ self.assertEqual(b_tag.next_sibling, magic_tag)
+ self.assertEqual(magic_tag.previous_sibling, b_tag)
+
+ find = b_tag.find(text="Find")
+ self.assertEqual(find.next_element, magic_tag)
+ self.assertEqual(magic_tag.previous_element, find)
+
+ c_tag = soup.c
+ self.assertEqual(magic_tag.next_sibling, c_tag)
+ self.assertEqual(c_tag.previous_sibling, magic_tag)
+
+ the = magic_tag.find(text="the")
+ self.assertEqual(the.parent, magic_tag)
+ self.assertEqual(the.next_element, c_tag)
+ self.assertEqual(c_tag.previous_element, the)
+
+ def test_append_child_thats_already_at_the_end(self):
+ data = "<a><b></b></a>"
+ soup = self.soup(data)
+ soup.a.append(soup.b)
+ self.assertEqual(data, soup.decode())
+
+ def test_extend(self):
+ data = "<a><b><c><d><e><f><g></g></f></e></d></c></b></a>"
+ soup = self.soup(data)
+ l = [soup.g, soup.f, soup.e, soup.d, soup.c, soup.b]
+ soup.a.extend(l)
+ self.assertEqual("<a><g></g><f></f><e></e><d></d><c></c><b></b></a>", soup.decode())
+
+ def test_move_tag_to_beginning_of_parent(self):
+ data = "<a><b></b><c></c><d></d></a>"
+ soup = self.soup(data)
+ soup.a.insert(0, soup.d)
+ self.assertEqual("<a><d></d><b></b><c></c></a>", soup.decode())
+
+ def test_insert_works_on_empty_element_tag(self):
+ # This is a little strange, since most HTML parsers don't allow
+ # markup like this to come through. But in general, we don't
+ # know what the parser would or wouldn't have allowed, so
+ # I'm letting this succeed for now.
+ soup = self.soup("<br/>")
+ soup.br.insert(1, "Contents")
+ self.assertEqual(str(soup.br), "<br>Contents</br>")
+
+ def test_insert_before(self):
+ soup = self.soup("<a>foo</a><b>bar</b>")
+ soup.b.insert_before("BAZ")
+ soup.a.insert_before("QUUX")
+ self.assertEqual(
+ soup.decode(), self.document_for("QUUX<a>foo</a>BAZ<b>bar</b>"))
+
+ soup.a.insert_before(soup.b)
+ self.assertEqual(
+ soup.decode(), self.document_for("QUUX<b>bar</b><a>foo</a>BAZ"))
+
+ # Can't insert an element before itself.
+ b = soup.b
+ self.assertRaises(ValueError, b.insert_before, b)
+
+ # Can't insert before if an element has no parent.
+ b.extract()
+ self.assertRaises(ValueError, b.insert_before, "nope")
+
+ # Can insert an identical element
+ soup = self.soup("<a>")
+ soup.a.insert_before(soup.new_tag("a"))
+
+ def test_insert_multiple_before(self):
+ soup = self.soup("<a>foo</a><b>bar</b>")
+ soup.b.insert_before("BAZ", " ", "QUUX")
+ soup.a.insert_before("QUUX", " ", "BAZ")
+ self.assertEqual(
+ soup.decode(), self.document_for("QUUX BAZ<a>foo</a>BAZ QUUX<b>bar</b>"))
+
+ soup.a.insert_before(soup.b, "FOO")
+ self.assertEqual(
+ soup.decode(), self.document_for("QUUX BAZ<b>bar</b>FOO<a>foo</a>BAZ QUUX"))
+
+ def test_insert_after(self):
+ soup = self.soup("<a>foo</a><b>bar</b>")
+ soup.b.insert_after("BAZ")
+ soup.a.insert_after("QUUX")
+ self.assertEqual(
+ soup.decode(), self.document_for("<a>foo</a>QUUX<b>bar</b>BAZ"))
+ soup.b.insert_after(soup.a)
+ self.assertEqual(
+ soup.decode(), self.document_for("QUUX<b>bar</b><a>foo</a>BAZ"))
+
+ # Can't insert an element after itself.
+ b = soup.b
+ self.assertRaises(ValueError, b.insert_after, b)
+
+ # Can't insert after if an element has no parent.
+ b.extract()
+ self.assertRaises(ValueError, b.insert_after, "nope")
+
+ # Can insert an identical element
+ soup = self.soup("<a>")
+ soup.a.insert_before(soup.new_tag("a"))
+
+ def test_insert_multiple_after(self):
+ soup = self.soup("<a>foo</a><b>bar</b>")
+ soup.b.insert_after("BAZ", " ", "QUUX")
+ soup.a.insert_after("QUUX", " ", "BAZ")
+ self.assertEqual(
+ soup.decode(), self.document_for("<a>foo</a>QUUX BAZ<b>bar</b>BAZ QUUX"))
+ soup.b.insert_after(soup.a, "FOO ")
+ self.assertEqual(
+ soup.decode(), self.document_for("QUUX BAZ<b>bar</b><a>foo</a>FOO BAZ QUUX"))
+
+ def test_insert_after_raises_exception_if_after_has_no_meaning(self):
+ soup = self.soup("")
+ tag = soup.new_tag("a")
+ string = soup.new_string("")
+ self.assertRaises(ValueError, string.insert_after, tag)
+ self.assertRaises(NotImplementedError, soup.insert_after, tag)
+ self.assertRaises(ValueError, tag.insert_after, tag)
+
+ def test_insert_before_raises_notimplementederror_if_before_has_no_meaning(self):
+ soup = self.soup("")
+ tag = soup.new_tag("a")
+ string = soup.new_string("")
+ self.assertRaises(ValueError, string.insert_before, tag)
+ self.assertRaises(NotImplementedError, soup.insert_before, tag)
+ self.assertRaises(ValueError, tag.insert_before, tag)
+
+ def test_replace_with(self):
+ soup = self.soup(
+ "<p>There's <b>no</b> business like <b>show</b> business</p>")
+ no, show = soup.find_all('b')
+ show.replace_with(no)
+ self.assertEqual(
+ soup.decode(),
+ self.document_for(
+ "<p>There's business like <b>no</b> business</p>"))
+
+ self.assertEqual(show.parent, None)
+ self.assertEqual(no.parent, soup.p)
+ self.assertEqual(no.next_element, "no")
+ self.assertEqual(no.next_sibling, " business")
+
+ def test_replace_first_child(self):
+ data = "<a><b></b><c></c></a>"
+ soup = self.soup(data)
+ soup.b.replace_with(soup.c)
+ self.assertEqual("<a><c></c></a>", soup.decode())
+
+ def test_replace_last_child(self):
+ data = "<a><b></b><c></c></a>"
+ soup = self.soup(data)
+ soup.c.replace_with(soup.b)
+ self.assertEqual("<a><b></b></a>", soup.decode())
+
+ def test_nested_tag_replace_with(self):
+ soup = self.soup(
+ """<a>We<b>reserve<c>the</c><d>right</d></b></a><e>to<f>refuse</f><g>service</g></e>""")
+
+ # Replace the entire <b> tag and its contents ("reserve the
+ # right") with the <f> tag ("refuse").
+ remove_tag = soup.b
+ move_tag = soup.f
+ remove_tag.replace_with(move_tag)
+
+ self.assertEqual(
+ soup.decode(), self.document_for(
+ "<a>We<f>refuse</f></a><e>to<g>service</g></e>"))
+
+ # The <b> tag is now an orphan.
+ self.assertEqual(remove_tag.parent, None)
+ self.assertEqual(remove_tag.find(text="right").next_element, None)
+ self.assertEqual(remove_tag.previous_element, None)
+ self.assertEqual(remove_tag.next_sibling, None)
+ self.assertEqual(remove_tag.previous_sibling, None)
+
+ # The <f> tag is now connected to the <a> tag.
+ self.assertEqual(move_tag.parent, soup.a)
+ self.assertEqual(move_tag.previous_element, "We")
+ self.assertEqual(move_tag.next_element.next_element, soup.e)
+ self.assertEqual(move_tag.next_sibling, None)
+
+ # The gap where the <f> tag used to be has been mended, and
+ # the word "to" is now connected to the <g> tag.
+ to_text = soup.find(text="to")
+ g_tag = soup.g
+ self.assertEqual(to_text.next_element, g_tag)
+ self.assertEqual(to_text.next_sibling, g_tag)
+ self.assertEqual(g_tag.previous_element, to_text)
+ self.assertEqual(g_tag.previous_sibling, to_text)
+
+ def test_unwrap(self):
+ tree = self.soup("""
+ <p>Unneeded <em>formatting</em> is unneeded</p>
+ """)
+ tree.em.unwrap()
+ self.assertEqual(tree.em, None)
+ self.assertEqual(tree.p.text, "Unneeded formatting is unneeded")
+
+ def test_wrap(self):
+ soup = self.soup("I wish I was bold.")
+ value = soup.string.wrap(soup.new_tag("b"))
+ self.assertEqual(value.decode(), "<b>I wish I was bold.</b>")
+ self.assertEqual(
+ soup.decode(), self.document_for("<b>I wish I was bold.</b>"))
+
+ def test_wrap_extracts_tag_from_elsewhere(self):
+ soup = self.soup("<b></b>I wish I was bold.")
+ soup.b.next_sibling.wrap(soup.b)
+ self.assertEqual(
+ soup.decode(), self.document_for("<b>I wish I was bold.</b>"))
+
+ def test_wrap_puts_new_contents_at_the_end(self):
+ soup = self.soup("<b>I like being bold.</b>I wish I was bold.")
+ soup.b.next_sibling.wrap(soup.b)
+ self.assertEqual(2, len(soup.b.contents))
+ self.assertEqual(
+ soup.decode(), self.document_for(
+ "<b>I like being bold.I wish I was bold.</b>"))
+
+ def test_extract(self):
+ soup = self.soup(
+ '<html><body>Some content. <div id="nav">Nav crap</div> More content.</body></html>')
+
+ self.assertEqual(len(soup.body.contents), 3)
+ extracted = soup.find(id="nav").extract()
+
+ self.assertEqual(
+ soup.decode(), "<html><body>Some content. More content.</body></html>")
+ self.assertEqual(extracted.decode(), '<div id="nav">Nav crap</div>')
+
+ # The extracted tag is now an orphan.
+ self.assertEqual(len(soup.body.contents), 2)
+ self.assertEqual(extracted.parent, None)
+ self.assertEqual(extracted.previous_element, None)
+ self.assertEqual(extracted.next_element.next_element, None)
+
+ # The gap where the extracted tag used to be has been mended.
+ content_1 = soup.find(text="Some content. ")
+ content_2 = soup.find(text=" More content.")
+ self.assertEqual(content_1.next_element, content_2)
+ self.assertEqual(content_1.next_sibling, content_2)
+ self.assertEqual(content_2.previous_element, content_1)
+ self.assertEqual(content_2.previous_sibling, content_1)
+
+ def test_extract_distinguishes_between_identical_strings(self):
+ soup = self.soup("<a>foo</a><b>bar</b>")
+ foo_1 = soup.a.string
+ bar_1 = soup.b.string
+ foo_2 = soup.new_string("foo")
+ bar_2 = soup.new_string("bar")
+ soup.a.append(foo_2)
+ soup.b.append(bar_2)
+
+ # Now there are two identical strings in the <a> tag, and two
+ # in the <b> tag. Let's remove the first "foo" and the second
+ # "bar".
+ foo_1.extract()
+ bar_2.extract()
+ self.assertEqual(foo_2, soup.a.string)
+ self.assertEqual(bar_2, soup.b.string)
+
+ def test_extract_multiples_of_same_tag(self):
+ soup = self.soup("""
+<html>
+<head>
+<script>foo</script>
+</head>
+<body>
+ <script>bar</script>
+ <a></a>
+</body>
+<script>baz</script>
+</html>""")
+ [soup.script.extract() for i in soup.find_all("script")]
+ self.assertEqual("<body>\n\n<a></a>\n</body>", str(soup.body))
+
+
+ def test_extract_works_when_element_is_surrounded_by_identical_strings(self):
+ soup = self.soup(
+ '<html>\n'
+ '<body>hi</body>\n'
+ '</html>')
+ soup.find('body').extract()
+ self.assertEqual(None, soup.find('body'))
+
+
+ def test_clear(self):
+ """Tag.clear()"""
+ soup = self.soup("<p><a>String <em>Italicized</em></a> and another</p>")
+ # clear using extract()
+ a = soup.a
+ soup.p.clear()
+ self.assertEqual(len(soup.p.contents), 0)
+ self.assertTrue(hasattr(a, "contents"))
+
+ # clear using decompose()
+ em = a.em
+ a.clear(decompose=True)
+ self.assertEqual(0, len(em.contents))
+
+ def test_string_set(self):
+ """Tag.string = 'string'"""
+ soup = self.soup("<a></a> <b><c></c></b>")
+ soup.a.string = "foo"
+ self.assertEqual(soup.a.contents, ["foo"])
+ soup.b.string = "bar"
+ self.assertEqual(soup.b.contents, ["bar"])
+
+ def test_string_set_does_not_affect_original_string(self):
+ soup = self.soup("<a><b>foo</b><c>bar</c>")
+ soup.b.string = soup.c.string
+ self.assertEqual(soup.a.encode(), b"<a><b>bar</b><c>bar</c></a>")
+
+ def test_set_string_preserves_class_of_string(self):
+ soup = self.soup("<a></a>")
+ cdata = CData("foo")
+ soup.a.string = cdata
+ self.assertTrue(isinstance(soup.a.string, CData))
+
+class TestElementObjects(SoupTest):
+ """Test various features of element objects."""
+
+ def test_len(self):
+ """The length of an element is its number of children."""
+ soup = self.soup("<top>1<b>2</b>3</top>")
+
+ # The BeautifulSoup object itself contains one element: the
+ # <top> tag.
+ self.assertEqual(len(soup.contents), 1)
+ self.assertEqual(len(soup), 1)
+
+ # The <top> tag contains three elements: the text node "1", the
+ # <b> tag, and the text node "3".
+ self.assertEqual(len(soup.top), 3)
+ self.assertEqual(len(soup.top.contents), 3)
+
+ def test_member_access_invokes_find(self):
+ """Accessing a Python member .foo invokes find('foo')"""
+ soup = self.soup('<b><i></i></b>')
+ self.assertEqual(soup.b, soup.find('b'))
+ self.assertEqual(soup.b.i, soup.find('b').find('i'))
+ self.assertEqual(soup.a, None)
+
+ def test_deprecated_member_access(self):
+ soup = self.soup('<b><i></i></b>')
+ with warnings.catch_warnings(record=True) as w:
+ tag = soup.bTag
+ self.assertEqual(soup.b, tag)
+ self.assertEqual(
+ '.bTag is deprecated, use .find("b") instead. If you really were looking for a tag called bTag, use .find("bTag")',
+ str(w[0].message))
+
+ def test_has_attr(self):
+ """has_attr() checks for the presence of an attribute.
+
+ Please note note: has_attr() is different from
+ __in__. has_attr() checks the tag's attributes and __in__
+ checks the tag's chidlren.
+ """
+ soup = self.soup("<foo attr='bar'>")
+ self.assertTrue(soup.foo.has_attr('attr'))
+ self.assertFalse(soup.foo.has_attr('attr2'))
+
+
+ def test_attributes_come_out_in_alphabetical_order(self):
+ markup = '<b a="1" z="5" m="3" f="2" y="4"></b>'
+ self.assertSoupEquals(markup, '<b a="1" f="2" m="3" y="4" z="5"></b>')
+
+ def test_string(self):
+ # A tag that contains only a text node makes that node
+ # available as .string.
+ soup = self.soup("<b>foo</b>")
+ self.assertEqual(soup.b.string, 'foo')
+
+ def test_empty_tag_has_no_string(self):
+ # A tag with no children has no .stirng.
+ soup = self.soup("<b></b>")
+ self.assertEqual(soup.b.string, None)
+
+ def test_tag_with_multiple_children_has_no_string(self):
+ # A tag with no children has no .string.
+ soup = self.soup("<a>foo<b></b><b></b></b>")
+ self.assertEqual(soup.b.string, None)
+
+ soup = self.soup("<a>foo<b></b>bar</b>")
+ self.assertEqual(soup.b.string, None)
+
+ # Even if all the children are strings, due to trickery,
+ # it won't work--but this would be a good optimization.
+ soup = self.soup("<a>foo</b>")
+ soup.a.insert(1, "bar")
+ self.assertEqual(soup.a.string, None)
+
+ def test_tag_with_recursive_string_has_string(self):
+ # A tag with a single child which has a .string inherits that
+ # .string.
+ soup = self.soup("<a><b>foo</b></a>")
+ self.assertEqual(soup.a.string, "foo")
+ self.assertEqual(soup.string, "foo")
+
+ def test_lack_of_string(self):
+ """Only a tag containing a single text node has a .string."""
+ soup = self.soup("<b>f<i>e</i>o</b>")
+ self.assertFalse(soup.b.string)
+
+ soup = self.soup("<b></b>")
+ self.assertFalse(soup.b.string)
+
+ def test_all_text(self):
+ """Tag.text and Tag.get_text(sep=u"") -> all child text, concatenated"""
+ soup = self.soup("<a>a<b>r</b> <r> t </r></a>")
+ self.assertEqual(soup.a.text, "ar t ")
+ self.assertEqual(soup.a.get_text(strip=True), "art")
+ self.assertEqual(soup.a.get_text(","), "a,r, , t ")
+ self.assertEqual(soup.a.get_text(",", strip=True), "a,r,t")
+
+ def test_get_text_ignores_comments(self):
+ soup = self.soup("foo<!--IGNORE-->bar")
+ self.assertEqual(soup.get_text(), "foobar")
+
+ self.assertEqual(
+ soup.get_text(types=(NavigableString, Comment)), "fooIGNOREbar")
+ self.assertEqual(
+ soup.get_text(types=None), "fooIGNOREbar")
+
+ def test_all_strings_ignores_comments(self):
+ soup = self.soup("foo<!--IGNORE-->bar")
+ self.assertEqual(['foo', 'bar'], list(soup.strings))
+
+class TestCDAtaListAttributes(SoupTest):
+
+ """Testing cdata-list attributes like 'class'.
+ """
+ def test_single_value_becomes_list(self):
+ soup = self.soup("<a class='foo'>")
+ self.assertEqual(["foo"],soup.a['class'])
+
+ def test_multiple_values_becomes_list(self):
+ soup = self.soup("<a class='foo bar'>")
+ self.assertEqual(["foo", "bar"], soup.a['class'])
+
+ def test_multiple_values_separated_by_weird_whitespace(self):
+ soup = self.soup("<a class='foo\tbar\nbaz'>")
+ self.assertEqual(["foo", "bar", "baz"],soup.a['class'])
+
+ def test_attributes_joined_into_string_on_output(self):
+ soup = self.soup("<a class='foo\tbar'>")
+ self.assertEqual(b'<a class="foo bar"></a>', soup.a.encode())
+
+ def test_get_attribute_list(self):
+ soup = self.soup("<a id='abc def'>")
+ self.assertEqual(['abc def'], soup.a.get_attribute_list('id'))
+
+ def test_accept_charset(self):
+ soup = self.soup('<form accept-charset="ISO-8859-1 UTF-8">')
+ self.assertEqual(['ISO-8859-1', 'UTF-8'], soup.form['accept-charset'])
+
+ def test_cdata_attribute_applying_only_to_one_tag(self):
+ data = '<a accept-charset="ISO-8859-1 UTF-8"></a>'
+ soup = self.soup(data)
+ # We saw in another test that accept-charset is a cdata-list
+ # attribute for the <form> tag. But it's not a cdata-list
+ # attribute for any other tag.
+ self.assertEqual('ISO-8859-1 UTF-8', soup.a['accept-charset'])
+
+ def test_string_has_immutable_name_property(self):
+ string = self.soup("s").string
+ self.assertEqual(None, string.name)
+ def t():
+ string.name = 'foo'
+ self.assertRaises(AttributeError, t)
+
+class TestPersistence(SoupTest):
+ "Testing features like pickle and deepcopy."
+
+ def setUp(self):
+ super(TestPersistence, self).setUp()
+ self.page = """<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.0 Transitional//EN"
+"http://www.w3.org/TR/REC-html40/transitional.dtd">
+<html>
+<head>
+<meta http-equiv="Content-Type" content="text/html; charset=utf-8">
+<title>Beautiful Soup: We called him Tortoise because he taught us.</title>
+<link rev="made" href="mailto:[email protected]">
+<meta name="Description" content="Beautiful Soup: an HTML parser optimized for screen-scraping.">
+<meta name="generator" content="Markov Approximation 1.4 (module: leonardr)">
+<meta name="author" content="Leonard Richardson">
+</head>
+<body>
+<a href="foo">foo</a>
+<a href="foo"><b>bar</b></a>
+</body>
+</html>"""
+ self.tree = self.soup(self.page)
+
+ def test_pickle_and_unpickle_identity(self):
+ # Pickling a tree, then unpickling it, yields a tree identical
+ # to the original.
+ dumped = pickle.dumps(self.tree, 2)
+ loaded = pickle.loads(dumped)
+ self.assertEqual(loaded.__class__, BeautifulSoup)
+ self.assertEqual(loaded.decode(), self.tree.decode())
+
+ def test_deepcopy_identity(self):
+ # Making a deepcopy of a tree yields an identical tree.
+ copied = copy.deepcopy(self.tree)
+ self.assertEqual(copied.decode(), self.tree.decode())
+
+ def test_copy_preserves_encoding(self):
+ soup = BeautifulSoup(b'<p>&nbsp;</p>', 'html.parser')
+ encoding = soup.original_encoding
+ copy = soup.__copy__()
+ self.assertEqual("<p> </p>", str(copy))
+ self.assertEqual(encoding, copy.original_encoding)
+
+ def test_unicode_pickle(self):
+ # A tree containing Unicode characters can be pickled.
+ html = "<b>\N{SNOWMAN}</b>"
+ soup = self.soup(html)
+ dumped = pickle.dumps(soup, pickle.HIGHEST_PROTOCOL)
+ loaded = pickle.loads(dumped)
+ self.assertEqual(loaded.decode(), soup.decode())
+
+ def test_copy_navigablestring_is_not_attached_to_tree(self):
+ html = "<b>Foo<a></a></b><b>Bar</b>"
+ soup = self.soup(html)
+ s1 = soup.find(string="Foo")
+ s2 = copy.copy(s1)
+ self.assertEqual(s1, s2)
+ self.assertEqual(None, s2.parent)
+ self.assertEqual(None, s2.next_element)
+ self.assertNotEqual(None, s1.next_sibling)
+ self.assertEqual(None, s2.next_sibling)
+ self.assertEqual(None, s2.previous_element)
+
+ def test_copy_navigablestring_subclass_has_same_type(self):
+ html = "<b><!--Foo--></b>"
+ soup = self.soup(html)
+ s1 = soup.string
+ s2 = copy.copy(s1)
+ self.assertEqual(s1, s2)
+ self.assertTrue(isinstance(s2, Comment))
+
+ def test_copy_entire_soup(self):
+ html = "<div><b>Foo<a></a></b><b>Bar</b></div>end"
+ soup = self.soup(html)
+ soup_copy = copy.copy(soup)
+ self.assertEqual(soup, soup_copy)
+
+ def test_copy_tag_copies_contents(self):
+ html = "<div><b>Foo<a></a></b><b>Bar</b></div>end"
+ soup = self.soup(html)
+ div = soup.div
+ div_copy = copy.copy(div)
+
+ # The two tags look the same, and evaluate to equal.
+ self.assertEqual(str(div), str(div_copy))
+ self.assertEqual(div, div_copy)
+
+ # But they're not the same object.
+ self.assertFalse(div is div_copy)
+
+ # And they don't have the same relation to the parse tree. The
+ # copy is not associated with a parse tree at all.
+ self.assertEqual(None, div_copy.parent)
+ self.assertEqual(None, div_copy.previous_element)
+ self.assertEqual(None, div_copy.find(string='Bar').next_element)
+ self.assertNotEqual(None, div.find(string='Bar').next_element)
+
+class TestSubstitutions(SoupTest):
+
+ def test_default_formatter_is_minimal(self):
+ markup = "<b>&lt;&lt;Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</b>"
+ soup = self.soup(markup)
+ decoded = soup.decode(formatter="minimal")
+ # The < is converted back into &lt; but the e-with-acute is left alone.
+ self.assertEqual(
+ decoded,
+ self.document_for(
+ "<b>&lt;&lt;Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</b>"))
+
+ def test_formatter_html(self):
+ markup = "<br><b>&lt;&lt;Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</b>"
+ soup = self.soup(markup)
+ decoded = soup.decode(formatter="html")
+ self.assertEqual(
+ decoded,
+ self.document_for("<br/><b>&lt;&lt;Sacr&eacute; bleu!&gt;&gt;</b>"))
+
+ def test_formatter_html5(self):
+ markup = "<br><b>&lt;&lt;Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</b>"
+ soup = self.soup(markup)
+ decoded = soup.decode(formatter="html5")
+ self.assertEqual(
+ decoded,
+ self.document_for("<br><b>&lt;&lt;Sacr&eacute; bleu!&gt;&gt;</b>"))
+
+ def test_formatter_minimal(self):
+ markup = "<b>&lt;&lt;Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</b>"
+ soup = self.soup(markup)
+ decoded = soup.decode(formatter="minimal")
+ # The < is converted back into &lt; but the e-with-acute is left alone.
+ self.assertEqual(
+ decoded,
+ self.document_for(
+ "<b>&lt;&lt;Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</b>"))
+
+ def test_formatter_null(self):
+ markup = "<b>&lt;&lt;Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</b>"
+ soup = self.soup(markup)
+ decoded = soup.decode(formatter=None)
+ # Neither the angle brackets nor the e-with-acute are converted.
+ # This is not valid HTML, but it's what the user wanted.
+ self.assertEqual(decoded,
+ self.document_for("<b><<Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!>></b>"))
+
+ def test_formatter_custom(self):
+ markup = "<b>&lt;foo&gt;</b><b>bar</b><br/>"
+ soup = self.soup(markup)
+ decoded = soup.decode(formatter = lambda x: x.upper())
+ # Instead of normal entity conversion code, the custom
+ # callable is called on every string.
+ self.assertEqual(
+ decoded,
+ self.document_for("<b><FOO></b><b>BAR</b><br/>"))
+
+ def test_formatter_is_run_on_attribute_values(self):
+ markup = '<a href="http://a.com?a=b&c=é">e</a>'
+ soup = self.soup(markup)
+ a = soup.a
+
+ expect_minimal = '<a href="http://a.com?a=b&amp;c=é">e</a>'
+
+ self.assertEqual(expect_minimal, a.decode())
+ self.assertEqual(expect_minimal, a.decode(formatter="minimal"))
+
+ expect_html = '<a href="http://a.com?a=b&amp;c=&eacute;">e</a>'
+ self.assertEqual(expect_html, a.decode(formatter="html"))
+
+ self.assertEqual(markup, a.decode(formatter=None))
+ expect_upper = '<a href="HTTP://A.COM?A=B&C=É">E</a>'
+ self.assertEqual(expect_upper, a.decode(formatter=lambda x: x.upper()))
+
+ def test_formatter_skips_script_tag_for_html_documents(self):
+ doc = """
+ <script type="text/javascript">
+ console.log("< < hey > > ");
+ </script>
+"""
+ encoded = BeautifulSoup(doc, 'html.parser').encode()
+ self.assertTrue(b"< < hey > >" in encoded)
+
+ def test_formatter_skips_style_tag_for_html_documents(self):
+ doc = """
+ <style type="text/css">
+ console.log("< < hey > > ");
+ </style>
+"""
+ encoded = BeautifulSoup(doc, 'html.parser').encode()
+ self.assertTrue(b"< < hey > >" in encoded)
+
+ def test_prettify_leaves_preformatted_text_alone(self):
+ soup = self.soup("<div> foo <pre> \tbar\n \n </pre> baz <textarea> eee\nfff\t</textarea></div>")
+ # Everything outside the <pre> tag is reformatted, but everything
+ # inside is left alone.
+ self.assertEqual(
+ '<div>\n foo\n <pre> \tbar\n \n </pre>\n baz\n <textarea> eee\nfff\t</textarea>\n</div>',
+ soup.div.prettify())
+
+ def test_prettify_accepts_formatter_function(self):
+ soup = BeautifulSoup("<html><body>foo</body></html>", 'html.parser')
+ pretty = soup.prettify(formatter = lambda x: x.upper())
+ self.assertTrue("FOO" in pretty)
+
+ def test_prettify_outputs_unicode_by_default(self):
+ soup = self.soup("<a></a>")
+ self.assertEqual(str, type(soup.prettify()))
+
+ def test_prettify_can_encode_data(self):
+ soup = self.soup("<a></a>")
+ self.assertEqual(bytes, type(soup.prettify("utf-8")))
+
+ def test_html_entity_substitution_off_by_default(self):
+ markup = "<b>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</b>"
+ soup = self.soup(markup)
+ encoded = soup.b.encode("utf-8")
+ self.assertEqual(encoded, markup.encode('utf-8'))
+
+ def test_encoding_substitution(self):
+ # Here's the <meta> tag saying that a document is
+ # encoded in Shift-JIS.
+ meta_tag = ('<meta content="text/html; charset=x-sjis" '
+ 'http-equiv="Content-type"/>')
+ soup = self.soup(meta_tag)
+
+ # Parse the document, and the charset apprears unchanged.
+ self.assertEqual(soup.meta['content'], 'text/html; charset=x-sjis')
+
+ # Encode the document into some encoding, and the encoding is
+ # substituted into the meta tag.
+ utf_8 = soup.encode("utf-8")
+ self.assertTrue(b"charset=utf-8" in utf_8)
+
+ euc_jp = soup.encode("euc_jp")
+ self.assertTrue(b"charset=euc_jp" in euc_jp)
+
+ shift_jis = soup.encode("shift-jis")
+ self.assertTrue(b"charset=shift-jis" in shift_jis)
+
+ utf_16_u = soup.encode("utf-16").decode("utf-16")
+ self.assertTrue("charset=utf-16" in utf_16_u)
+
+ def test_encoding_substitution_doesnt_happen_if_tag_is_strained(self):
+ markup = ('<head><meta content="text/html; charset=x-sjis" '
+ 'http-equiv="Content-type"/></head><pre>foo</pre>')
+
+ # Beautiful Soup used to try to rewrite the meta tag even if the
+ # meta tag got filtered out by the strainer. This test makes
+ # sure that doesn't happen.
+ strainer = SoupStrainer('pre')
+ soup = self.soup(markup, parse_only=strainer)
+ self.assertEqual(soup.contents[0].name, 'pre')
+
+class TestEncoding(SoupTest):
+ """Test the ability to encode objects into strings."""
+
+ def test_unicode_string_can_be_encoded(self):
+ html = "<b>\N{SNOWMAN}</b>"
+ soup = self.soup(html)
+ self.assertEqual(soup.b.string.encode("utf-8"),
+ "\N{SNOWMAN}".encode("utf-8"))
+
+ def test_tag_containing_unicode_string_can_be_encoded(self):
+ html = "<b>\N{SNOWMAN}</b>"
+ soup = self.soup(html)
+ self.assertEqual(
+ soup.b.encode("utf-8"), html.encode("utf-8"))
+
+ def test_encoding_substitutes_unrecognized_characters_by_default(self):
+ html = "<b>\N{SNOWMAN}</b>"
+ soup = self.soup(html)
+ self.assertEqual(soup.b.encode("ascii"), b"<b>&#9731;</b>")
+
+ def test_encoding_can_be_made_strict(self):
+ html = "<b>\N{SNOWMAN}</b>"
+ soup = self.soup(html)
+ self.assertRaises(
+ UnicodeEncodeError, soup.encode, "ascii", errors="strict")
+
+ def test_decode_contents(self):
+ html = "<b>\N{SNOWMAN}</b>"
+ soup = self.soup(html)
+ self.assertEqual("\N{SNOWMAN}", soup.b.decode_contents())
+
+ def test_encode_contents(self):
+ html = "<b>\N{SNOWMAN}</b>"
+ soup = self.soup(html)
+ self.assertEqual(
+ "\N{SNOWMAN}".encode("utf8"), soup.b.encode_contents(
+ encoding="utf8"))
+
+ def test_deprecated_renderContents(self):
+ html = "<b>\N{SNOWMAN}</b>"
+ soup = self.soup(html)
+ self.assertEqual(
+ "\N{SNOWMAN}".encode("utf8"), soup.b.renderContents())
+
+ def test_repr(self):
+ html = "<b>\N{SNOWMAN}</b>"
+ soup = self.soup(html)
+ if PY3K:
+ self.assertEqual(html, repr(soup))
+ else:
+ self.assertEqual(b'<b>\\u2603</b>', repr(soup))
+
+class TestFormatter(SoupTest):
+
+ def test_sort_attributes(self):
+ # Test the ability to override Formatter.attributes() to,
+ # e.g., disable the normal sorting of attributes.
+ class UnsortedFormatter(Formatter):
+ def attributes(self, tag):
+ self.called_with = tag
+ for k, v in sorted(tag.attrs.items()):
+ if k == 'ignore':
+ continue
+ yield k,v
+
+ soup = self.soup('<p cval="1" aval="2" ignore="ignored"></p>')
+ formatter = UnsortedFormatter()
+ decoded = soup.decode(formatter=formatter)
+
+ # attributes() was called on the <p> tag. It filtered out one
+ # attribute and sorted the other two.
+ self.assertEqual(formatter.called_with, soup.p)
+ self.assertEqual('<p aval="2" cval="1"></p>', decoded)
+
+
+class TestNavigableStringSubclasses(SoupTest):
+
+ def test_cdata(self):
+ # None of the current builders turn CDATA sections into CData
+ # objects, but you can create them manually.
+ soup = self.soup("")
+ cdata = CData("foo")
+ soup.insert(1, cdata)
+ self.assertEqual(str(soup), "<![CDATA[foo]]>")
+ self.assertEqual(soup.find(text="foo"), "foo")
+ self.assertEqual(soup.contents[0], "foo")
+
+ def test_cdata_is_never_formatted(self):
+ """Text inside a CData object is passed into the formatter.
+
+ But the return value is ignored.
+ """
+
+ self.count = 0
+ def increment(*args):
+ self.count += 1
+ return "BITTER FAILURE"
+
+ soup = self.soup("")
+ cdata = CData("<><><>")
+ soup.insert(1, cdata)
+ self.assertEqual(
+ b"<![CDATA[<><><>]]>", soup.encode(formatter=increment))
+ self.assertEqual(1, self.count)
+
+ def test_doctype_ends_in_newline(self):
+ # Unlike other NavigableString subclasses, a DOCTYPE always ends
+ # in a newline.
+ doctype = Doctype("foo")
+ soup = self.soup("")
+ soup.insert(1, doctype)
+ self.assertEqual(soup.encode(), b"<!DOCTYPE foo>\n")
+
+ def test_declaration(self):
+ d = Declaration("foo")
+ self.assertEqual("<?foo?>", d.output_ready())
+
+class TestSoupSelector(TreeTest):
+
+ HTML = """
+<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01//EN"
+"http://www.w3.org/TR/html4/strict.dtd">
+<html>
+<head>
+<title>The title</title>
+<link rel="stylesheet" href="blah.css" type="text/css" id="l1">
+</head>
+<body>
+<custom-dashed-tag class="dashed" id="dash1">Hello there.</custom-dashed-tag>
+<div id="main" class="fancy">
+<div id="inner">
+<h1 id="header1">An H1</h1>
+<p>Some text</p>
+<p class="onep" id="p1">Some more text</p>
+<h2 id="header2">An H2</h2>
+<p class="class1 class2 class3" id="pmulti">Another</p>
+<a href="http://bob.example.org/" rel="friend met" id="bob">Bob</a>
+<h2 id="header3">Another H2</h2>
+<a id="me" href="http://simonwillison.net/" rel="me">me</a>
+<span class="s1">
+<a href="#" id="s1a1">span1a1</a>
+<a href="#" id="s1a2">span1a2 <span id="s1a2s1">test</span></a>
+<span class="span2">
+<a href="#" id="s2a1">span2a1</a>
+</span>
+<span class="span3"></span>
+<custom-dashed-tag class="dashed" id="dash2"/>
+<div data-tag="dashedvalue" id="data1"/>
+</span>
+</div>
+<x id="xid">
+<z id="zida"/>
+<z id="zidab"/>
+<z id="zidac"/>
+</x>
+<y id="yid">
+<z id="zidb"/>
+</y>
+<p lang="en" id="lang-en">English</p>
+<p lang="en-gb" id="lang-en-gb">English UK</p>
+<p lang="en-us" id="lang-en-us">English US</p>
+<p lang="fr" id="lang-fr">French</p>
+</div>
+
+<div id="footer">
+</div>
+"""
+
+ def setUp(self):
+ self.soup = BeautifulSoup(self.HTML, 'html.parser')
+
+ def assertSelects(self, selector, expected_ids, **kwargs):
+ el_ids = [el['id'] for el in self.soup.select(selector, **kwargs)]
+ el_ids.sort()
+ expected_ids.sort()
+ self.assertEqual(expected_ids, el_ids,
+ "Selector %s, expected [%s], got [%s]" % (
+ selector, ', '.join(expected_ids), ', '.join(el_ids)
+ )
+ )
+
+ assertSelect = assertSelects
+
+ def assertSelectMultiple(self, *tests):
+ for selector, expected_ids in tests:
+ self.assertSelect(selector, expected_ids)
+
+ def test_one_tag_one(self):
+ els = self.soup.select('title')
+ self.assertEqual(len(els), 1)
+ self.assertEqual(els[0].name, 'title')
+ self.assertEqual(els[0].contents, ['The title'])
+
+ def test_one_tag_many(self):
+ els = self.soup.select('div')
+ self.assertEqual(len(els), 4)
+ for div in els:
+ self.assertEqual(div.name, 'div')
+
+ el = self.soup.select_one('div')
+ self.assertEqual('main', el['id'])
+
+ def test_select_one_returns_none_if_no_match(self):
+ match = self.soup.select_one('nonexistenttag')
+ self.assertEqual(None, match)
+
+
+ def test_tag_in_tag_one(self):
+ els = self.soup.select('div div')
+ self.assertSelects('div div', ['inner', 'data1'])
+
+ def test_tag_in_tag_many(self):
+ for selector in ('html div', 'html body div', 'body div'):
+ self.assertSelects(selector, ['data1', 'main', 'inner', 'footer'])
+
+
+ def test_limit(self):
+ self.assertSelects('html div', ['main'], limit=1)
+ self.assertSelects('html body div', ['inner', 'main'], limit=2)
+ self.assertSelects('body div', ['data1', 'main', 'inner', 'footer'],
+ limit=10)
+
+ def test_tag_no_match(self):
+ self.assertEqual(len(self.soup.select('del')), 0)
+
+ def test_invalid_tag(self):
+ self.assertRaises(SyntaxError, self.soup.select, 'tag%t')
+
+ def test_select_dashed_tag_ids(self):
+ self.assertSelects('custom-dashed-tag', ['dash1', 'dash2'])
+
+ def test_select_dashed_by_id(self):
+ dashed = self.soup.select('custom-dashed-tag[id=\"dash2\"]')
+ self.assertEqual(dashed[0].name, 'custom-dashed-tag')
+ self.assertEqual(dashed[0]['id'], 'dash2')
+
+ def test_dashed_tag_text(self):
+ self.assertEqual(self.soup.select('body > custom-dashed-tag')[0].text, 'Hello there.')
+
+ def test_select_dashed_matches_find_all(self):
+ self.assertEqual(self.soup.select('custom-dashed-tag'), self.soup.find_all('custom-dashed-tag'))
+
+ def test_header_tags(self):
+ self.assertSelectMultiple(
+ ('h1', ['header1']),
+ ('h2', ['header2', 'header3']),
+ )
+
+ def test_class_one(self):
+ for selector in ('.onep', 'p.onep', 'html p.onep'):
+ els = self.soup.select(selector)
+ self.assertEqual(len(els), 1)
+ self.assertEqual(els[0].name, 'p')
+ self.assertEqual(els[0]['class'], ['onep'])
+
+ def test_class_mismatched_tag(self):
+ els = self.soup.select('div.onep')
+ self.assertEqual(len(els), 0)
+
+ def test_one_id(self):
+ for selector in ('div#inner', '#inner', 'div div#inner'):
+ self.assertSelects(selector, ['inner'])
+
+ def test_bad_id(self):
+ els = self.soup.select('#doesnotexist')
+ self.assertEqual(len(els), 0)
+
+ def test_items_in_id(self):
+ els = self.soup.select('div#inner p')
+ self.assertEqual(len(els), 3)
+ for el in els:
+ self.assertEqual(el.name, 'p')
+ self.assertEqual(els[1]['class'], ['onep'])
+ self.assertFalse(els[0].has_attr('class'))
+
+ def test_a_bunch_of_emptys(self):
+ for selector in ('div#main del', 'div#main div.oops', 'div div#main'):
+ self.assertEqual(len(self.soup.select(selector)), 0)
+
+ def test_multi_class_support(self):
+ for selector in ('.class1', 'p.class1', '.class2', 'p.class2',
+ '.class3', 'p.class3', 'html p.class2', 'div#inner .class2'):
+ self.assertSelects(selector, ['pmulti'])
+
+ def test_multi_class_selection(self):
+ for selector in ('.class1.class3', '.class3.class2',
+ '.class1.class2.class3'):
+ self.assertSelects(selector, ['pmulti'])
+
+ def test_child_selector(self):
+ self.assertSelects('.s1 > a', ['s1a1', 's1a2'])
+ self.assertSelects('.s1 > a span', ['s1a2s1'])
+
+ def test_child_selector_id(self):
+ self.assertSelects('.s1 > a#s1a2 span', ['s1a2s1'])
+
+ def test_attribute_equals(self):
+ self.assertSelectMultiple(
+ ('p[class="onep"]', ['p1']),
+ ('p[id="p1"]', ['p1']),
+ ('[class="onep"]', ['p1']),
+ ('[id="p1"]', ['p1']),
+ ('link[rel="stylesheet"]', ['l1']),
+ ('link[type="text/css"]', ['l1']),
+ ('link[href="blah.css"]', ['l1']),
+ ('link[href="no-blah.css"]', []),
+ ('[rel="stylesheet"]', ['l1']),
+ ('[type="text/css"]', ['l1']),
+ ('[href="blah.css"]', ['l1']),
+ ('[href="no-blah.css"]', []),
+ ('p[href="no-blah.css"]', []),
+ ('[href="no-blah.css"]', []),
+ )
+
+ def test_attribute_tilde(self):
+ self.assertSelectMultiple(
+ ('p[class~="class1"]', ['pmulti']),
+ ('p[class~="class2"]', ['pmulti']),
+ ('p[class~="class3"]', ['pmulti']),
+ ('[class~="class1"]', ['pmulti']),
+ ('[class~="class2"]', ['pmulti']),
+ ('[class~="class3"]', ['pmulti']),
+ ('a[rel~="friend"]', ['bob']),
+ ('a[rel~="met"]', ['bob']),
+ ('[rel~="friend"]', ['bob']),
+ ('[rel~="met"]', ['bob']),
+ )
+
+ def test_attribute_startswith(self):
+ self.assertSelectMultiple(
+ ('[rel^="style"]', ['l1']),
+ ('link[rel^="style"]', ['l1']),
+ ('notlink[rel^="notstyle"]', []),
+ ('[rel^="notstyle"]', []),
+ ('link[rel^="notstyle"]', []),
+ ('link[href^="bla"]', ['l1']),
+ ('a[href^="http://"]', ['bob', 'me']),
+ ('[href^="http://"]', ['bob', 'me']),
+ ('[id^="p"]', ['pmulti', 'p1']),
+ ('[id^="m"]', ['me', 'main']),
+ ('div[id^="m"]', ['main']),
+ ('a[id^="m"]', ['me']),
+ ('div[data-tag^="dashed"]', ['data1'])
+ )
+
+ def test_attribute_endswith(self):
+ self.assertSelectMultiple(
+ ('[href$=".css"]', ['l1']),
+ ('link[href$=".css"]', ['l1']),
+ ('link[id$="1"]', ['l1']),
+ ('[id$="1"]', ['data1', 'l1', 'p1', 'header1', 's1a1', 's2a1', 's1a2s1', 'dash1']),
+ ('div[id$="1"]', ['data1']),
+ ('[id$="noending"]', []),
+ )
+
+ def test_attribute_contains(self):
+ self.assertSelectMultiple(
+ # From test_attribute_startswith
+ ('[rel*="style"]', ['l1']),
+ ('link[rel*="style"]', ['l1']),
+ ('notlink[rel*="notstyle"]', []),
+ ('[rel*="notstyle"]', []),
+ ('link[rel*="notstyle"]', []),
+ ('link[href*="bla"]', ['l1']),
+ ('[href*="http://"]', ['bob', 'me']),
+ ('[id*="p"]', ['pmulti', 'p1']),
+ ('div[id*="m"]', ['main']),
+ ('a[id*="m"]', ['me']),
+ # From test_attribute_endswith
+ ('[href*=".css"]', ['l1']),
+ ('link[href*=".css"]', ['l1']),
+ ('link[id*="1"]', ['l1']),
+ ('[id*="1"]', ['data1', 'l1', 'p1', 'header1', 's1a1', 's1a2', 's2a1', 's1a2s1', 'dash1']),
+ ('div[id*="1"]', ['data1']),
+ ('[id*="noending"]', []),
+ # New for this test
+ ('[href*="."]', ['bob', 'me', 'l1']),
+ ('a[href*="."]', ['bob', 'me']),
+ ('link[href*="."]', ['l1']),
+ ('div[id*="n"]', ['main', 'inner']),
+ ('div[id*="nn"]', ['inner']),
+ ('div[data-tag*="edval"]', ['data1'])
+ )
+
+ def test_attribute_exact_or_hypen(self):
+ self.assertSelectMultiple(
+ ('p[lang|="en"]', ['lang-en', 'lang-en-gb', 'lang-en-us']),
+ ('[lang|="en"]', ['lang-en', 'lang-en-gb', 'lang-en-us']),
+ ('p[lang|="fr"]', ['lang-fr']),
+ ('p[lang|="gb"]', []),
+ )
+
+ def test_attribute_exists(self):
+ self.assertSelectMultiple(
+ ('[rel]', ['l1', 'bob', 'me']),
+ ('link[rel]', ['l1']),
+ ('a[rel]', ['bob', 'me']),
+ ('[lang]', ['lang-en', 'lang-en-gb', 'lang-en-us', 'lang-fr']),
+ ('p[class]', ['p1', 'pmulti']),
+ ('[blah]', []),
+ ('p[blah]', []),
+ ('div[data-tag]', ['data1'])
+ )
+
+ def test_quoted_space_in_selector_name(self):
+ html = """<div style="display: wrong">nope</div>
+ <div style="display: right">yes</div>
+ """
+ soup = BeautifulSoup(html, 'html.parser')
+ [chosen] = soup.select('div[style="display: right"]')
+ self.assertEqual("yes", chosen.string)
+
+ def test_unsupported_pseudoclass(self):
+ self.assertRaises(
+ NotImplementedError, self.soup.select, "a:no-such-pseudoclass")
+
+ self.assertRaises(
+ SyntaxError, self.soup.select, "a:nth-of-type(a)")
+
+ def test_nth_of_type(self):
+ # Try to select first paragraph
+ els = self.soup.select('div#inner p:nth-of-type(1)')
+ self.assertEqual(len(els), 1)
+ self.assertEqual(els[0].string, 'Some text')
+
+ # Try to select third paragraph
+ els = self.soup.select('div#inner p:nth-of-type(3)')
+ self.assertEqual(len(els), 1)
+ self.assertEqual(els[0].string, 'Another')
+
+ # Try to select (non-existent!) fourth paragraph
+ els = self.soup.select('div#inner p:nth-of-type(4)')
+ self.assertEqual(len(els), 0)
+
+ # Zero will select no tags.
+ els = self.soup.select('div p:nth-of-type(0)')
+ self.assertEqual(len(els), 0)
+
+ def test_nth_of_type_direct_descendant(self):
+ els = self.soup.select('div#inner > p:nth-of-type(1)')
+ self.assertEqual(len(els), 1)
+ self.assertEqual(els[0].string, 'Some text')
+
+ def test_id_child_selector_nth_of_type(self):
+ self.assertSelects('#inner > p:nth-of-type(2)', ['p1'])
+
+ def test_select_on_element(self):
+ # Other tests operate on the tree; this operates on an element
+ # within the tree.
+ inner = self.soup.find("div", id="main")
+ selected = inner.select("div")
+ # The <div id="inner"> tag was selected. The <div id="footer">
+ # tag was not.
+ self.assertSelectsIDs(selected, ['inner', 'data1'])
+
+ def test_overspecified_child_id(self):
+ self.assertSelects(".fancy #inner", ['inner'])
+ self.assertSelects(".normal #inner", [])
+
+ def test_adjacent_sibling_selector(self):
+ self.assertSelects('#p1 + h2', ['header2'])
+ self.assertSelects('#p1 + h2 + p', ['pmulti'])
+ self.assertSelects('#p1 + #header2 + .class1', ['pmulti'])
+ self.assertEqual([], self.soup.select('#p1 + p'))
+
+ def test_general_sibling_selector(self):
+ self.assertSelects('#p1 ~ h2', ['header2', 'header3'])
+ self.assertSelects('#p1 ~ #header2', ['header2'])
+ self.assertSelects('#p1 ~ h2 + a', ['me'])
+ self.assertSelects('#p1 ~ h2 + [rel="me"]', ['me'])
+ self.assertEqual([], self.soup.select('#inner ~ h2'))
+
+ def test_dangling_combinator(self):
+ self.assertRaises(SyntaxError, self.soup.select, 'h1 >')
+
+ def test_sibling_combinator_wont_select_same_tag_twice(self):
+ self.assertSelects('p[lang] ~ p', ['lang-en-gb', 'lang-en-us', 'lang-fr'])
+
+ # Test the selector grouping operator (the comma)
+ def test_multiple_select(self):
+ self.assertSelects('x, y', ['xid', 'yid'])
+
+ def test_multiple_select_with_no_space(self):
+ self.assertSelects('x,y', ['xid', 'yid'])
+
+ def test_multiple_select_with_more_space(self):
+ self.assertSelects('x, y', ['xid', 'yid'])
+
+ def test_multiple_select_duplicated(self):
+ self.assertSelects('x, x', ['xid'])
+
+ def test_multiple_select_sibling(self):
+ self.assertSelects('x, y ~ p[lang=fr]', ['xid', 'lang-fr'])
+
+ def test_multiple_select_tag_and_direct_descendant(self):
+ self.assertSelects('x, y > z', ['xid', 'zidb'])
+
+ def test_multiple_select_direct_descendant_and_tags(self):
+ self.assertSelects('div > x, y, z', ['xid', 'yid', 'zida', 'zidb', 'zidab', 'zidac'])
+
+ def test_multiple_select_indirect_descendant(self):
+ self.assertSelects('div x,y, z', ['xid', 'yid', 'zida', 'zidb', 'zidab', 'zidac'])
+
+ def test_invalid_multiple_select(self):
+ self.assertRaises(SyntaxError, self.soup.select, ',x, y')
+ self.assertRaises(SyntaxError, self.soup.select, 'x,,y')
+
+ def test_multiple_select_attrs(self):
+ self.assertSelects('p[lang=en], p[lang=en-gb]', ['lang-en', 'lang-en-gb'])
+
+ def test_multiple_select_ids(self):
+ self.assertSelects('x, y > z[id=zida], z[id=zidab], z[id=zidb]', ['xid', 'zidb', 'zidab'])
+
+ def test_multiple_select_nested(self):
+ self.assertSelects('body > div > x, y > z', ['xid', 'zidb'])
+
+ def test_select_duplicate_elements(self):
+ # When markup contains duplicate elements, a multiple select
+ # will find all of them.
+ markup = '<div class="c1"/><div class="c2"/><div class="c1"/>'
+ soup = BeautifulSoup(markup, 'html.parser')
+ selected = soup.select(".c1, .c2")
+ self.assertEqual(3, len(selected))
+
+ # Verify that find_all finds the same elements, though because
+ # of an implementation detail it finds them in a different
+ # order.
+ for element in soup.find_all(class_=['c1', 'c2']):
+ assert element in selected
diff --git a/libs/engineio/__init__.py b/libs/engineio/__init__.py
new file mode 100644
index 000000000..f2c5b774c
--- /dev/null
+++ b/libs/engineio/__init__.py
@@ -0,0 +1,25 @@
+import sys
+
+from .client import Client
+from .middleware import WSGIApp, Middleware
+from .server import Server
+if sys.version_info >= (3, 5): # pragma: no cover
+ from .asyncio_server import AsyncServer
+ from .asyncio_client import AsyncClient
+ from .async_drivers.asgi import ASGIApp
+ try:
+ from .async_drivers.tornado import get_tornado_handler
+ except ImportError:
+ get_tornado_handler = None
+else: # pragma: no cover
+ AsyncServer = None
+ AsyncClient = None
+ get_tornado_handler = None
+ ASGIApp = None
+
+__version__ = '3.11.2'
+
+__all__ = ['__version__', 'Server', 'WSGIApp', 'Middleware', 'Client']
+if AsyncServer is not None: # pragma: no cover
+ __all__ += ['AsyncServer', 'ASGIApp', 'get_tornado_handler',
+ 'AsyncClient'],
diff --git a/libs/engineio/async_drivers/__init__.py b/libs/engineio/async_drivers/__init__.py
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/libs/engineio/async_drivers/__init__.py
diff --git a/libs/engineio/async_drivers/aiohttp.py b/libs/engineio/async_drivers/aiohttp.py
new file mode 100644
index 000000000..ad6987649
--- /dev/null
+++ b/libs/engineio/async_drivers/aiohttp.py
@@ -0,0 +1,128 @@
+import asyncio
+import sys
+from urllib.parse import urlsplit
+
+from aiohttp.web import Response, WebSocketResponse
+import six
+
+
+def create_route(app, engineio_server, engineio_endpoint):
+ """This function sets up the engine.io endpoint as a route for the
+ application.
+
+ Note that both GET and POST requests must be hooked up on the engine.io
+ endpoint.
+ """
+ app.router.add_get(engineio_endpoint, engineio_server.handle_request)
+ app.router.add_post(engineio_endpoint, engineio_server.handle_request)
+ app.router.add_route('OPTIONS', engineio_endpoint,
+ engineio_server.handle_request)
+
+
+def translate_request(request):
+ """This function takes the arguments passed to the request handler and
+ uses them to generate a WSGI compatible environ dictionary.
+ """
+ message = request._message
+ payload = request._payload
+
+ uri_parts = urlsplit(message.path)
+ environ = {
+ 'wsgi.input': payload,
+ 'wsgi.errors': sys.stderr,
+ 'wsgi.version': (1, 0),
+ 'wsgi.async': True,
+ 'wsgi.multithread': False,
+ 'wsgi.multiprocess': False,
+ 'wsgi.run_once': False,
+ 'SERVER_SOFTWARE': 'aiohttp',
+ 'REQUEST_METHOD': message.method,
+ 'QUERY_STRING': uri_parts.query or '',
+ 'RAW_URI': message.path,
+ 'SERVER_PROTOCOL': 'HTTP/%s.%s' % message.version,
+ 'REMOTE_ADDR': '127.0.0.1',
+ 'REMOTE_PORT': '0',
+ 'SERVER_NAME': 'aiohttp',
+ 'SERVER_PORT': '0',
+ 'aiohttp.request': request
+ }
+
+ for hdr_name, hdr_value in message.headers.items():
+ hdr_name = hdr_name.upper()
+ if hdr_name == 'CONTENT-TYPE':
+ environ['CONTENT_TYPE'] = hdr_value
+ continue
+ elif hdr_name == 'CONTENT-LENGTH':
+ environ['CONTENT_LENGTH'] = hdr_value
+ continue
+
+ key = 'HTTP_%s' % hdr_name.replace('-', '_')
+ if key in environ:
+ hdr_value = '%s,%s' % (environ[key], hdr_value)
+
+ environ[key] = hdr_value
+
+ environ['wsgi.url_scheme'] = environ.get('HTTP_X_FORWARDED_PROTO', 'http')
+
+ path_info = uri_parts.path
+
+ environ['PATH_INFO'] = path_info
+ environ['SCRIPT_NAME'] = ''
+
+ return environ
+
+
+def make_response(status, headers, payload, environ):
+ """This function generates an appropriate response object for this async
+ mode.
+ """
+ return Response(body=payload, status=int(status.split()[0]),
+ headers=headers)
+
+
+class WebSocket(object): # pragma: no cover
+ """
+ This wrapper class provides a aiohttp WebSocket interface that is
+ somewhat compatible with eventlet's implementation.
+ """
+ def __init__(self, handler):
+ self.handler = handler
+ self._sock = None
+
+ async def __call__(self, environ):
+ request = environ['aiohttp.request']
+ self._sock = WebSocketResponse()
+ await self._sock.prepare(request)
+
+ self.environ = environ
+ await self.handler(self)
+ return self._sock
+
+ async def close(self):
+ await self._sock.close()
+
+ async def send(self, message):
+ if isinstance(message, bytes):
+ f = self._sock.send_bytes
+ else:
+ f = self._sock.send_str
+ if asyncio.iscoroutinefunction(f):
+ await f(message)
+ else:
+ f(message)
+
+ async def wait(self):
+ msg = await self._sock.receive()
+ if not isinstance(msg.data, six.binary_type) and \
+ not isinstance(msg.data, six.text_type):
+ raise IOError()
+ return msg.data
+
+
+_async = {
+ 'asyncio': True,
+ 'create_route': create_route,
+ 'translate_request': translate_request,
+ 'make_response': make_response,
+ 'websocket': WebSocket,
+}
diff --git a/libs/engineio/async_drivers/asgi.py b/libs/engineio/async_drivers/asgi.py
new file mode 100644
index 000000000..9f14ef05f
--- /dev/null
+++ b/libs/engineio/async_drivers/asgi.py
@@ -0,0 +1,214 @@
+import os
+import sys
+
+from engineio.static_files import get_static_file
+
+
+class ASGIApp:
+ """ASGI application middleware for Engine.IO.
+
+ This middleware dispatches traffic to an Engine.IO application. It can
+ also serve a list of static files to the client, or forward unrelated
+ HTTP traffic to another ASGI application.
+
+ :param engineio_server: The Engine.IO server. Must be an instance of the
+ ``engineio.AsyncServer`` class.
+ :param static_files: A dictionary with static file mapping rules. See the
+ documentation for details on this argument.
+ :param other_asgi_app: A separate ASGI app that receives all other traffic.
+ :param engineio_path: The endpoint where the Engine.IO application should
+ be installed. The default value is appropriate for
+ most cases.
+
+ Example usage::
+
+ import engineio
+ import uvicorn
+
+ eio = engineio.AsyncServer()
+ app = engineio.ASGIApp(eio, static_files={
+ '/': {'content_type': 'text/html', 'filename': 'index.html'},
+ '/index.html': {'content_type': 'text/html',
+ 'filename': 'index.html'},
+ })
+ uvicorn.run(app, '127.0.0.1', 5000)
+ """
+ def __init__(self, engineio_server, other_asgi_app=None,
+ static_files=None, engineio_path='engine.io'):
+ self.engineio_server = engineio_server
+ self.other_asgi_app = other_asgi_app
+ self.engineio_path = engineio_path.strip('/')
+ self.static_files = static_files or {}
+
+ async def __call__(self, scope, receive, send):
+ if scope['type'] in ['http', 'websocket'] and \
+ scope['path'].startswith('/{0}/'.format(self.engineio_path)):
+ await self.engineio_server.handle_request(scope, receive, send)
+ else:
+ static_file = get_static_file(scope['path'], self.static_files) \
+ if scope['type'] == 'http' and self.static_files else None
+ if static_file:
+ await self.serve_static_file(static_file, receive, send)
+ elif self.other_asgi_app is not None:
+ await self.other_asgi_app(scope, receive, send)
+ elif scope['type'] == 'lifespan':
+ await self.lifespan(receive, send)
+ else:
+ await self.not_found(receive, send)
+
+ async def serve_static_file(self, static_file, receive,
+ send): # pragma: no cover
+ event = await receive()
+ if event['type'] == 'http.request':
+ if os.path.exists(static_file['filename']):
+ with open(static_file['filename'], 'rb') as f:
+ payload = f.read()
+ await send({'type': 'http.response.start',
+ 'status': 200,
+ 'headers': [(b'Content-Type', static_file[
+ 'content_type'].encode('utf-8'))]})
+ await send({'type': 'http.response.body',
+ 'body': payload})
+ else:
+ await self.not_found(receive, send)
+
+ async def lifespan(self, receive, send):
+ event = await receive()
+ if event['type'] == 'lifespan.startup':
+ await send({'type': 'lifespan.startup.complete'})
+ elif event['type'] == 'lifespan.shutdown':
+ await send({'type': 'lifespan.shutdown.complete'})
+
+ async def not_found(self, receive, send):
+ """Return a 404 Not Found error to the client."""
+ await send({'type': 'http.response.start',
+ 'status': 404,
+ 'headers': [(b'Content-Type', b'text/plain')]})
+ await send({'type': 'http.response.body',
+ 'body': b'Not Found'})
+
+
+async def translate_request(scope, receive, send):
+ class AwaitablePayload(object): # pragma: no cover
+ def __init__(self, payload):
+ self.payload = payload or b''
+
+ async def read(self, length=None):
+ if length is None:
+ r = self.payload
+ self.payload = b''
+ else:
+ r = self.payload[:length]
+ self.payload = self.payload[length:]
+ return r
+
+ event = await receive()
+ payload = b''
+ if event['type'] == 'http.request':
+ payload += event.get('body') or b''
+ while event.get('more_body'):
+ event = await receive()
+ if event['type'] == 'http.request':
+ payload += event.get('body') or b''
+ elif event['type'] == 'websocket.connect':
+ await send({'type': 'websocket.accept'})
+ else:
+ return {}
+
+ raw_uri = scope['path'].encode('utf-8')
+ if 'query_string' in scope and scope['query_string']:
+ raw_uri += b'?' + scope['query_string']
+ environ = {
+ 'wsgi.input': AwaitablePayload(payload),
+ 'wsgi.errors': sys.stderr,
+ 'wsgi.version': (1, 0),
+ 'wsgi.async': True,
+ 'wsgi.multithread': False,
+ 'wsgi.multiprocess': False,
+ 'wsgi.run_once': False,
+ 'SERVER_SOFTWARE': 'asgi',
+ 'REQUEST_METHOD': scope.get('method', 'GET'),
+ 'PATH_INFO': scope['path'],
+ 'QUERY_STRING': scope.get('query_string', b'').decode('utf-8'),
+ 'RAW_URI': raw_uri.decode('utf-8'),
+ 'SCRIPT_NAME': '',
+ 'SERVER_PROTOCOL': 'HTTP/1.1',
+ 'REMOTE_ADDR': '127.0.0.1',
+ 'REMOTE_PORT': '0',
+ 'SERVER_NAME': 'asgi',
+ 'SERVER_PORT': '0',
+ 'asgi.receive': receive,
+ 'asgi.send': send,
+ }
+
+ for hdr_name, hdr_value in scope['headers']:
+ hdr_name = hdr_name.upper().decode('utf-8')
+ hdr_value = hdr_value.decode('utf-8')
+ if hdr_name == 'CONTENT-TYPE':
+ environ['CONTENT_TYPE'] = hdr_value
+ continue
+ elif hdr_name == 'CONTENT-LENGTH':
+ environ['CONTENT_LENGTH'] = hdr_value
+ continue
+
+ key = 'HTTP_%s' % hdr_name.replace('-', '_')
+ if key in environ:
+ hdr_value = '%s,%s' % (environ[key], hdr_value)
+
+ environ[key] = hdr_value
+
+ environ['wsgi.url_scheme'] = environ.get('HTTP_X_FORWARDED_PROTO', 'http')
+ return environ
+
+
+async def make_response(status, headers, payload, environ):
+ headers = [(h[0].encode('utf-8'), h[1].encode('utf-8')) for h in headers]
+ await environ['asgi.send']({'type': 'http.response.start',
+ 'status': int(status.split(' ')[0]),
+ 'headers': headers})
+ await environ['asgi.send']({'type': 'http.response.body',
+ 'body': payload})
+
+
+class WebSocket(object): # pragma: no cover
+ """
+ This wrapper class provides an asgi WebSocket interface that is
+ somewhat compatible with eventlet's implementation.
+ """
+ def __init__(self, handler):
+ self.handler = handler
+ self.asgi_receive = None
+ self.asgi_send = None
+
+ async def __call__(self, environ):
+ self.asgi_receive = environ['asgi.receive']
+ self.asgi_send = environ['asgi.send']
+ await self.handler(self)
+
+ async def close(self):
+ await self.asgi_send({'type': 'websocket.close'})
+
+ async def send(self, message):
+ msg_bytes = None
+ msg_text = None
+ if isinstance(message, bytes):
+ msg_bytes = message
+ else:
+ msg_text = message
+ await self.asgi_send({'type': 'websocket.send',
+ 'bytes': msg_bytes,
+ 'text': msg_text})
+
+ async def wait(self):
+ event = await self.asgi_receive()
+ if event['type'] != 'websocket.receive':
+ raise IOError()
+ return event.get('bytes') or event.get('text')
+
+
+_async = {
+ 'asyncio': True,
+ 'translate_request': translate_request,
+ 'make_response': make_response,
+ 'websocket': WebSocket,
+}
diff --git a/libs/engineio/async_drivers/eventlet.py b/libs/engineio/async_drivers/eventlet.py
new file mode 100644
index 000000000..9be3797cd
--- /dev/null
+++ b/libs/engineio/async_drivers/eventlet.py
@@ -0,0 +1,30 @@
+from __future__ import absolute_import
+
+from eventlet.green.threading import Thread, Event
+from eventlet import queue
+from eventlet import sleep
+from eventlet.websocket import WebSocketWSGI as _WebSocketWSGI
+
+
+class WebSocketWSGI(_WebSocketWSGI):
+ def __init__(self, *args, **kwargs):
+ super(WebSocketWSGI, self).__init__(*args, **kwargs)
+ self._sock = None
+
+ def __call__(self, environ, start_response):
+ if 'eventlet.input' not in environ:
+ raise RuntimeError('You need to use the eventlet server. '
+ 'See the Deployment section of the '
+ 'documentation for more information.')
+ self._sock = environ['eventlet.input'].get_socket()
+ return super(WebSocketWSGI, self).__call__(environ, start_response)
+
+
+_async = {
+ 'thread': Thread,
+ 'queue': queue.Queue,
+ 'queue_empty': queue.Empty,
+ 'event': Event,
+ 'websocket': WebSocketWSGI,
+ 'sleep': sleep,
+}
diff --git a/libs/engineio/async_drivers/gevent.py b/libs/engineio/async_drivers/gevent.py
new file mode 100644
index 000000000..024dd0aad
--- /dev/null
+++ b/libs/engineio/async_drivers/gevent.py
@@ -0,0 +1,63 @@
+from __future__ import absolute_import
+
+import gevent
+from gevent import queue
+from gevent.event import Event
+try:
+ import geventwebsocket # noqa
+ _websocket_available = True
+except ImportError:
+ _websocket_available = False
+
+
+class Thread(gevent.Greenlet): # pragma: no cover
+ """
+ This wrapper class provides gevent Greenlet interface that is compatible
+ with the standard library's Thread class.
+ """
+ def __init__(self, target, args=[], kwargs={}):
+ super(Thread, self).__init__(target, *args, **kwargs)
+
+ def _run(self):
+ return self.run()
+
+
+class WebSocketWSGI(object): # pragma: no cover
+ """
+ This wrapper class provides a gevent WebSocket interface that is
+ compatible with eventlet's implementation.
+ """
+ def __init__(self, app):
+ self.app = app
+
+ def __call__(self, environ, start_response):
+ if 'wsgi.websocket' not in environ:
+ raise RuntimeError('You need to use the gevent-websocket server. '
+ 'See the Deployment section of the '
+ 'documentation for more information.')
+ self._sock = environ['wsgi.websocket']
+ self.environ = environ
+ self.version = self._sock.version
+ self.path = self._sock.path
+ self.origin = self._sock.origin
+ self.protocol = self._sock.protocol
+ return self.app(self)
+
+ def close(self):
+ return self._sock.close()
+
+ def send(self, message):
+ return self._sock.send(message)
+
+ def wait(self):
+ return self._sock.receive()
+
+
+_async = {
+ 'thread': Thread,
+ 'queue': queue.JoinableQueue,
+ 'queue_empty': queue.Empty,
+ 'event': Event,
+ 'websocket': WebSocketWSGI if _websocket_available else None,
+ 'sleep': gevent.sleep,
+}
diff --git a/libs/engineio/async_drivers/gevent_uwsgi.py b/libs/engineio/async_drivers/gevent_uwsgi.py
new file mode 100644
index 000000000..07fa2a79d
--- /dev/null
+++ b/libs/engineio/async_drivers/gevent_uwsgi.py
@@ -0,0 +1,156 @@
+from __future__ import absolute_import
+
+import six
+
+import gevent
+from gevent import queue
+from gevent.event import Event
+import uwsgi
+_websocket_available = hasattr(uwsgi, 'websocket_handshake')
+
+
+class Thread(gevent.Greenlet): # pragma: no cover
+ """
+ This wrapper class provides gevent Greenlet interface that is compatible
+ with the standard library's Thread class.
+ """
+ def __init__(self, target, args=[], kwargs={}):
+ super(Thread, self).__init__(target, *args, **kwargs)
+
+ def _run(self):
+ return self.run()
+
+
+class uWSGIWebSocket(object): # pragma: no cover
+ """
+ This wrapper class provides a uWSGI WebSocket interface that is
+ compatible with eventlet's implementation.
+ """
+ def __init__(self, app):
+ self.app = app
+ self._sock = None
+
+ def __call__(self, environ, start_response):
+ self._sock = uwsgi.connection_fd()
+ self.environ = environ
+
+ uwsgi.websocket_handshake()
+
+ self._req_ctx = None
+ if hasattr(uwsgi, 'request_context'):
+ # uWSGI >= 2.1.x with support for api access across-greenlets
+ self._req_ctx = uwsgi.request_context()
+ else:
+ # use event and queue for sending messages
+ from gevent.event import Event
+ from gevent.queue import Queue
+ from gevent.select import select
+ self._event = Event()
+ self._send_queue = Queue()
+
+ # spawn a select greenlet
+ def select_greenlet_runner(fd, event):
+ """Sets event when data becomes available to read on fd."""
+ while True:
+ event.set()
+ try:
+ select([fd], [], [])[0]
+ except ValueError:
+ break
+ self._select_greenlet = gevent.spawn(
+ select_greenlet_runner,
+ self._sock,
+ self._event)
+
+ self.app(self)
+
+ def close(self):
+ """Disconnects uWSGI from the client."""
+ uwsgi.disconnect()
+ if self._req_ctx is None:
+ # better kill it here in case wait() is not called again
+ self._select_greenlet.kill()
+ self._event.set()
+
+ def _send(self, msg):
+ """Transmits message either in binary or UTF-8 text mode,
+ depending on its type."""
+ if isinstance(msg, six.binary_type):
+ method = uwsgi.websocket_send_binary
+ else:
+ method = uwsgi.websocket_send
+ if self._req_ctx is not None:
+ method(msg, request_context=self._req_ctx)
+ else:
+ method(msg)
+
+ def _decode_received(self, msg):
+ """Returns either bytes or str, depending on message type."""
+ if not isinstance(msg, six.binary_type):
+ # already decoded - do nothing
+ return msg
+ # only decode from utf-8 if message is not binary data
+ type = six.byte2int(msg[0:1])
+ if type >= 48: # no binary
+ return msg.decode('utf-8')
+ # binary message, don't try to decode
+ return msg
+
+ def send(self, msg):
+ """Queues a message for sending. Real transmission is done in
+ wait method.
+ Sends directly if uWSGI version is new enough."""
+ if self._req_ctx is not None:
+ self._send(msg)
+ else:
+ self._send_queue.put(msg)
+ self._event.set()
+
+ def wait(self):
+ """Waits and returns received messages.
+ If running in compatibility mode for older uWSGI versions,
+ it also sends messages that have been queued by send().
+ A return value of None means that connection was closed.
+ This must be called repeatedly. For uWSGI < 2.1.x it must
+ be called from the main greenlet."""
+ while True:
+ if self._req_ctx is not None:
+ try:
+ msg = uwsgi.websocket_recv(request_context=self._req_ctx)
+ except IOError: # connection closed
+ return None
+ return self._decode_received(msg)
+ else:
+ # we wake up at least every 3 seconds to let uWSGI
+ # do its ping/ponging
+ event_set = self._event.wait(timeout=3)
+ if event_set:
+ self._event.clear()
+ # maybe there is something to send
+ msgs = []
+ while True:
+ try:
+ msgs.append(self._send_queue.get(block=False))
+ except gevent.queue.Empty:
+ break
+ for msg in msgs:
+ self._send(msg)
+ # maybe there is something to receive, if not, at least
+ # ensure uWSGI does its ping/ponging
+ try:
+ msg = uwsgi.websocket_recv_nb()
+ except IOError: # connection closed
+ self._select_greenlet.kill()
+ return None
+ if msg: # message available
+ return self._decode_received(msg)
+
+
+_async = {
+ 'thread': Thread,
+ 'queue': queue.JoinableQueue,
+ 'queue_empty': queue.Empty,
+ 'event': Event,
+ 'websocket': uWSGIWebSocket if _websocket_available else None,
+ 'sleep': gevent.sleep,
+}
diff --git a/libs/engineio/async_drivers/sanic.py b/libs/engineio/async_drivers/sanic.py
new file mode 100644
index 000000000..6929654b9
--- /dev/null
+++ b/libs/engineio/async_drivers/sanic.py
@@ -0,0 +1,144 @@
+import sys
+from urllib.parse import urlsplit
+
+from sanic.response import HTTPResponse
+try:
+ from sanic.websocket import WebSocketProtocol
+except ImportError:
+ # the installed version of sanic does not have websocket support
+ WebSocketProtocol = None
+import six
+
+
+def create_route(app, engineio_server, engineio_endpoint):
+ """This function sets up the engine.io endpoint as a route for the
+ application.
+
+ Note that both GET and POST requests must be hooked up on the engine.io
+ endpoint.
+ """
+ app.add_route(engineio_server.handle_request, engineio_endpoint,
+ methods=['GET', 'POST', 'OPTIONS'])
+ try:
+ app.enable_websocket()
+ except AttributeError:
+ # ignore, this version does not support websocket
+ pass
+
+
+def translate_request(request):
+ """This function takes the arguments passed to the request handler and
+ uses them to generate a WSGI compatible environ dictionary.
+ """
+ class AwaitablePayload(object):
+ def __init__(self, payload):
+ self.payload = payload or b''
+
+ async def read(self, length=None):
+ if length is None:
+ r = self.payload
+ self.payload = b''
+ else:
+ r = self.payload[:length]
+ self.payload = self.payload[length:]
+ return r
+
+ uri_parts = urlsplit(request.url)
+ environ = {
+ 'wsgi.input': AwaitablePayload(request.body),
+ 'wsgi.errors': sys.stderr,
+ 'wsgi.version': (1, 0),
+ 'wsgi.async': True,
+ 'wsgi.multithread': False,
+ 'wsgi.multiprocess': False,
+ 'wsgi.run_once': False,
+ 'SERVER_SOFTWARE': 'sanic',
+ 'REQUEST_METHOD': request.method,
+ 'QUERY_STRING': uri_parts.query or '',
+ 'RAW_URI': request.url,
+ 'SERVER_PROTOCOL': 'HTTP/' + request.version,
+ 'REMOTE_ADDR': '127.0.0.1',
+ 'REMOTE_PORT': '0',
+ 'SERVER_NAME': 'sanic',
+ 'SERVER_PORT': '0',
+ 'sanic.request': request
+ }
+
+ for hdr_name, hdr_value in request.headers.items():
+ hdr_name = hdr_name.upper()
+ if hdr_name == 'CONTENT-TYPE':
+ environ['CONTENT_TYPE'] = hdr_value
+ continue
+ elif hdr_name == 'CONTENT-LENGTH':
+ environ['CONTENT_LENGTH'] = hdr_value
+ continue
+
+ key = 'HTTP_%s' % hdr_name.replace('-', '_')
+ if key in environ:
+ hdr_value = '%s,%s' % (environ[key], hdr_value)
+
+ environ[key] = hdr_value
+
+ environ['wsgi.url_scheme'] = environ.get('HTTP_X_FORWARDED_PROTO', 'http')
+
+ path_info = uri_parts.path
+
+ environ['PATH_INFO'] = path_info
+ environ['SCRIPT_NAME'] = ''
+
+ return environ
+
+
+def make_response(status, headers, payload, environ):
+ """This function generates an appropriate response object for this async
+ mode.
+ """
+ headers_dict = {}
+ content_type = None
+ for h in headers:
+ if h[0].lower() == 'content-type':
+ content_type = h[1]
+ else:
+ headers_dict[h[0]] = h[1]
+ return HTTPResponse(body_bytes=payload, content_type=content_type,
+ status=int(status.split()[0]), headers=headers_dict)
+
+
+class WebSocket(object): # pragma: no cover
+ """
+ This wrapper class provides a sanic WebSocket interface that is
+ somewhat compatible with eventlet's implementation.
+ """
+ def __init__(self, handler):
+ self.handler = handler
+ self._sock = None
+
+ async def __call__(self, environ):
+ request = environ['sanic.request']
+ protocol = request.transport.get_protocol()
+ self._sock = await protocol.websocket_handshake(request)
+
+ self.environ = environ
+ await self.handler(self)
+
+ async def close(self):
+ await self._sock.close()
+
+ async def send(self, message):
+ await self._sock.send(message)
+
+ async def wait(self):
+ data = await self._sock.recv()
+ if not isinstance(data, six.binary_type) and \
+ not isinstance(data, six.text_type):
+ raise IOError()
+ return data
+
+
+_async = {
+ 'asyncio': True,
+ 'create_route': create_route,
+ 'translate_request': translate_request,
+ 'make_response': make_response,
+ 'websocket': WebSocket if WebSocketProtocol else None,
+}
diff --git a/libs/engineio/async_drivers/threading.py b/libs/engineio/async_drivers/threading.py
new file mode 100644
index 000000000..9b5375668
--- /dev/null
+++ b/libs/engineio/async_drivers/threading.py
@@ -0,0 +1,17 @@
+from __future__ import absolute_import
+import threading
+import time
+
+try:
+ import queue
+except ImportError: # pragma: no cover
+ import Queue as queue
+
+_async = {
+ 'thread': threading.Thread,
+ 'queue': queue.Queue,
+ 'queue_empty': queue.Empty,
+ 'event': threading.Event,
+ 'websocket': None,
+ 'sleep': time.sleep,
+}
diff --git a/libs/engineio/async_drivers/tornado.py b/libs/engineio/async_drivers/tornado.py
new file mode 100644
index 000000000..adfe18f5a
--- /dev/null
+++ b/libs/engineio/async_drivers/tornado.py
@@ -0,0 +1,184 @@
+import asyncio
+import sys
+from urllib.parse import urlsplit
+from .. import exceptions
+
+import tornado.web
+import tornado.websocket
+import six
+
+
+def get_tornado_handler(engineio_server):
+ class Handler(tornado.websocket.WebSocketHandler): # pragma: no cover
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ if isinstance(engineio_server.cors_allowed_origins,
+ six.string_types):
+ if engineio_server.cors_allowed_origins == '*':
+ self.allowed_origins = None
+ else:
+ self.allowed_origins = [
+ engineio_server.cors_allowed_origins]
+ else:
+ self.allowed_origins = engineio_server.cors_allowed_origins
+ self.receive_queue = asyncio.Queue()
+
+ async def get(self, *args, **kwargs):
+ if self.request.headers.get('Upgrade', '').lower() == 'websocket':
+ ret = super().get(*args, **kwargs)
+ if asyncio.iscoroutine(ret):
+ await ret
+ else:
+ await engineio_server.handle_request(self)
+
+ async def open(self, *args, **kwargs):
+ # this is the handler for the websocket request
+ asyncio.ensure_future(engineio_server.handle_request(self))
+
+ async def post(self, *args, **kwargs):
+ await engineio_server.handle_request(self)
+
+ async def options(self, *args, **kwargs):
+ await engineio_server.handle_request(self)
+
+ async def on_message(self, message):
+ await self.receive_queue.put(message)
+
+ async def get_next_message(self):
+ return await self.receive_queue.get()
+
+ def on_close(self):
+ self.receive_queue.put_nowait(None)
+
+ def check_origin(self, origin):
+ if self.allowed_origins is None or origin in self.allowed_origins:
+ return True
+ return super().check_origin(origin)
+
+ def get_compression_options(self):
+ # enable compression
+ return {}
+
+ return Handler
+
+
+def translate_request(handler):
+ """This function takes the arguments passed to the request handler and
+ uses them to generate a WSGI compatible environ dictionary.
+ """
+ class AwaitablePayload(object):
+ def __init__(self, payload):
+ self.payload = payload or b''
+
+ async def read(self, length=None):
+ if length is None:
+ r = self.payload
+ self.payload = b''
+ else:
+ r = self.payload[:length]
+ self.payload = self.payload[length:]
+ return r
+
+ payload = handler.request.body
+
+ uri_parts = urlsplit(handler.request.path)
+ full_uri = handler.request.path
+ if handler.request.query: # pragma: no cover
+ full_uri += '?' + handler.request.query
+ environ = {
+ 'wsgi.input': AwaitablePayload(payload),
+ 'wsgi.errors': sys.stderr,
+ 'wsgi.version': (1, 0),
+ 'wsgi.async': True,
+ 'wsgi.multithread': False,
+ 'wsgi.multiprocess': False,
+ 'wsgi.run_once': False,
+ 'SERVER_SOFTWARE': 'aiohttp',
+ 'REQUEST_METHOD': handler.request.method,
+ 'QUERY_STRING': handler.request.query or '',
+ 'RAW_URI': full_uri,
+ 'SERVER_PROTOCOL': 'HTTP/%s' % handler.request.version,
+ 'REMOTE_ADDR': '127.0.0.1',
+ 'REMOTE_PORT': '0',
+ 'SERVER_NAME': 'aiohttp',
+ 'SERVER_PORT': '0',
+ 'tornado.handler': handler
+ }
+
+ for hdr_name, hdr_value in handler.request.headers.items():
+ hdr_name = hdr_name.upper()
+ if hdr_name == 'CONTENT-TYPE':
+ environ['CONTENT_TYPE'] = hdr_value
+ continue
+ elif hdr_name == 'CONTENT-LENGTH':
+ environ['CONTENT_LENGTH'] = hdr_value
+ continue
+
+ key = 'HTTP_%s' % hdr_name.replace('-', '_')
+ environ[key] = hdr_value
+
+ environ['wsgi.url_scheme'] = environ.get('HTTP_X_FORWARDED_PROTO', 'http')
+
+ path_info = uri_parts.path
+
+ environ['PATH_INFO'] = path_info
+ environ['SCRIPT_NAME'] = ''
+
+ return environ
+
+
+def make_response(status, headers, payload, environ):
+ """This function generates an appropriate response object for this async
+ mode.
+ """
+ tornado_handler = environ['tornado.handler']
+ try:
+ tornado_handler.set_status(int(status.split()[0]))
+ except RuntimeError: # pragma: no cover
+ # for websocket connections Tornado does not accept a response, since
+ # it already emitted the 101 status code
+ return
+ for header, value in headers:
+ tornado_handler.set_header(header, value)
+ tornado_handler.write(payload)
+ tornado_handler.finish()
+
+
+class WebSocket(object): # pragma: no cover
+ """
+ This wrapper class provides a tornado WebSocket interface that is
+ somewhat compatible with eventlet's implementation.
+ """
+ def __init__(self, handler):
+ self.handler = handler
+ self.tornado_handler = None
+
+ async def __call__(self, environ):
+ self.tornado_handler = environ['tornado.handler']
+ self.environ = environ
+ await self.handler(self)
+
+ async def close(self):
+ self.tornado_handler.close()
+
+ async def send(self, message):
+ try:
+ self.tornado_handler.write_message(
+ message, binary=isinstance(message, bytes))
+ except tornado.websocket.WebSocketClosedError:
+ raise exceptions.EngineIOError()
+
+ async def wait(self):
+ msg = await self.tornado_handler.get_next_message()
+ if not isinstance(msg, six.binary_type) and \
+ not isinstance(msg, six.text_type):
+ raise IOError()
+ return msg
+
+
+_async = {
+ 'asyncio': True,
+ 'translate_request': translate_request,
+ 'make_response': make_response,
+ 'websocket': WebSocket,
+}
diff --git a/libs/engineio/asyncio_client.py b/libs/engineio/asyncio_client.py
new file mode 100644
index 000000000..049b4bd95
--- /dev/null
+++ b/libs/engineio/asyncio_client.py
@@ -0,0 +1,585 @@
+import asyncio
+import ssl
+
+try:
+ import aiohttp
+except ImportError: # pragma: no cover
+ aiohttp = None
+import six
+
+from . import client
+from . import exceptions
+from . import packet
+from . import payload
+
+
+class AsyncClient(client.Client):
+ """An Engine.IO client for asyncio.
+
+ This class implements a fully compliant Engine.IO web client with support
+ for websocket and long-polling transports, compatible with the asyncio
+ framework on Python 3.5 or newer.
+
+ :param logger: To enable logging set to ``True`` or pass a logger object to
+ use. To disable logging set to ``False``. The default is
+ ``False``.
+ :param json: An alternative json module to use for encoding and decoding
+ packets. Custom json modules must have ``dumps`` and ``loads``
+ functions that are compatible with the standard library
+ versions.
+ :param request_timeout: A timeout in seconds for requests. The default is
+ 5 seconds.
+ :param ssl_verify: ``True`` to verify SSL certificates, or ``False`` to
+ skip SSL certificate verification, allowing
+ connections to servers with self signed certificates.
+ The default is ``True``.
+ """
+ def is_asyncio_based(self):
+ return True
+
+ async def connect(self, url, headers={}, transports=None,
+ engineio_path='engine.io'):
+ """Connect to an Engine.IO server.
+
+ :param url: The URL of the Engine.IO server. It can include custom
+ query string parameters if required by the server.
+ :param headers: A dictionary with custom headers to send with the
+ connection request.
+ :param transports: The list of allowed transports. Valid transports
+ are ``'polling'`` and ``'websocket'``. If not
+ given, the polling transport is connected first,
+ then an upgrade to websocket is attempted.
+ :param engineio_path: The endpoint where the Engine.IO server is
+ installed. The default value is appropriate for
+ most cases.
+
+ Note: this method is a coroutine.
+
+ Example usage::
+
+ eio = engineio.Client()
+ await eio.connect('http://localhost:5000')
+ """
+ if self.state != 'disconnected':
+ raise ValueError('Client is not in a disconnected state')
+ valid_transports = ['polling', 'websocket']
+ if transports is not None:
+ if isinstance(transports, six.text_type):
+ transports = [transports]
+ transports = [transport for transport in transports
+ if transport in valid_transports]
+ if not transports:
+ raise ValueError('No valid transports provided')
+ self.transports = transports or valid_transports
+ self.queue = self.create_queue()
+ return await getattr(self, '_connect_' + self.transports[0])(
+ url, headers, engineio_path)
+
+ async def wait(self):
+ """Wait until the connection with the server ends.
+
+ Client applications can use this function to block the main thread
+ during the life of the connection.
+
+ Note: this method is a coroutine.
+ """
+ if self.read_loop_task:
+ await self.read_loop_task
+
+ async def send(self, data, binary=None):
+ """Send a message to a client.
+
+ :param data: The data to send to the client. Data can be of type
+ ``str``, ``bytes``, ``list`` or ``dict``. If a ``list``
+ or ``dict``, the data will be serialized as JSON.
+ :param binary: ``True`` to send packet as binary, ``False`` to send
+ as text. If not given, unicode (Python 2) and str
+ (Python 3) are sent as text, and str (Python 2) and
+ bytes (Python 3) are sent as binary.
+
+ Note: this method is a coroutine.
+ """
+ await self._send_packet(packet.Packet(packet.MESSAGE, data=data,
+ binary=binary))
+
+ async def disconnect(self, abort=False):
+ """Disconnect from the server.
+
+ :param abort: If set to ``True``, do not wait for background tasks
+ associated with the connection to end.
+
+ Note: this method is a coroutine.
+ """
+ if self.state == 'connected':
+ await self._send_packet(packet.Packet(packet.CLOSE))
+ await self.queue.put(None)
+ self.state = 'disconnecting'
+ await self._trigger_event('disconnect', run_async=False)
+ if self.current_transport == 'websocket':
+ await self.ws.close()
+ if not abort:
+ await self.read_loop_task
+ self.state = 'disconnected'
+ try:
+ client.connected_clients.remove(self)
+ except ValueError: # pragma: no cover
+ pass
+ self._reset()
+
+ def start_background_task(self, target, *args, **kwargs):
+ """Start a background task.
+
+ This is a utility function that applications can use to start a
+ background task.
+
+ :param target: the target function to execute.
+ :param args: arguments to pass to the function.
+ :param kwargs: keyword arguments to pass to the function.
+
+ This function returns an object compatible with the `Thread` class in
+ the Python standard library. The `start()` method on this object is
+ already called by this function.
+
+ Note: this method is a coroutine.
+ """
+ return asyncio.ensure_future(target(*args, **kwargs))
+
+ async def sleep(self, seconds=0):
+ """Sleep for the requested amount of time.
+
+ Note: this method is a coroutine.
+ """
+ return await asyncio.sleep(seconds)
+
+ def create_queue(self):
+ """Create a queue object."""
+ q = asyncio.Queue()
+ q.Empty = asyncio.QueueEmpty
+ return q
+
+ def create_event(self):
+ """Create an event object."""
+ return asyncio.Event()
+
+ def _reset(self):
+ if self.http: # pragma: no cover
+ asyncio.ensure_future(self.http.close())
+ super()._reset()
+
+ async def _connect_polling(self, url, headers, engineio_path):
+ """Establish a long-polling connection to the Engine.IO server."""
+ if aiohttp is None: # pragma: no cover
+ self.logger.error('aiohttp not installed -- cannot make HTTP '
+ 'requests!')
+ return
+ self.base_url = self._get_engineio_url(url, engineio_path, 'polling')
+ self.logger.info('Attempting polling connection to ' + self.base_url)
+ r = await self._send_request(
+ 'GET', self.base_url + self._get_url_timestamp(), headers=headers,
+ timeout=self.request_timeout)
+ if r is None:
+ self._reset()
+ raise exceptions.ConnectionError(
+ 'Connection refused by the server')
+ if r.status < 200 or r.status >= 300:
+ raise exceptions.ConnectionError(
+ 'Unexpected status code {} in server response'.format(
+ r.status))
+ try:
+ p = payload.Payload(encoded_payload=await r.read())
+ except ValueError:
+ six.raise_from(exceptions.ConnectionError(
+ 'Unexpected response from server'), None)
+ open_packet = p.packets[0]
+ if open_packet.packet_type != packet.OPEN:
+ raise exceptions.ConnectionError(
+ 'OPEN packet not returned by server')
+ self.logger.info(
+ 'Polling connection accepted with ' + str(open_packet.data))
+ self.sid = open_packet.data['sid']
+ self.upgrades = open_packet.data['upgrades']
+ self.ping_interval = open_packet.data['pingInterval'] / 1000.0
+ self.ping_timeout = open_packet.data['pingTimeout'] / 1000.0
+ self.current_transport = 'polling'
+ self.base_url += '&sid=' + self.sid
+
+ self.state = 'connected'
+ client.connected_clients.append(self)
+ await self._trigger_event('connect', run_async=False)
+
+ for pkt in p.packets[1:]:
+ await self._receive_packet(pkt)
+
+ if 'websocket' in self.upgrades and 'websocket' in self.transports:
+ # attempt to upgrade to websocket
+ if await self._connect_websocket(url, headers, engineio_path):
+ # upgrade to websocket succeeded, we're done here
+ return
+
+ self.ping_loop_task = self.start_background_task(self._ping_loop)
+ self.write_loop_task = self.start_background_task(self._write_loop)
+ self.read_loop_task = self.start_background_task(
+ self._read_loop_polling)
+
+ async def _connect_websocket(self, url, headers, engineio_path):
+ """Establish or upgrade to a WebSocket connection with the server."""
+ if aiohttp is None: # pragma: no cover
+ self.logger.error('aiohttp package not installed')
+ return False
+ websocket_url = self._get_engineio_url(url, engineio_path,
+ 'websocket')
+ if self.sid:
+ self.logger.info(
+ 'Attempting WebSocket upgrade to ' + websocket_url)
+ upgrade = True
+ websocket_url += '&sid=' + self.sid
+ else:
+ upgrade = False
+ self.base_url = websocket_url
+ self.logger.info(
+ 'Attempting WebSocket connection to ' + websocket_url)
+
+ if self.http is None or self.http.closed: # pragma: no cover
+ self.http = aiohttp.ClientSession()
+
+ try:
+ if not self.ssl_verify:
+ ssl_context = ssl.create_default_context()
+ ssl_context.check_hostname = False
+ ssl_context.verify_mode = ssl.CERT_NONE
+ ws = await self.http.ws_connect(
+ websocket_url + self._get_url_timestamp(),
+ headers=headers, ssl=ssl_context)
+ else:
+ ws = await self.http.ws_connect(
+ websocket_url + self._get_url_timestamp(),
+ headers=headers)
+ except (aiohttp.client_exceptions.WSServerHandshakeError,
+ aiohttp.client_exceptions.ServerConnectionError):
+ if upgrade:
+ self.logger.warning(
+ 'WebSocket upgrade failed: connection error')
+ return False
+ else:
+ raise exceptions.ConnectionError('Connection error')
+ if upgrade:
+ p = packet.Packet(packet.PING, data='probe').encode(
+ always_bytes=False)
+ try:
+ await ws.send_str(p)
+ except Exception as e: # pragma: no cover
+ self.logger.warning(
+ 'WebSocket upgrade failed: unexpected send exception: %s',
+ str(e))
+ return False
+ try:
+ p = (await ws.receive()).data
+ except Exception as e: # pragma: no cover
+ self.logger.warning(
+ 'WebSocket upgrade failed: unexpected recv exception: %s',
+ str(e))
+ return False
+ pkt = packet.Packet(encoded_packet=p)
+ if pkt.packet_type != packet.PONG or pkt.data != 'probe':
+ self.logger.warning(
+ 'WebSocket upgrade failed: no PONG packet')
+ return False
+ p = packet.Packet(packet.UPGRADE).encode(always_bytes=False)
+ try:
+ await ws.send_str(p)
+ except Exception as e: # pragma: no cover
+ self.logger.warning(
+ 'WebSocket upgrade failed: unexpected send exception: %s',
+ str(e))
+ return False
+ self.current_transport = 'websocket'
+ self.logger.info('WebSocket upgrade was successful')
+ else:
+ try:
+ p = (await ws.receive()).data
+ except Exception as e: # pragma: no cover
+ raise exceptions.ConnectionError(
+ 'Unexpected recv exception: ' + str(e))
+ open_packet = packet.Packet(encoded_packet=p)
+ if open_packet.packet_type != packet.OPEN:
+ raise exceptions.ConnectionError('no OPEN packet')
+ self.logger.info(
+ 'WebSocket connection accepted with ' + str(open_packet.data))
+ self.sid = open_packet.data['sid']
+ self.upgrades = open_packet.data['upgrades']
+ self.ping_interval = open_packet.data['pingInterval'] / 1000.0
+ self.ping_timeout = open_packet.data['pingTimeout'] / 1000.0
+ self.current_transport = 'websocket'
+
+ self.state = 'connected'
+ client.connected_clients.append(self)
+ await self._trigger_event('connect', run_async=False)
+
+ self.ws = ws
+ self.ping_loop_task = self.start_background_task(self._ping_loop)
+ self.write_loop_task = self.start_background_task(self._write_loop)
+ self.read_loop_task = self.start_background_task(
+ self._read_loop_websocket)
+ return True
+
+ async def _receive_packet(self, pkt):
+ """Handle incoming packets from the server."""
+ packet_name = packet.packet_names[pkt.packet_type] \
+ if pkt.packet_type < len(packet.packet_names) else 'UNKNOWN'
+ self.logger.info(
+ 'Received packet %s data %s', packet_name,
+ pkt.data if not isinstance(pkt.data, bytes) else '<binary>')
+ if pkt.packet_type == packet.MESSAGE:
+ await self._trigger_event('message', pkt.data, run_async=True)
+ elif pkt.packet_type == packet.PONG:
+ self.pong_received = True
+ elif pkt.packet_type == packet.CLOSE:
+ await self.disconnect(abort=True)
+ elif pkt.packet_type == packet.NOOP:
+ pass
+ else:
+ self.logger.error('Received unexpected packet of type %s',
+ pkt.packet_type)
+
+ async def _send_packet(self, pkt):
+ """Queue a packet to be sent to the server."""
+ if self.state != 'connected':
+ return
+ await self.queue.put(pkt)
+ self.logger.info(
+ 'Sending packet %s data %s',
+ packet.packet_names[pkt.packet_type],
+ pkt.data if not isinstance(pkt.data, bytes) else '<binary>')
+
+ async def _send_request(
+ self, method, url, headers=None, body=None,
+ timeout=None): # pragma: no cover
+ if self.http is None or self.http.closed:
+ self.http = aiohttp.ClientSession()
+ http_method = getattr(self.http, method.lower())
+
+ try:
+ if not self.ssl_verify:
+ return await http_method(
+ url, headers=headers, data=body,
+ timeout=aiohttp.ClientTimeout(total=timeout), ssl=False)
+ else:
+ return await http_method(
+ url, headers=headers, data=body,
+ timeout=aiohttp.ClientTimeout(total=timeout))
+
+ except (aiohttp.ClientError, asyncio.TimeoutError) as exc:
+ self.logger.info('HTTP %s request to %s failed with error %s.',
+ method, url, exc)
+
+ async def _trigger_event(self, event, *args, **kwargs):
+ """Invoke an event handler."""
+ run_async = kwargs.pop('run_async', False)
+ ret = None
+ if event in self.handlers:
+ if asyncio.iscoroutinefunction(self.handlers[event]) is True:
+ if run_async:
+ return self.start_background_task(self.handlers[event],
+ *args)
+ else:
+ try:
+ ret = await self.handlers[event](*args)
+ except asyncio.CancelledError: # pragma: no cover
+ pass
+ except:
+ self.logger.exception(event + ' async handler error')
+ if event == 'connect':
+ # if connect handler raised error we reject the
+ # connection
+ return False
+ else:
+ if run_async:
+ async def async_handler():
+ return self.handlers[event](*args)
+
+ return self.start_background_task(async_handler)
+ else:
+ try:
+ ret = self.handlers[event](*args)
+ except:
+ self.logger.exception(event + ' handler error')
+ if event == 'connect':
+ # if connect handler raised error we reject the
+ # connection
+ return False
+ return ret
+
+ async def _ping_loop(self):
+ """This background task sends a PING to the server at the requested
+ interval.
+ """
+ self.pong_received = True
+ if self.ping_loop_event is None:
+ self.ping_loop_event = self.create_event()
+ else:
+ self.ping_loop_event.clear()
+ while self.state == 'connected':
+ if not self.pong_received:
+ self.logger.info(
+ 'PONG response has not been received, aborting')
+ if self.ws:
+ await self.ws.close()
+ await self.queue.put(None)
+ break
+ self.pong_received = False
+ await self._send_packet(packet.Packet(packet.PING))
+ try:
+ await asyncio.wait_for(self.ping_loop_event.wait(),
+ self.ping_interval)
+ except (asyncio.TimeoutError,
+ asyncio.CancelledError): # pragma: no cover
+ pass
+ self.logger.info('Exiting ping task')
+
+ async def _read_loop_polling(self):
+ """Read packets by polling the Engine.IO server."""
+ while self.state == 'connected':
+ self.logger.info(
+ 'Sending polling GET request to ' + self.base_url)
+ r = await self._send_request(
+ 'GET', self.base_url + self._get_url_timestamp(),
+ timeout=max(self.ping_interval, self.ping_timeout) + 5)
+ if r is None:
+ self.logger.warning(
+ 'Connection refused by the server, aborting')
+ await self.queue.put(None)
+ break
+ if r.status < 200 or r.status >= 300:
+ self.logger.warning('Unexpected status code %s in server '
+ 'response, aborting', r.status)
+ await self.queue.put(None)
+ break
+ try:
+ p = payload.Payload(encoded_payload=await r.read())
+ except ValueError:
+ self.logger.warning(
+ 'Unexpected packet from server, aborting')
+ await self.queue.put(None)
+ break
+ for pkt in p.packets:
+ await self._receive_packet(pkt)
+
+ self.logger.info('Waiting for write loop task to end')
+ await self.write_loop_task
+ self.logger.info('Waiting for ping loop task to end')
+ if self.ping_loop_event: # pragma: no cover
+ self.ping_loop_event.set()
+ await self.ping_loop_task
+ if self.state == 'connected':
+ await self._trigger_event('disconnect', run_async=False)
+ try:
+ client.connected_clients.remove(self)
+ except ValueError: # pragma: no cover
+ pass
+ self._reset()
+ self.logger.info('Exiting read loop task')
+
+ async def _read_loop_websocket(self):
+ """Read packets from the Engine.IO WebSocket connection."""
+ while self.state == 'connected':
+ p = None
+ try:
+ p = (await self.ws.receive()).data
+ if p is None: # pragma: no cover
+ raise RuntimeError('WebSocket read returned None')
+ except aiohttp.client_exceptions.ServerDisconnectedError:
+ self.logger.info(
+ 'Read loop: WebSocket connection was closed, aborting')
+ await self.queue.put(None)
+ break
+ except Exception as e:
+ self.logger.info(
+ 'Unexpected error "%s", aborting', str(e))
+ await self.queue.put(None)
+ break
+ if isinstance(p, six.text_type): # pragma: no cover
+ p = p.encode('utf-8')
+ pkt = packet.Packet(encoded_packet=p)
+ await self._receive_packet(pkt)
+
+ self.logger.info('Waiting for write loop task to end')
+ await self.write_loop_task
+ self.logger.info('Waiting for ping loop task to end')
+ if self.ping_loop_event: # pragma: no cover
+ self.ping_loop_event.set()
+ await self.ping_loop_task
+ if self.state == 'connected':
+ await self._trigger_event('disconnect', run_async=False)
+ try:
+ client.connected_clients.remove(self)
+ except ValueError: # pragma: no cover
+ pass
+ self._reset()
+ self.logger.info('Exiting read loop task')
+
+ async def _write_loop(self):
+ """This background task sends packages to the server as they are
+ pushed to the send queue.
+ """
+ while self.state == 'connected':
+ # to simplify the timeout handling, use the maximum of the
+ # ping interval and ping timeout as timeout, with an extra 5
+ # seconds grace period
+ timeout = max(self.ping_interval, self.ping_timeout) + 5
+ packets = None
+ try:
+ packets = [await asyncio.wait_for(self.queue.get(), timeout)]
+ except (self.queue.Empty, asyncio.TimeoutError,
+ asyncio.CancelledError):
+ self.logger.error('packet queue is empty, aborting')
+ break
+ if packets == [None]:
+ self.queue.task_done()
+ packets = []
+ else:
+ while True:
+ try:
+ packets.append(self.queue.get_nowait())
+ except self.queue.Empty:
+ break
+ if packets[-1] is None:
+ packets = packets[:-1]
+ self.queue.task_done()
+ break
+ if not packets:
+ # empty packet list returned -> connection closed
+ break
+ if self.current_transport == 'polling':
+ p = payload.Payload(packets=packets)
+ r = await self._send_request(
+ 'POST', self.base_url, body=p.encode(),
+ headers={'Content-Type': 'application/octet-stream'},
+ timeout=self.request_timeout)
+ for pkt in packets:
+ self.queue.task_done()
+ if r is None:
+ self.logger.warning(
+ 'Connection refused by the server, aborting')
+ break
+ if r.status < 200 or r.status >= 300:
+ self.logger.warning('Unexpected status code %s in server '
+ 'response, aborting', r.status)
+ self._reset()
+ break
+ else:
+ # websocket
+ try:
+ for pkt in packets:
+ if pkt.binary:
+ await self.ws.send_bytes(pkt.encode(
+ always_bytes=False))
+ else:
+ await self.ws.send_str(pkt.encode(
+ always_bytes=False))
+ self.queue.task_done()
+ except aiohttp.client_exceptions.ServerDisconnectedError:
+ self.logger.info(
+ 'Write loop: WebSocket connection was closed, '
+ 'aborting')
+ break
+ self.logger.info('Exiting write loop task')
diff --git a/libs/engineio/asyncio_server.py b/libs/engineio/asyncio_server.py
new file mode 100644
index 000000000..d52b556db
--- /dev/null
+++ b/libs/engineio/asyncio_server.py
@@ -0,0 +1,472 @@
+import asyncio
+
+import six
+from six.moves import urllib
+
+from . import exceptions
+from . import packet
+from . import server
+from . import asyncio_socket
+
+
+class AsyncServer(server.Server):
+ """An Engine.IO server for asyncio.
+
+ This class implements a fully compliant Engine.IO web server with support
+ for websocket and long-polling transports, compatible with the asyncio
+ framework on Python 3.5 or newer.
+
+ :param async_mode: The asynchronous model to use. See the Deployment
+ section in the documentation for a description of the
+ available options. Valid async modes are "aiohttp",
+ "sanic", "tornado" and "asgi". If this argument is not
+ given, "aiohttp" is tried first, followed by "sanic",
+ "tornado", and finally "asgi". The first async mode that
+ has all its dependencies installed is the one that is
+ chosen.
+ :param ping_timeout: The time in seconds that the client waits for the
+ server to respond before disconnecting.
+ :param ping_interval: The interval in seconds at which the client pings
+ the server. The default is 25 seconds. For advanced
+ control, a two element tuple can be given, where
+ the first number is the ping interval and the second
+ is a grace period added by the server. The default
+ grace period is 5 seconds.
+ :param max_http_buffer_size: The maximum size of a message when using the
+ polling transport.
+ :param allow_upgrades: Whether to allow transport upgrades or not.
+ :param http_compression: Whether to compress packages when using the
+ polling transport.
+ :param compression_threshold: Only compress messages when their byte size
+ is greater than this value.
+ :param cookie: Name of the HTTP cookie that contains the client session
+ id. If set to ``None``, a cookie is not sent to the client.
+ :param cors_allowed_origins: Origin or list of origins that are allowed to
+ connect to this server. Only the same origin
+ is allowed by default. Set this argument to
+ ``'*'`` to allow all origins, or to ``[]`` to
+ disable CORS handling.
+ :param cors_credentials: Whether credentials (cookies, authentication) are
+ allowed in requests to this server.
+ :param logger: To enable logging set to ``True`` or pass a logger object to
+ use. To disable logging set to ``False``.
+ :param json: An alternative json module to use for encoding and decoding
+ packets. Custom json modules must have ``dumps`` and ``loads``
+ functions that are compatible with the standard library
+ versions.
+ :param async_handlers: If set to ``True``, run message event handlers in
+ non-blocking threads. To run handlers synchronously,
+ set to ``False``. The default is ``True``.
+ :param kwargs: Reserved for future extensions, any additional parameters
+ given as keyword arguments will be silently ignored.
+ """
+ def is_asyncio_based(self):
+ return True
+
+ def async_modes(self):
+ return ['aiohttp', 'sanic', 'tornado', 'asgi']
+
+ def attach(self, app, engineio_path='engine.io'):
+ """Attach the Engine.IO server to an application."""
+ engineio_path = engineio_path.strip('/')
+ self._async['create_route'](app, self, '/{}/'.format(engineio_path))
+
+ async def send(self, sid, data, binary=None):
+ """Send a message to a client.
+
+ :param sid: The session id of the recipient client.
+ :param data: The data to send to the client. Data can be of type
+ ``str``, ``bytes``, ``list`` or ``dict``. If a ``list``
+ or ``dict``, the data will be serialized as JSON.
+ :param binary: ``True`` to send packet as binary, ``False`` to send
+ as text. If not given, unicode (Python 2) and str
+ (Python 3) are sent as text, and str (Python 2) and
+ bytes (Python 3) are sent as binary.
+
+ Note: this method is a coroutine.
+ """
+ try:
+ socket = self._get_socket(sid)
+ except KeyError:
+ # the socket is not available
+ self.logger.warning('Cannot send to sid %s', sid)
+ return
+ await socket.send(packet.Packet(packet.MESSAGE, data=data,
+ binary=binary))
+
+ async def get_session(self, sid):
+ """Return the user session for a client.
+
+ :param sid: The session id of the client.
+
+ The return value is a dictionary. Modifications made to this
+ dictionary are not guaranteed to be preserved. If you want to modify
+ the user session, use the ``session`` context manager instead.
+ """
+ socket = self._get_socket(sid)
+ return socket.session
+
+ async def save_session(self, sid, session):
+ """Store the user session for a client.
+
+ :param sid: The session id of the client.
+ :param session: The session dictionary.
+ """
+ socket = self._get_socket(sid)
+ socket.session = session
+
+ def session(self, sid):
+ """Return the user session for a client with context manager syntax.
+
+ :param sid: The session id of the client.
+
+ This is a context manager that returns the user session dictionary for
+ the client. Any changes that are made to this dictionary inside the
+ context manager block are saved back to the session. Example usage::
+
+ @eio.on('connect')
+ def on_connect(sid, environ):
+ username = authenticate_user(environ)
+ if not username:
+ return False
+ with eio.session(sid) as session:
+ session['username'] = username
+
+ @eio.on('message')
+ def on_message(sid, msg):
+ async with eio.session(sid) as session:
+ print('received message from ', session['username'])
+ """
+ class _session_context_manager(object):
+ def __init__(self, server, sid):
+ self.server = server
+ self.sid = sid
+ self.session = None
+
+ async def __aenter__(self):
+ self.session = await self.server.get_session(sid)
+ return self.session
+
+ async def __aexit__(self, *args):
+ await self.server.save_session(sid, self.session)
+
+ return _session_context_manager(self, sid)
+
+ async def disconnect(self, sid=None):
+ """Disconnect a client.
+
+ :param sid: The session id of the client to close. If this parameter
+ is not given, then all clients are closed.
+
+ Note: this method is a coroutine.
+ """
+ if sid is not None:
+ try:
+ socket = self._get_socket(sid)
+ except KeyError: # pragma: no cover
+ # the socket was already closed or gone
+ pass
+ else:
+ await socket.close()
+ if sid in self.sockets: # pragma: no cover
+ del self.sockets[sid]
+ else:
+ await asyncio.wait([client.close()
+ for client in six.itervalues(self.sockets)])
+ self.sockets = {}
+
+ async def handle_request(self, *args, **kwargs):
+ """Handle an HTTP request from the client.
+
+ This is the entry point of the Engine.IO application. This function
+ returns the HTTP response to deliver to the client.
+
+ Note: this method is a coroutine.
+ """
+ translate_request = self._async['translate_request']
+ if asyncio.iscoroutinefunction(translate_request):
+ environ = await translate_request(*args, **kwargs)
+ else:
+ environ = translate_request(*args, **kwargs)
+
+ if self.cors_allowed_origins != []:
+ # Validate the origin header if present
+ # This is important for WebSocket more than for HTTP, since
+ # browsers only apply CORS controls to HTTP.
+ origin = environ.get('HTTP_ORIGIN')
+ if origin:
+ allowed_origins = self._cors_allowed_origins(environ)
+ if allowed_origins is not None and origin not in \
+ allowed_origins:
+ self.logger.info(origin + ' is not an accepted origin.')
+ r = self._bad_request()
+ make_response = self._async['make_response']
+ if asyncio.iscoroutinefunction(make_response):
+ response = await make_response(
+ r['status'], r['headers'], r['response'], environ)
+ else:
+ response = make_response(r['status'], r['headers'],
+ r['response'], environ)
+ return response
+
+ method = environ['REQUEST_METHOD']
+ query = urllib.parse.parse_qs(environ.get('QUERY_STRING', ''))
+
+ sid = query['sid'][0] if 'sid' in query else None
+ b64 = False
+ jsonp = False
+ jsonp_index = None
+
+ if 'b64' in query:
+ if query['b64'][0] == "1" or query['b64'][0].lower() == "true":
+ b64 = True
+ if 'j' in query:
+ jsonp = True
+ try:
+ jsonp_index = int(query['j'][0])
+ except (ValueError, KeyError, IndexError):
+ # Invalid JSONP index number
+ pass
+
+ if jsonp and jsonp_index is None:
+ self.logger.warning('Invalid JSONP index number')
+ r = self._bad_request()
+ elif method == 'GET':
+ if sid is None:
+ transport = query.get('transport', ['polling'])[0]
+ if transport != 'polling' and transport != 'websocket':
+ self.logger.warning('Invalid transport %s', transport)
+ r = self._bad_request()
+ else:
+ r = await self._handle_connect(environ, transport,
+ b64, jsonp_index)
+ else:
+ if sid not in self.sockets:
+ self.logger.warning('Invalid session %s', sid)
+ r = self._bad_request()
+ else:
+ socket = self._get_socket(sid)
+ try:
+ packets = await socket.handle_get_request(environ)
+ if isinstance(packets, list):
+ r = self._ok(packets, b64=b64,
+ jsonp_index=jsonp_index)
+ else:
+ r = packets
+ except exceptions.EngineIOError:
+ if sid in self.sockets: # pragma: no cover
+ await self.disconnect(sid)
+ r = self._bad_request()
+ if sid in self.sockets and self.sockets[sid].closed:
+ del self.sockets[sid]
+ elif method == 'POST':
+ if sid is None or sid not in self.sockets:
+ self.logger.warning('Invalid session %s', sid)
+ r = self._bad_request()
+ else:
+ socket = self._get_socket(sid)
+ try:
+ await socket.handle_post_request(environ)
+ r = self._ok(jsonp_index=jsonp_index)
+ except exceptions.EngineIOError:
+ if sid in self.sockets: # pragma: no cover
+ await self.disconnect(sid)
+ r = self._bad_request()
+ except: # pragma: no cover
+ # for any other unexpected errors, we log the error
+ # and keep going
+ self.logger.exception('post request handler error')
+ r = self._ok(jsonp_index=jsonp_index)
+ elif method == 'OPTIONS':
+ r = self._ok()
+ else:
+ self.logger.warning('Method %s not supported', method)
+ r = self._method_not_found()
+ if not isinstance(r, dict):
+ return r
+ if self.http_compression and \
+ len(r['response']) >= self.compression_threshold:
+ encodings = [e.split(';')[0].strip() for e in
+ environ.get('HTTP_ACCEPT_ENCODING', '').split(',')]
+ for encoding in encodings:
+ if encoding in self.compression_methods:
+ r['response'] = \
+ getattr(self, '_' + encoding)(r['response'])
+ r['headers'] += [('Content-Encoding', encoding)]
+ break
+ cors_headers = self._cors_headers(environ)
+ make_response = self._async['make_response']
+ if asyncio.iscoroutinefunction(make_response):
+ response = await make_response(r['status'],
+ r['headers'] + cors_headers,
+ r['response'], environ)
+ else:
+ response = make_response(r['status'], r['headers'] + cors_headers,
+ r['response'], environ)
+ return response
+
+ def start_background_task(self, target, *args, **kwargs):
+ """Start a background task using the appropriate async model.
+
+ This is a utility function that applications can use to start a
+ background task using the method that is compatible with the
+ selected async mode.
+
+ :param target: the target function to execute.
+ :param args: arguments to pass to the function.
+ :param kwargs: keyword arguments to pass to the function.
+
+ The return value is a ``asyncio.Task`` object.
+ """
+ return asyncio.ensure_future(target(*args, **kwargs))
+
+ async def sleep(self, seconds=0):
+ """Sleep for the requested amount of time using the appropriate async
+ model.
+
+ This is a utility function that applications can use to put a task to
+ sleep without having to worry about using the correct call for the
+ selected async mode.
+
+ Note: this method is a coroutine.
+ """
+ return await asyncio.sleep(seconds)
+
+ def create_queue(self, *args, **kwargs):
+ """Create a queue object using the appropriate async model.
+
+ This is a utility function that applications can use to create a queue
+ without having to worry about using the correct call for the selected
+ async mode. For asyncio based async modes, this returns an instance of
+ ``asyncio.Queue``.
+ """
+ return asyncio.Queue(*args, **kwargs)
+
+ def get_queue_empty_exception(self):
+ """Return the queue empty exception for the appropriate async model.
+
+ This is a utility function that applications can use to work with a
+ queue without having to worry about using the correct call for the
+ selected async mode. For asyncio based async modes, this returns an
+ instance of ``asyncio.QueueEmpty``.
+ """
+ return asyncio.QueueEmpty
+
+ def create_event(self, *args, **kwargs):
+ """Create an event object using the appropriate async model.
+
+ This is a utility function that applications can use to create an
+ event without having to worry about using the correct call for the
+ selected async mode. For asyncio based async modes, this returns
+ an instance of ``asyncio.Event``.
+ """
+ return asyncio.Event(*args, **kwargs)
+
+ async def _handle_connect(self, environ, transport, b64=False,
+ jsonp_index=None):
+ """Handle a client connection request."""
+ if self.start_service_task:
+ # start the service task to monitor connected clients
+ self.start_service_task = False
+ self.start_background_task(self._service_task)
+
+ sid = self._generate_id()
+ s = asyncio_socket.AsyncSocket(self, sid)
+ self.sockets[sid] = s
+
+ pkt = packet.Packet(
+ packet.OPEN, {'sid': sid,
+ 'upgrades': self._upgrades(sid, transport),
+ 'pingTimeout': int(self.ping_timeout * 1000),
+ 'pingInterval': int(self.ping_interval * 1000)})
+ await s.send(pkt)
+
+ ret = await self._trigger_event('connect', sid, environ,
+ run_async=False)
+ if ret is False:
+ del self.sockets[sid]
+ self.logger.warning('Application rejected connection')
+ return self._unauthorized()
+
+ if transport == 'websocket':
+ ret = await s.handle_get_request(environ)
+ if s.closed:
+ # websocket connection ended, so we are done
+ del self.sockets[sid]
+ return ret
+ else:
+ s.connected = True
+ headers = None
+ if self.cookie:
+ headers = [('Set-Cookie', self.cookie + '=' + sid)]
+ try:
+ return self._ok(await s.poll(), headers=headers, b64=b64,
+ jsonp_index=jsonp_index)
+ except exceptions.QueueEmpty:
+ return self._bad_request()
+
+ async def _trigger_event(self, event, *args, **kwargs):
+ """Invoke an event handler."""
+ run_async = kwargs.pop('run_async', False)
+ ret = None
+ if event in self.handlers:
+ if asyncio.iscoroutinefunction(self.handlers[event]) is True:
+ if run_async:
+ return self.start_background_task(self.handlers[event],
+ *args)
+ else:
+ try:
+ ret = await self.handlers[event](*args)
+ except asyncio.CancelledError: # pragma: no cover
+ pass
+ except:
+ self.logger.exception(event + ' async handler error')
+ if event == 'connect':
+ # if connect handler raised error we reject the
+ # connection
+ return False
+ else:
+ if run_async:
+ async def async_handler():
+ return self.handlers[event](*args)
+
+ return self.start_background_task(async_handler)
+ else:
+ try:
+ ret = self.handlers[event](*args)
+ except:
+ self.logger.exception(event + ' handler error')
+ if event == 'connect':
+ # if connect handler raised error we reject the
+ # connection
+ return False
+ return ret
+
+ async def _service_task(self): # pragma: no cover
+ """Monitor connected clients and clean up those that time out."""
+ while True:
+ if len(self.sockets) == 0:
+ # nothing to do
+ await self.sleep(self.ping_timeout)
+ continue
+
+ # go through the entire client list in a ping interval cycle
+ sleep_interval = self.ping_timeout / len(self.sockets)
+
+ try:
+ # iterate over the current clients
+ for socket in self.sockets.copy().values():
+ if not socket.closing and not socket.closed:
+ await socket.check_ping_timeout()
+ await self.sleep(sleep_interval)
+ except (SystemExit, KeyboardInterrupt, asyncio.CancelledError):
+ self.logger.info('service task canceled')
+ break
+ except:
+ if asyncio.get_event_loop().is_closed():
+ self.logger.info('event loop is closed, exiting service '
+ 'task')
+ break
+
+ # an unexpected exception has occurred, log it and continue
+ self.logger.exception('service task exception')
diff --git a/libs/engineio/asyncio_socket.py b/libs/engineio/asyncio_socket.py
new file mode 100644
index 000000000..7057a6cc3
--- /dev/null
+++ b/libs/engineio/asyncio_socket.py
@@ -0,0 +1,236 @@
+import asyncio
+import six
+import sys
+import time
+
+from . import exceptions
+from . import packet
+from . import payload
+from . import socket
+
+
+class AsyncSocket(socket.Socket):
+ async def poll(self):
+ """Wait for packets to send to the client."""
+ try:
+ packets = [await asyncio.wait_for(self.queue.get(),
+ self.server.ping_timeout)]
+ self.queue.task_done()
+ except (asyncio.TimeoutError, asyncio.CancelledError):
+ raise exceptions.QueueEmpty()
+ if packets == [None]:
+ return []
+ try:
+ packets.append(self.queue.get_nowait())
+ self.queue.task_done()
+ except asyncio.QueueEmpty:
+ pass
+ return packets
+
+ async def receive(self, pkt):
+ """Receive packet from the client."""
+ self.server.logger.info('%s: Received packet %s data %s',
+ self.sid, packet.packet_names[pkt.packet_type],
+ pkt.data if not isinstance(pkt.data, bytes)
+ else '<binary>')
+ if pkt.packet_type == packet.PING:
+ self.last_ping = time.time()
+ await self.send(packet.Packet(packet.PONG, pkt.data))
+ elif pkt.packet_type == packet.MESSAGE:
+ await self.server._trigger_event(
+ 'message', self.sid, pkt.data,
+ run_async=self.server.async_handlers)
+ elif pkt.packet_type == packet.UPGRADE:
+ await self.send(packet.Packet(packet.NOOP))
+ elif pkt.packet_type == packet.CLOSE:
+ await self.close(wait=False, abort=True)
+ else:
+ raise exceptions.UnknownPacketError()
+
+ async def check_ping_timeout(self):
+ """Make sure the client is still sending pings.
+
+ This helps detect disconnections for long-polling clients.
+ """
+ if self.closed:
+ raise exceptions.SocketIsClosedError()
+ if time.time() - self.last_ping > self.server.ping_interval + \
+ self.server.ping_interval_grace_period:
+ self.server.logger.info('%s: Client is gone, closing socket',
+ self.sid)
+ # Passing abort=False here will cause close() to write a
+ # CLOSE packet. This has the effect of updating half-open sockets
+ # to their correct state of disconnected
+ await self.close(wait=False, abort=False)
+ return False
+ return True
+
+ async def send(self, pkt):
+ """Send a packet to the client."""
+ if not await self.check_ping_timeout():
+ return
+ if self.upgrading:
+ self.packet_backlog.append(pkt)
+ else:
+ await self.queue.put(pkt)
+ self.server.logger.info('%s: Sending packet %s data %s',
+ self.sid, packet.packet_names[pkt.packet_type],
+ pkt.data if not isinstance(pkt.data, bytes)
+ else '<binary>')
+
+ async def handle_get_request(self, environ):
+ """Handle a long-polling GET request from the client."""
+ connections = [
+ s.strip()
+ for s in environ.get('HTTP_CONNECTION', '').lower().split(',')]
+ transport = environ.get('HTTP_UPGRADE', '').lower()
+ if 'upgrade' in connections and transport in self.upgrade_protocols:
+ self.server.logger.info('%s: Received request to upgrade to %s',
+ self.sid, transport)
+ return await getattr(self, '_upgrade_' + transport)(environ)
+ try:
+ packets = await self.poll()
+ except exceptions.QueueEmpty:
+ exc = sys.exc_info()
+ await self.close(wait=False)
+ six.reraise(*exc)
+ return packets
+
+ async def handle_post_request(self, environ):
+ """Handle a long-polling POST request from the client."""
+ length = int(environ.get('CONTENT_LENGTH', '0'))
+ if length > self.server.max_http_buffer_size:
+ raise exceptions.ContentTooLongError()
+ else:
+ body = await environ['wsgi.input'].read(length)
+ p = payload.Payload(encoded_payload=body)
+ for pkt in p.packets:
+ await self.receive(pkt)
+
+ async def close(self, wait=True, abort=False):
+ """Close the socket connection."""
+ if not self.closed and not self.closing:
+ self.closing = True
+ await self.server._trigger_event('disconnect', self.sid)
+ if not abort:
+ await self.send(packet.Packet(packet.CLOSE))
+ self.closed = True
+ if wait:
+ await self.queue.join()
+
+ async def _upgrade_websocket(self, environ):
+ """Upgrade the connection from polling to websocket."""
+ if self.upgraded:
+ raise IOError('Socket has been upgraded already')
+ if self.server._async['websocket'] is None:
+ # the selected async mode does not support websocket
+ return self.server._bad_request()
+ ws = self.server._async['websocket'](self._websocket_handler)
+ return await ws(environ)
+
+ async def _websocket_handler(self, ws):
+ """Engine.IO handler for websocket transport."""
+ if self.connected:
+ # the socket was already connected, so this is an upgrade
+ self.upgrading = True # hold packet sends during the upgrade
+
+ try:
+ pkt = await ws.wait()
+ except IOError: # pragma: no cover
+ return
+ decoded_pkt = packet.Packet(encoded_packet=pkt)
+ if decoded_pkt.packet_type != packet.PING or \
+ decoded_pkt.data != 'probe':
+ self.server.logger.info(
+ '%s: Failed websocket upgrade, no PING packet', self.sid)
+ return
+ await ws.send(packet.Packet(
+ packet.PONG,
+ data=six.text_type('probe')).encode(always_bytes=False))
+ await self.queue.put(packet.Packet(packet.NOOP)) # end poll
+
+ try:
+ pkt = await ws.wait()
+ except IOError: # pragma: no cover
+ return
+ decoded_pkt = packet.Packet(encoded_packet=pkt)
+ if decoded_pkt.packet_type != packet.UPGRADE:
+ self.upgraded = False
+ self.server.logger.info(
+ ('%s: Failed websocket upgrade, expected UPGRADE packet, '
+ 'received %s instead.'),
+ self.sid, pkt)
+ return
+ self.upgraded = True
+
+ # flush any packets that were sent during the upgrade
+ for pkt in self.packet_backlog:
+ await self.queue.put(pkt)
+ self.packet_backlog = []
+ self.upgrading = False
+ else:
+ self.connected = True
+ self.upgraded = True
+
+ # start separate writer thread
+ async def writer():
+ while True:
+ packets = None
+ try:
+ packets = await self.poll()
+ except exceptions.QueueEmpty:
+ break
+ if not packets:
+ # empty packet list returned -> connection closed
+ break
+ try:
+ for pkt in packets:
+ await ws.send(pkt.encode(always_bytes=False))
+ except:
+ break
+ writer_task = asyncio.ensure_future(writer())
+
+ self.server.logger.info(
+ '%s: Upgrade to websocket successful', self.sid)
+
+ while True:
+ p = None
+ wait_task = asyncio.ensure_future(ws.wait())
+ try:
+ p = await asyncio.wait_for(wait_task, self.server.ping_timeout)
+ except asyncio.CancelledError: # pragma: no cover
+ # there is a bug (https://bugs.python.org/issue30508) in
+ # asyncio that causes a "Task exception never retrieved" error
+ # to appear when wait_task raises an exception before it gets
+ # cancelled. Calling wait_task.exception() prevents the error
+ # from being issued in Python 3.6, but causes other errors in
+ # other versions, so we run it with all errors suppressed and
+ # hope for the best.
+ try:
+ wait_task.exception()
+ except:
+ pass
+ break
+ except:
+ break
+ if p is None:
+ # connection closed by client
+ break
+ if isinstance(p, six.text_type): # pragma: no cover
+ p = p.encode('utf-8')
+ pkt = packet.Packet(encoded_packet=p)
+ try:
+ await self.receive(pkt)
+ except exceptions.UnknownPacketError: # pragma: no cover
+ pass
+ except exceptions.SocketIsClosedError: # pragma: no cover
+ self.server.logger.info('Receive error -- socket is closed')
+ break
+ except: # pragma: no cover
+ # if we get an unexpected exception we log the error and exit
+ # the connection properly
+ self.server.logger.exception('Unknown receive error')
+
+ await self.queue.put(None) # unlock the writer task so it can exit
+ await asyncio.wait_for(writer_task, timeout=None)
+ await self.close(wait=False, abort=True)
diff --git a/libs/engineio/client.py b/libs/engineio/client.py
new file mode 100644
index 000000000..b5ab50377
--- /dev/null
+++ b/libs/engineio/client.py
@@ -0,0 +1,680 @@
+import logging
+try:
+ import queue
+except ImportError: # pragma: no cover
+ import Queue as queue
+import signal
+import ssl
+import threading
+import time
+
+import six
+from six.moves import urllib
+try:
+ import requests
+except ImportError: # pragma: no cover
+ requests = None
+try:
+ import websocket
+except ImportError: # pragma: no cover
+ websocket = None
+from . import exceptions
+from . import packet
+from . import payload
+
+default_logger = logging.getLogger('engineio.client')
+connected_clients = []
+
+if six.PY2: # pragma: no cover
+ ConnectionError = OSError
+
+
+def signal_handler(sig, frame):
+ """SIGINT handler.
+
+ Disconnect all active clients and then invoke the original signal handler.
+ """
+ for client in connected_clients[:]:
+ if client.is_asyncio_based():
+ client.start_background_task(client.disconnect, abort=True)
+ else:
+ client.disconnect(abort=True)
+ if callable(original_signal_handler):
+ return original_signal_handler(sig, frame)
+ else: # pragma: no cover
+ # Handle case where no original SIGINT handler was present.
+ return signal.default_int_handler(sig, frame)
+
+
+original_signal_handler = None
+
+
+class Client(object):
+ """An Engine.IO client.
+
+ This class implements a fully compliant Engine.IO web client with support
+ for websocket and long-polling transports.
+
+ :param logger: To enable logging set to ``True`` or pass a logger object to
+ use. To disable logging set to ``False``. The default is
+ ``False``.
+ :param json: An alternative json module to use for encoding and decoding
+ packets. Custom json modules must have ``dumps`` and ``loads``
+ functions that are compatible with the standard library
+ versions.
+ :param request_timeout: A timeout in seconds for requests. The default is
+ 5 seconds.
+ :param ssl_verify: ``True`` to verify SSL certificates, or ``False`` to
+ skip SSL certificate verification, allowing
+ connections to servers with self signed certificates.
+ The default is ``True``.
+ """
+ event_names = ['connect', 'disconnect', 'message']
+
+ def __init__(self,
+ logger=False,
+ json=None,
+ request_timeout=5,
+ ssl_verify=True):
+ global original_signal_handler
+ if original_signal_handler is None:
+ original_signal_handler = signal.signal(signal.SIGINT,
+ signal_handler)
+ self.handlers = {}
+ self.base_url = None
+ self.transports = None
+ self.current_transport = None
+ self.sid = None
+ self.upgrades = None
+ self.ping_interval = None
+ self.ping_timeout = None
+ self.pong_received = True
+ self.http = None
+ self.ws = None
+ self.read_loop_task = None
+ self.write_loop_task = None
+ self.ping_loop_task = None
+ self.ping_loop_event = None
+ self.queue = None
+ self.state = 'disconnected'
+ self.ssl_verify = ssl_verify
+
+ if json is not None:
+ packet.Packet.json = json
+ if not isinstance(logger, bool):
+ self.logger = logger
+ else:
+ self.logger = default_logger
+ if not logging.root.handlers and \
+ self.logger.level == logging.NOTSET:
+ if logger:
+ self.logger.setLevel(logging.INFO)
+ else:
+ self.logger.setLevel(logging.ERROR)
+ self.logger.addHandler(logging.StreamHandler())
+
+ self.request_timeout = request_timeout
+
+ def is_asyncio_based(self):
+ return False
+
+ def on(self, event, handler=None):
+ """Register an event handler.
+
+ :param event: The event name. Can be ``'connect'``, ``'message'`` or
+ ``'disconnect'``.
+ :param handler: The function that should be invoked to handle the
+ event. When this parameter is not given, the method
+ acts as a decorator for the handler function.
+
+ Example usage::
+
+ # as a decorator:
+ @eio.on('connect')
+ def connect_handler():
+ print('Connection request')
+
+ # as a method:
+ def message_handler(msg):
+ print('Received message: ', msg)
+ eio.send('response')
+ eio.on('message', message_handler)
+ """
+ if event not in self.event_names:
+ raise ValueError('Invalid event')
+
+ def set_handler(handler):
+ self.handlers[event] = handler
+ return handler
+
+ if handler is None:
+ return set_handler
+ set_handler(handler)
+
+ def connect(self, url, headers={}, transports=None,
+ engineio_path='engine.io'):
+ """Connect to an Engine.IO server.
+
+ :param url: The URL of the Engine.IO server. It can include custom
+ query string parameters if required by the server.
+ :param headers: A dictionary with custom headers to send with the
+ connection request.
+ :param transports: The list of allowed transports. Valid transports
+ are ``'polling'`` and ``'websocket'``. If not
+ given, the polling transport is connected first,
+ then an upgrade to websocket is attempted.
+ :param engineio_path: The endpoint where the Engine.IO server is
+ installed. The default value is appropriate for
+ most cases.
+
+ Example usage::
+
+ eio = engineio.Client()
+ eio.connect('http://localhost:5000')
+ """
+ if self.state != 'disconnected':
+ raise ValueError('Client is not in a disconnected state')
+ valid_transports = ['polling', 'websocket']
+ if transports is not None:
+ if isinstance(transports, six.string_types):
+ transports = [transports]
+ transports = [transport for transport in transports
+ if transport in valid_transports]
+ if not transports:
+ raise ValueError('No valid transports provided')
+ self.transports = transports or valid_transports
+ self.queue = self.create_queue()
+ return getattr(self, '_connect_' + self.transports[0])(
+ url, headers, engineio_path)
+
+ def wait(self):
+ """Wait until the connection with the server ends.
+
+ Client applications can use this function to block the main thread
+ during the life of the connection.
+ """
+ if self.read_loop_task:
+ self.read_loop_task.join()
+
+ def send(self, data, binary=None):
+ """Send a message to a client.
+
+ :param data: The data to send to the client. Data can be of type
+ ``str``, ``bytes``, ``list`` or ``dict``. If a ``list``
+ or ``dict``, the data will be serialized as JSON.
+ :param binary: ``True`` to send packet as binary, ``False`` to send
+ as text. If not given, unicode (Python 2) and str
+ (Python 3) are sent as text, and str (Python 2) and
+ bytes (Python 3) are sent as binary.
+ """
+ self._send_packet(packet.Packet(packet.MESSAGE, data=data,
+ binary=binary))
+
+ def disconnect(self, abort=False):
+ """Disconnect from the server.
+
+ :param abort: If set to ``True``, do not wait for background tasks
+ associated with the connection to end.
+ """
+ if self.state == 'connected':
+ self._send_packet(packet.Packet(packet.CLOSE))
+ self.queue.put(None)
+ self.state = 'disconnecting'
+ self._trigger_event('disconnect', run_async=False)
+ if self.current_transport == 'websocket':
+ self.ws.close()
+ if not abort:
+ self.read_loop_task.join()
+ self.state = 'disconnected'
+ try:
+ connected_clients.remove(self)
+ except ValueError: # pragma: no cover
+ pass
+ self._reset()
+
+ def transport(self):
+ """Return the name of the transport currently in use.
+
+ The possible values returned by this function are ``'polling'`` and
+ ``'websocket'``.
+ """
+ return self.current_transport
+
+ def start_background_task(self, target, *args, **kwargs):
+ """Start a background task.
+
+ This is a utility function that applications can use to start a
+ background task.
+
+ :param target: the target function to execute.
+ :param args: arguments to pass to the function.
+ :param kwargs: keyword arguments to pass to the function.
+
+ This function returns an object compatible with the `Thread` class in
+ the Python standard library. The `start()` method on this object is
+ already called by this function.
+ """
+ th = threading.Thread(target=target, args=args, kwargs=kwargs)
+ th.start()
+ return th
+
+ def sleep(self, seconds=0):
+ """Sleep for the requested amount of time."""
+ return time.sleep(seconds)
+
+ def create_queue(self, *args, **kwargs):
+ """Create a queue object."""
+ q = queue.Queue(*args, **kwargs)
+ q.Empty = queue.Empty
+ return q
+
+ def create_event(self, *args, **kwargs):
+ """Create an event object."""
+ return threading.Event(*args, **kwargs)
+
+ def _reset(self):
+ self.state = 'disconnected'
+ self.sid = None
+
+ def _connect_polling(self, url, headers, engineio_path):
+ """Establish a long-polling connection to the Engine.IO server."""
+ if requests is None: # pragma: no cover
+ # not installed
+ self.logger.error('requests package is not installed -- cannot '
+ 'send HTTP requests!')
+ return
+ self.base_url = self._get_engineio_url(url, engineio_path, 'polling')
+ self.logger.info('Attempting polling connection to ' + self.base_url)
+ r = self._send_request(
+ 'GET', self.base_url + self._get_url_timestamp(), headers=headers,
+ timeout=self.request_timeout)
+ if r is None:
+ self._reset()
+ raise exceptions.ConnectionError(
+ 'Connection refused by the server')
+ if r.status_code < 200 or r.status_code >= 300:
+ raise exceptions.ConnectionError(
+ 'Unexpected status code {} in server response'.format(
+ r.status_code))
+ try:
+ p = payload.Payload(encoded_payload=r.content)
+ except ValueError:
+ six.raise_from(exceptions.ConnectionError(
+ 'Unexpected response from server'), None)
+ open_packet = p.packets[0]
+ if open_packet.packet_type != packet.OPEN:
+ raise exceptions.ConnectionError(
+ 'OPEN packet not returned by server')
+ self.logger.info(
+ 'Polling connection accepted with ' + str(open_packet.data))
+ self.sid = open_packet.data['sid']
+ self.upgrades = open_packet.data['upgrades']
+ self.ping_interval = open_packet.data['pingInterval'] / 1000.0
+ self.ping_timeout = open_packet.data['pingTimeout'] / 1000.0
+ self.current_transport = 'polling'
+ self.base_url += '&sid=' + self.sid
+
+ self.state = 'connected'
+ connected_clients.append(self)
+ self._trigger_event('connect', run_async=False)
+
+ for pkt in p.packets[1:]:
+ self._receive_packet(pkt)
+
+ if 'websocket' in self.upgrades and 'websocket' in self.transports:
+ # attempt to upgrade to websocket
+ if self._connect_websocket(url, headers, engineio_path):
+ # upgrade to websocket succeeded, we're done here
+ return
+
+ # start background tasks associated with this client
+ self.ping_loop_task = self.start_background_task(self._ping_loop)
+ self.write_loop_task = self.start_background_task(self._write_loop)
+ self.read_loop_task = self.start_background_task(
+ self._read_loop_polling)
+
+ def _connect_websocket(self, url, headers, engineio_path):
+ """Establish or upgrade to a WebSocket connection with the server."""
+ if websocket is None: # pragma: no cover
+ # not installed
+ self.logger.warning('websocket-client package not installed, only '
+ 'polling transport is available')
+ return False
+ websocket_url = self._get_engineio_url(url, engineio_path, 'websocket')
+ if self.sid:
+ self.logger.info(
+ 'Attempting WebSocket upgrade to ' + websocket_url)
+ upgrade = True
+ websocket_url += '&sid=' + self.sid
+ else:
+ upgrade = False
+ self.base_url = websocket_url
+ self.logger.info(
+ 'Attempting WebSocket connection to ' + websocket_url)
+
+ # get the cookies from the long-polling connection so that they can
+ # also be sent the the WebSocket route
+ cookies = None
+ if self.http:
+ cookies = '; '.join(["{}={}".format(cookie.name, cookie.value)
+ for cookie in self.http.cookies])
+
+ try:
+ if not self.ssl_verify:
+ ws = websocket.create_connection(
+ websocket_url + self._get_url_timestamp(), header=headers,
+ cookie=cookies, sslopt={"cert_reqs": ssl.CERT_NONE})
+ else:
+ ws = websocket.create_connection(
+ websocket_url + self._get_url_timestamp(), header=headers,
+ cookie=cookies)
+ except (ConnectionError, IOError, websocket.WebSocketException):
+ if upgrade:
+ self.logger.warning(
+ 'WebSocket upgrade failed: connection error')
+ return False
+ else:
+ raise exceptions.ConnectionError('Connection error')
+ if upgrade:
+ p = packet.Packet(packet.PING,
+ data=six.text_type('probe')).encode()
+ try:
+ ws.send(p)
+ except Exception as e: # pragma: no cover
+ self.logger.warning(
+ 'WebSocket upgrade failed: unexpected send exception: %s',
+ str(e))
+ return False
+ try:
+ p = ws.recv()
+ except Exception as e: # pragma: no cover
+ self.logger.warning(
+ 'WebSocket upgrade failed: unexpected recv exception: %s',
+ str(e))
+ return False
+ pkt = packet.Packet(encoded_packet=p)
+ if pkt.packet_type != packet.PONG or pkt.data != 'probe':
+ self.logger.warning(
+ 'WebSocket upgrade failed: no PONG packet')
+ return False
+ p = packet.Packet(packet.UPGRADE).encode()
+ try:
+ ws.send(p)
+ except Exception as e: # pragma: no cover
+ self.logger.warning(
+ 'WebSocket upgrade failed: unexpected send exception: %s',
+ str(e))
+ return False
+ self.current_transport = 'websocket'
+ self.logger.info('WebSocket upgrade was successful')
+ else:
+ try:
+ p = ws.recv()
+ except Exception as e: # pragma: no cover
+ raise exceptions.ConnectionError(
+ 'Unexpected recv exception: ' + str(e))
+ open_packet = packet.Packet(encoded_packet=p)
+ if open_packet.packet_type != packet.OPEN:
+ raise exceptions.ConnectionError('no OPEN packet')
+ self.logger.info(
+ 'WebSocket connection accepted with ' + str(open_packet.data))
+ self.sid = open_packet.data['sid']
+ self.upgrades = open_packet.data['upgrades']
+ self.ping_interval = open_packet.data['pingInterval'] / 1000.0
+ self.ping_timeout = open_packet.data['pingTimeout'] / 1000.0
+ self.current_transport = 'websocket'
+
+ self.state = 'connected'
+ connected_clients.append(self)
+ self._trigger_event('connect', run_async=False)
+ self.ws = ws
+
+ # start background tasks associated with this client
+ self.ping_loop_task = self.start_background_task(self._ping_loop)
+ self.write_loop_task = self.start_background_task(self._write_loop)
+ self.read_loop_task = self.start_background_task(
+ self._read_loop_websocket)
+ return True
+
+ def _receive_packet(self, pkt):
+ """Handle incoming packets from the server."""
+ packet_name = packet.packet_names[pkt.packet_type] \
+ if pkt.packet_type < len(packet.packet_names) else 'UNKNOWN'
+ self.logger.info(
+ 'Received packet %s data %s', packet_name,
+ pkt.data if not isinstance(pkt.data, bytes) else '<binary>')
+ if pkt.packet_type == packet.MESSAGE:
+ self._trigger_event('message', pkt.data, run_async=True)
+ elif pkt.packet_type == packet.PONG:
+ self.pong_received = True
+ elif pkt.packet_type == packet.CLOSE:
+ self.disconnect(abort=True)
+ elif pkt.packet_type == packet.NOOP:
+ pass
+ else:
+ self.logger.error('Received unexpected packet of type %s',
+ pkt.packet_type)
+
+ def _send_packet(self, pkt):
+ """Queue a packet to be sent to the server."""
+ if self.state != 'connected':
+ return
+ self.queue.put(pkt)
+ self.logger.info(
+ 'Sending packet %s data %s',
+ packet.packet_names[pkt.packet_type],
+ pkt.data if not isinstance(pkt.data, bytes) else '<binary>')
+
+ def _send_request(
+ self, method, url, headers=None, body=None,
+ timeout=None): # pragma: no cover
+ if self.http is None:
+ self.http = requests.Session()
+ try:
+ return self.http.request(method, url, headers=headers, data=body,
+ timeout=timeout, verify=self.ssl_verify)
+ except requests.exceptions.RequestException as exc:
+ self.logger.info('HTTP %s request to %s failed with error %s.',
+ method, url, exc)
+
+ def _trigger_event(self, event, *args, **kwargs):
+ """Invoke an event handler."""
+ run_async = kwargs.pop('run_async', False)
+ if event in self.handlers:
+ if run_async:
+ return self.start_background_task(self.handlers[event], *args)
+ else:
+ try:
+ return self.handlers[event](*args)
+ except:
+ self.logger.exception(event + ' handler error')
+
+ def _get_engineio_url(self, url, engineio_path, transport):
+ """Generate the Engine.IO connection URL."""
+ engineio_path = engineio_path.strip('/')
+ parsed_url = urllib.parse.urlparse(url)
+
+ if transport == 'polling':
+ scheme = 'http'
+ elif transport == 'websocket':
+ scheme = 'ws'
+ else: # pragma: no cover
+ raise ValueError('invalid transport')
+ if parsed_url.scheme in ['https', 'wss']:
+ scheme += 's'
+
+ return ('{scheme}://{netloc}/{path}/?{query}'
+ '{sep}transport={transport}&EIO=3').format(
+ scheme=scheme, netloc=parsed_url.netloc,
+ path=engineio_path, query=parsed_url.query,
+ sep='&' if parsed_url.query else '',
+ transport=transport)
+
+ def _get_url_timestamp(self):
+ """Generate the Engine.IO query string timestamp."""
+ return '&t=' + str(time.time())
+
+ def _ping_loop(self):
+ """This background task sends a PING to the server at the requested
+ interval.
+ """
+ self.pong_received = True
+ if self.ping_loop_event is None:
+ self.ping_loop_event = self.create_event()
+ else:
+ self.ping_loop_event.clear()
+ while self.state == 'connected':
+ if not self.pong_received:
+ self.logger.info(
+ 'PONG response has not been received, aborting')
+ if self.ws:
+ self.ws.close(timeout=0)
+ self.queue.put(None)
+ break
+ self.pong_received = False
+ self._send_packet(packet.Packet(packet.PING))
+ self.ping_loop_event.wait(timeout=self.ping_interval)
+ self.logger.info('Exiting ping task')
+
+ def _read_loop_polling(self):
+ """Read packets by polling the Engine.IO server."""
+ while self.state == 'connected':
+ self.logger.info(
+ 'Sending polling GET request to ' + self.base_url)
+ r = self._send_request(
+ 'GET', self.base_url + self._get_url_timestamp(),
+ timeout=max(self.ping_interval, self.ping_timeout) + 5)
+ if r is None:
+ self.logger.warning(
+ 'Connection refused by the server, aborting')
+ self.queue.put(None)
+ break
+ if r.status_code < 200 or r.status_code >= 300:
+ self.logger.warning('Unexpected status code %s in server '
+ 'response, aborting', r.status_code)
+ self.queue.put(None)
+ break
+ try:
+ p = payload.Payload(encoded_payload=r.content)
+ except ValueError:
+ self.logger.warning(
+ 'Unexpected packet from server, aborting')
+ self.queue.put(None)
+ break
+ for pkt in p.packets:
+ self._receive_packet(pkt)
+
+ self.logger.info('Waiting for write loop task to end')
+ self.write_loop_task.join()
+ self.logger.info('Waiting for ping loop task to end')
+ if self.ping_loop_event: # pragma: no cover
+ self.ping_loop_event.set()
+ self.ping_loop_task.join()
+ if self.state == 'connected':
+ self._trigger_event('disconnect', run_async=False)
+ try:
+ connected_clients.remove(self)
+ except ValueError: # pragma: no cover
+ pass
+ self._reset()
+ self.logger.info('Exiting read loop task')
+
+ def _read_loop_websocket(self):
+ """Read packets from the Engine.IO WebSocket connection."""
+ while self.state == 'connected':
+ p = None
+ try:
+ p = self.ws.recv()
+ except websocket.WebSocketConnectionClosedException:
+ self.logger.warning(
+ 'WebSocket connection was closed, aborting')
+ self.queue.put(None)
+ break
+ except Exception as e:
+ self.logger.info(
+ 'Unexpected error "%s", aborting', str(e))
+ self.queue.put(None)
+ break
+ if isinstance(p, six.text_type): # pragma: no cover
+ p = p.encode('utf-8')
+ pkt = packet.Packet(encoded_packet=p)
+ self._receive_packet(pkt)
+
+ self.logger.info('Waiting for write loop task to end')
+ self.write_loop_task.join()
+ self.logger.info('Waiting for ping loop task to end')
+ if self.ping_loop_event: # pragma: no cover
+ self.ping_loop_event.set()
+ self.ping_loop_task.join()
+ if self.state == 'connected':
+ self._trigger_event('disconnect', run_async=False)
+ try:
+ connected_clients.remove(self)
+ except ValueError: # pragma: no cover
+ pass
+ self._reset()
+ self.logger.info('Exiting read loop task')
+
+ def _write_loop(self):
+ """This background task sends packages to the server as they are
+ pushed to the send queue.
+ """
+ while self.state == 'connected':
+ # to simplify the timeout handling, use the maximum of the
+ # ping interval and ping timeout as timeout, with an extra 5
+ # seconds grace period
+ timeout = max(self.ping_interval, self.ping_timeout) + 5
+ packets = None
+ try:
+ packets = [self.queue.get(timeout=timeout)]
+ except self.queue.Empty:
+ self.logger.error('packet queue is empty, aborting')
+ break
+ if packets == [None]:
+ self.queue.task_done()
+ packets = []
+ else:
+ while True:
+ try:
+ packets.append(self.queue.get(block=False))
+ except self.queue.Empty:
+ break
+ if packets[-1] is None:
+ packets = packets[:-1]
+ self.queue.task_done()
+ break
+ if not packets:
+ # empty packet list returned -> connection closed
+ break
+ if self.current_transport == 'polling':
+ p = payload.Payload(packets=packets)
+ r = self._send_request(
+ 'POST', self.base_url, body=p.encode(),
+ headers={'Content-Type': 'application/octet-stream'},
+ timeout=self.request_timeout)
+ for pkt in packets:
+ self.queue.task_done()
+ if r is None:
+ self.logger.warning(
+ 'Connection refused by the server, aborting')
+ break
+ if r.status_code < 200 or r.status_code >= 300:
+ self.logger.warning('Unexpected status code %s in server '
+ 'response, aborting', r.status_code)
+ self._reset()
+ break
+ else:
+ # websocket
+ try:
+ for pkt in packets:
+ encoded_packet = pkt.encode(always_bytes=False)
+ if pkt.binary:
+ self.ws.send_binary(encoded_packet)
+ else:
+ self.ws.send(encoded_packet)
+ self.queue.task_done()
+ except websocket.WebSocketConnectionClosedException:
+ self.logger.warning(
+ 'WebSocket connection was closed, aborting')
+ break
+ self.logger.info('Exiting write loop task')
diff --git a/libs/engineio/exceptions.py b/libs/engineio/exceptions.py
new file mode 100644
index 000000000..fb0b3e057
--- /dev/null
+++ b/libs/engineio/exceptions.py
@@ -0,0 +1,22 @@
+class EngineIOError(Exception):
+ pass
+
+
+class ContentTooLongError(EngineIOError):
+ pass
+
+
+class UnknownPacketError(EngineIOError):
+ pass
+
+
+class QueueEmpty(EngineIOError):
+ pass
+
+
+class SocketIsClosedError(EngineIOError):
+ pass
+
+
+class ConnectionError(EngineIOError):
+ pass
diff --git a/libs/engineio/middleware.py b/libs/engineio/middleware.py
new file mode 100644
index 000000000..d0bdcc747
--- /dev/null
+++ b/libs/engineio/middleware.py
@@ -0,0 +1,87 @@
+import os
+from engineio.static_files import get_static_file
+
+
+class WSGIApp(object):
+ """WSGI application middleware for Engine.IO.
+
+ This middleware dispatches traffic to an Engine.IO application. It can
+ also serve a list of static files to the client, or forward unrelated
+ HTTP traffic to another WSGI application.
+
+ :param engineio_app: The Engine.IO server. Must be an instance of the
+ ``engineio.Server`` class.
+ :param wsgi_app: The WSGI app that receives all other traffic.
+ :param static_files: A dictionary with static file mapping rules. See the
+ documentation for details on this argument.
+ :param engineio_path: The endpoint where the Engine.IO application should
+ be installed. The default value is appropriate for
+ most cases.
+
+ Example usage::
+
+ import engineio
+ import eventlet
+
+ eio = engineio.Server()
+ app = engineio.WSGIApp(eio, static_files={
+ '/': {'content_type': 'text/html', 'filename': 'index.html'},
+ '/index.html': {'content_type': 'text/html',
+ 'filename': 'index.html'},
+ })
+ eventlet.wsgi.server(eventlet.listen(('', 8000)), app)
+ """
+ def __init__(self, engineio_app, wsgi_app=None, static_files=None,
+ engineio_path='engine.io'):
+ self.engineio_app = engineio_app
+ self.wsgi_app = wsgi_app
+ self.engineio_path = engineio_path.strip('/')
+ self.static_files = static_files or {}
+
+ def __call__(self, environ, start_response):
+ if 'gunicorn.socket' in environ:
+ # gunicorn saves the socket under environ['gunicorn.socket'], while
+ # eventlet saves it under environ['eventlet.input']. Eventlet also
+ # stores the socket inside a wrapper class, while gunicon writes it
+ # directly into the environment. To give eventlet's WebSocket
+ # module access to this socket when running under gunicorn, here we
+ # copy the socket to the eventlet format.
+ class Input(object):
+ def __init__(self, socket):
+ self.socket = socket
+
+ def get_socket(self):
+ return self.socket
+
+ environ['eventlet.input'] = Input(environ['gunicorn.socket'])
+ path = environ['PATH_INFO']
+ if path is not None and \
+ path.startswith('/{0}/'.format(self.engineio_path)):
+ return self.engineio_app.handle_request(environ, start_response)
+ else:
+ static_file = get_static_file(path, self.static_files) \
+ if self.static_files else None
+ if static_file:
+ if os.path.exists(static_file['filename']):
+ start_response(
+ '200 OK',
+ [('Content-Type', static_file['content_type'])])
+ with open(static_file['filename'], 'rb') as f:
+ return [f.read()]
+ else:
+ return self.not_found(start_response)
+ elif self.wsgi_app is not None:
+ return self.wsgi_app(environ, start_response)
+ return self.not_found(start_response)
+
+ def not_found(self, start_response):
+ start_response("404 Not Found", [('Content-Type', 'text/plain')])
+ return [b'Not Found']
+
+
+class Middleware(WSGIApp):
+ """This class has been renamed to ``WSGIApp`` and is now deprecated."""
+ def __init__(self, engineio_app, wsgi_app=None,
+ engineio_path='engine.io'):
+ super(Middleware, self).__init__(engineio_app, wsgi_app,
+ engineio_path=engineio_path)
diff --git a/libs/engineio/packet.py b/libs/engineio/packet.py
new file mode 100644
index 000000000..a3aa6d476
--- /dev/null
+++ b/libs/engineio/packet.py
@@ -0,0 +1,92 @@
+import base64
+import json as _json
+
+import six
+
+(OPEN, CLOSE, PING, PONG, MESSAGE, UPGRADE, NOOP) = (0, 1, 2, 3, 4, 5, 6)
+packet_names = ['OPEN', 'CLOSE', 'PING', 'PONG', 'MESSAGE', 'UPGRADE', 'NOOP']
+
+binary_types = (six.binary_type, bytearray)
+
+
+class Packet(object):
+ """Engine.IO packet."""
+
+ json = _json
+
+ def __init__(self, packet_type=NOOP, data=None, binary=None,
+ encoded_packet=None):
+ self.packet_type = packet_type
+ self.data = data
+ if binary is not None:
+ self.binary = binary
+ elif isinstance(data, six.text_type):
+ self.binary = False
+ elif isinstance(data, binary_types):
+ self.binary = True
+ else:
+ self.binary = False
+ if encoded_packet:
+ self.decode(encoded_packet)
+
+ def encode(self, b64=False, always_bytes=True):
+ """Encode the packet for transmission."""
+ if self.binary and not b64:
+ encoded_packet = six.int2byte(self.packet_type)
+ else:
+ encoded_packet = six.text_type(self.packet_type)
+ if self.binary and b64:
+ encoded_packet = 'b' + encoded_packet
+ if self.binary:
+ if b64:
+ encoded_packet += base64.b64encode(self.data).decode('utf-8')
+ else:
+ encoded_packet += self.data
+ elif isinstance(self.data, six.string_types):
+ encoded_packet += self.data
+ elif isinstance(self.data, dict) or isinstance(self.data, list):
+ encoded_packet += self.json.dumps(self.data,
+ separators=(',', ':'))
+ elif self.data is not None:
+ encoded_packet += str(self.data)
+ if always_bytes and not isinstance(encoded_packet, binary_types):
+ encoded_packet = encoded_packet.encode('utf-8')
+ return encoded_packet
+
+ def decode(self, encoded_packet):
+ """Decode a transmitted package."""
+ b64 = False
+ if not isinstance(encoded_packet, binary_types):
+ encoded_packet = encoded_packet.encode('utf-8')
+ elif not isinstance(encoded_packet, bytes):
+ encoded_packet = bytes(encoded_packet)
+ self.packet_type = six.byte2int(encoded_packet[0:1])
+ if self.packet_type == 98: # 'b' --> binary base64 encoded packet
+ self.binary = True
+ encoded_packet = encoded_packet[1:]
+ self.packet_type = six.byte2int(encoded_packet[0:1])
+ self.packet_type -= 48
+ b64 = True
+ elif self.packet_type >= 48:
+ self.packet_type -= 48
+ self.binary = False
+ else:
+ self.binary = True
+ self.data = None
+ if len(encoded_packet) > 1:
+ if self.binary:
+ if b64:
+ self.data = base64.b64decode(encoded_packet[1:])
+ else:
+ self.data = encoded_packet[1:]
+ else:
+ try:
+ self.data = self.json.loads(
+ encoded_packet[1:].decode('utf-8'))
+ if isinstance(self.data, int):
+ # do not allow integer payloads, see
+ # github.com/miguelgrinberg/python-engineio/issues/75
+ # for background on this decision
+ raise ValueError
+ except ValueError:
+ self.data = encoded_packet[1:].decode('utf-8')
diff --git a/libs/engineio/payload.py b/libs/engineio/payload.py
new file mode 100644
index 000000000..fbf9cbd27
--- /dev/null
+++ b/libs/engineio/payload.py
@@ -0,0 +1,81 @@
+import six
+
+from . import packet
+
+from six.moves import urllib
+
+
+class Payload(object):
+ """Engine.IO payload."""
+ max_decode_packets = 16
+
+ def __init__(self, packets=None, encoded_payload=None):
+ self.packets = packets or []
+ if encoded_payload is not None:
+ self.decode(encoded_payload)
+
+ def encode(self, b64=False, jsonp_index=None):
+ """Encode the payload for transmission."""
+ encoded_payload = b''
+ for pkt in self.packets:
+ encoded_packet = pkt.encode(b64=b64)
+ packet_len = len(encoded_packet)
+ if b64:
+ encoded_payload += str(packet_len).encode('utf-8') + b':' + \
+ encoded_packet
+ else:
+ binary_len = b''
+ while packet_len != 0:
+ binary_len = six.int2byte(packet_len % 10) + binary_len
+ packet_len = int(packet_len / 10)
+ if not pkt.binary:
+ encoded_payload += b'\0'
+ else:
+ encoded_payload += b'\1'
+ encoded_payload += binary_len + b'\xff' + encoded_packet
+ if jsonp_index is not None:
+ encoded_payload = b'___eio[' + \
+ str(jsonp_index).encode() + \
+ b']("' + \
+ encoded_payload.replace(b'"', b'\\"') + \
+ b'");'
+ return encoded_payload
+
+ def decode(self, encoded_payload):
+ """Decode a transmitted payload."""
+ self.packets = []
+
+ if len(encoded_payload) == 0:
+ return
+
+ # JSONP POST payload starts with 'd='
+ if encoded_payload.startswith(b'd='):
+ encoded_payload = urllib.parse.parse_qs(
+ encoded_payload)[b'd'][0]
+
+ i = 0
+ if six.byte2int(encoded_payload[0:1]) <= 1:
+ # binary encoding
+ while i < len(encoded_payload):
+ if len(self.packets) >= self.max_decode_packets:
+ raise ValueError('Too many packets in payload')
+ packet_len = 0
+ i += 1
+ while six.byte2int(encoded_payload[i:i + 1]) != 255:
+ packet_len = packet_len * 10 + six.byte2int(
+ encoded_payload[i:i + 1])
+ i += 1
+ self.packets.append(packet.Packet(
+ encoded_packet=encoded_payload[i + 1:i + 1 + packet_len]))
+ i += packet_len + 1
+ else:
+ # assume text encoding
+ encoded_payload = encoded_payload.decode('utf-8')
+ while i < len(encoded_payload):
+ if len(self.packets) >= self.max_decode_packets:
+ raise ValueError('Too many packets in payload')
+ j = encoded_payload.find(':', i)
+ packet_len = int(encoded_payload[i:j])
+ pkt = encoded_payload[j + 1:j + 1 + packet_len]
+ self.packets.append(packet.Packet(encoded_packet=pkt))
+ i = j + 1 + packet_len
diff --git a/libs/engineio/server.py b/libs/engineio/server.py
new file mode 100644
index 000000000..e1543c2dc
--- /dev/null
+++ b/libs/engineio/server.py
@@ -0,0 +1,675 @@
+import gzip
+import importlib
+import logging
+import uuid
+import zlib
+
+import six
+from six.moves import urllib
+
+from . import exceptions
+from . import packet
+from . import payload
+from . import socket
+
+default_logger = logging.getLogger('engineio.server')
+
+
+class Server(object):
+ """An Engine.IO server.
+
+ This class implements a fully compliant Engine.IO web server with support
+ for websocket and long-polling transports.
+
+ :param async_mode: The asynchronous model to use. See the Deployment
+ section in the documentation for a description of the
+ available options. Valid async modes are "threading",
+ "eventlet", "gevent" and "gevent_uwsgi". If this
+ argument is not given, "eventlet" is tried first, then
+ "gevent_uwsgi", then "gevent", and finally "threading".
+ The first async mode that has all its dependencies
+ installed is the one that is chosen.
+ :param ping_timeout: The time in seconds that the client waits for the
+ server to respond before disconnecting. The default
+ is 60 seconds.
+ :param ping_interval: The interval in seconds at which the client pings
+ the server. The default is 25 seconds. For advanced
+ control, a two element tuple can be given, where
+ the first number is the ping interval and the second
+ is a grace period added by the server. The default
+ grace period is 5 seconds.
+ :param max_http_buffer_size: The maximum size of a message when using the
+ polling transport. The default is 100,000,000
+ bytes.
+ :param allow_upgrades: Whether to allow transport upgrades or not. The
+ default is ``True``.
+ :param http_compression: Whether to compress packages when using the
+ polling transport. The default is ``True``.
+ :param compression_threshold: Only compress messages when their byte size
+ is greater than this value. The default is
+ 1024 bytes.
+ :param cookie: Name of the HTTP cookie that contains the client session
+ id. If set to ``None``, a cookie is not sent to the client.
+ The default is ``'io'``.
+ :param cors_allowed_origins: Origin or list of origins that are allowed to
+ connect to this server. Only the same origin
+ is allowed by default. Set this argument to
+ ``'*'`` to allow all origins, or to ``[]`` to
+ disable CORS handling.
+ :param cors_credentials: Whether credentials (cookies, authentication) are
+ allowed in requests to this server. The default
+ is ``True``.
+ :param logger: To enable logging set to ``True`` or pass a logger object to
+ use. To disable logging set to ``False``. The default is
+ ``False``.
+ :param json: An alternative json module to use for encoding and decoding
+ packets. Custom json modules must have ``dumps`` and ``loads``
+ functions that are compatible with the standard library
+ versions.
+ :param async_handlers: If set to ``True``, run message event handlers in
+ non-blocking threads. To run handlers synchronously,
+ set to ``False``. The default is ``True``.
+ :param monitor_clients: If set to ``True``, a background task will ensure
+ inactive clients are closed. Set to ``False`` to
+ disable the monitoring task (not recommended). The
+ default is ``True``.
+ :param kwargs: Reserved for future extensions, any additional parameters
+ given as keyword arguments will be silently ignored.
+ """
+ compression_methods = ['gzip', 'deflate']
+ event_names = ['connect', 'disconnect', 'message']
+ _default_monitor_clients = True
+
+ def __init__(self, async_mode=None, ping_timeout=60, ping_interval=25,
+ max_http_buffer_size=100000000, allow_upgrades=True,
+ http_compression=True, compression_threshold=1024,
+ cookie='io', cors_allowed_origins=None,
+ cors_credentials=True, logger=False, json=None,
+ async_handlers=True, monitor_clients=None, **kwargs):
+ self.ping_timeout = ping_timeout
+ if isinstance(ping_interval, tuple):
+ self.ping_interval = ping_interval[0]
+ self.ping_interval_grace_period = ping_interval[1]
+ else:
+ self.ping_interval = ping_interval
+ self.ping_interval_grace_period = 5
+ self.max_http_buffer_size = max_http_buffer_size
+ self.allow_upgrades = allow_upgrades
+ self.http_compression = http_compression
+ self.compression_threshold = compression_threshold
+ self.cookie = cookie
+ self.cors_allowed_origins = cors_allowed_origins
+ self.cors_credentials = cors_credentials
+ self.async_handlers = async_handlers
+ self.sockets = {}
+ self.handlers = {}
+ self.start_service_task = monitor_clients \
+ if monitor_clients is not None else self._default_monitor_clients
+ if json is not None:
+ packet.Packet.json = json
+ if not isinstance(logger, bool):
+ self.logger = logger
+ else:
+ self.logger = default_logger
+ if not logging.root.handlers and \
+ self.logger.level == logging.NOTSET:
+ if logger:
+ self.logger.setLevel(logging.INFO)
+ else:
+ self.logger.setLevel(logging.ERROR)
+ self.logger.addHandler(logging.StreamHandler())
+ modes = self.async_modes()
+ if async_mode is not None:
+ modes = [async_mode] if async_mode in modes else []
+ self._async = None
+ self.async_mode = None
+ for mode in modes:
+ try:
+ self._async = importlib.import_module(
+ 'engineio.async_drivers.' + mode)._async
+ asyncio_based = self._async['asyncio'] \
+ if 'asyncio' in self._async else False
+ if asyncio_based != self.is_asyncio_based():
+ continue # pragma: no cover
+ self.async_mode = mode
+ break
+ except ImportError:
+ pass
+ if self.async_mode is None:
+ raise ValueError('Invalid async_mode specified')
+ if self.is_asyncio_based() and \
+ ('asyncio' not in self._async or not
+ self._async['asyncio']): # pragma: no cover
+ raise ValueError('The selected async_mode is not asyncio '
+ 'compatible')
+ if not self.is_asyncio_based() and 'asyncio' in self._async and \
+ self._async['asyncio']: # pragma: no cover
+ raise ValueError('The selected async_mode requires asyncio and '
+ 'must use the AsyncServer class')
+ self.logger.info('Server initialized for %s.', self.async_mode)
+
+ def is_asyncio_based(self):
+ return False
+
+ def async_modes(self):
+ return ['eventlet', 'gevent_uwsgi', 'gevent', 'threading']
+
+ def on(self, event, handler=None):
+ """Register an event handler.
+
+ :param event: The event name. Can be ``'connect'``, ``'message'`` or
+ ``'disconnect'``.
+ :param handler: The function that should be invoked to handle the
+ event. When this parameter is not given, the method
+ acts as a decorator for the handler function.
+
+ Example usage::
+
+ # as a decorator:
+ @eio.on('connect')
+ def connect_handler(sid, environ):
+ print('Connection request')
+ if environ['REMOTE_ADDR'] in blacklisted:
+ return False # reject
+
+ # as a method:
+ def message_handler(sid, msg):
+ print('Received message: ', msg)
+ eio.send(sid, 'response')
+ eio.on('message', message_handler)
+
+ The handler function receives the ``sid`` (session ID) for the
+ client as first argument. The ``'connect'`` event handler receives the
+ WSGI environment as a second argument, and can return ``False`` to
+ reject the connection. The ``'message'`` handler receives the message
+ payload as a second argument. The ``'disconnect'`` handler does not
+ take a second argument.
+ """
+ if event not in self.event_names:
+ raise ValueError('Invalid event')
+
+ def set_handler(handler):
+ self.handlers[event] = handler
+ return handler
+
+ if handler is None:
+ return set_handler
+ set_handler(handler)
+
+ def send(self, sid, data, binary=None):
+ """Send a message to a client.
+
+ :param sid: The session id of the recipient client.
+ :param data: The data to send to the client. Data can be of type
+ ``str``, ``bytes``, ``list`` or ``dict``. If a ``list``
+ or ``dict``, the data will be serialized as JSON.
+ :param binary: ``True`` to send packet as binary, ``False`` to send
+ as text. If not given, unicode (Python 2) and str
+ (Python 3) are sent as text, and str (Python 2) and
+ bytes (Python 3) are sent as binary.
+ """
+ try:
+ socket = self._get_socket(sid)
+ except KeyError:
+ # the socket is not available
+ self.logger.warning('Cannot send to sid %s', sid)
+ return
+ socket.send(packet.Packet(packet.MESSAGE, data=data, binary=binary))
+
+ def get_session(self, sid):
+ """Return the user session for a client.
+
+ :param sid: The session id of the client.
+
+ The return value is a dictionary. Modifications made to this
+ dictionary are not guaranteed to be preserved unless
+ ``save_session()`` is called, or when the ``session`` context manager
+ is used.
+ """
+ socket = self._get_socket(sid)
+ return socket.session
+
+ def save_session(self, sid, session):
+ """Store the user session for a client.
+
+ :param sid: The session id of the client.
+ :param session: The session dictionary.
+ """
+ socket = self._get_socket(sid)
+ socket.session = session
+
+ def session(self, sid):
+ """Return the user session for a client with context manager syntax.
+
+ :param sid: The session id of the client.
+
+ This is a context manager that returns the user session dictionary for
+ the client. Any changes that are made to this dictionary inside the
+ context manager block are saved back to the session. Example usage::
+
+ @eio.on('connect')
+ def on_connect(sid, environ):
+ username = authenticate_user(environ)
+ if not username:
+ return False
+ with eio.session(sid) as session:
+ session['username'] = username
+
+ @eio.on('message')
+ def on_message(sid, msg):
+ with eio.session(sid) as session:
+ print('received message from ', session['username'])
+ """
+ class _session_context_manager(object):
+ def __init__(self, server, sid):
+ self.server = server
+ self.sid = sid
+ self.session = None
+
+ def __enter__(self):
+ self.session = self.server.get_session(sid)
+ return self.session
+
+ def __exit__(self, *args):
+ self.server.save_session(sid, self.session)
+
+ return _session_context_manager(self, sid)
+
+ def disconnect(self, sid=None):
+ """Disconnect a client.
+
+ :param sid: The session id of the client to close. If this parameter
+ is not given, then all clients are closed.
+ """
+ if sid is not None:
+ try:
+ socket = self._get_socket(sid)
+ except KeyError: # pragma: no cover
+ # the socket was already closed or gone
+ pass
+ else:
+ socket.close()
+ if sid in self.sockets: # pragma: no cover
+ del self.sockets[sid]
+ else:
+ for client in six.itervalues(self.sockets):
+ client.close()
+ self.sockets = {}
+
+ def transport(self, sid):
+ """Return the name of the transport used by the client.
+
+ The two possible values returned by this function are ``'polling'``
+ and ``'websocket'``.
+
+ :param sid: The session of the client.
+ """
+ return 'websocket' if self._get_socket(sid).upgraded else 'polling'
+
+ def handle_request(self, environ, start_response):
+ """Handle an HTTP request from the client.
+
+ This is the entry point of the Engine.IO application, using the same
+ interface as a WSGI application. For the typical usage, this function
+ is invoked by the :class:`Middleware` instance, but it can be invoked
+ directly when the middleware is not used.
+
+ :param environ: The WSGI environment.
+ :param start_response: The WSGI ``start_response`` function.
+
+ This function returns the HTTP response body to deliver to the client
+ as a byte sequence.
+ """
+ if self.cors_allowed_origins != []:
+ # Validate the origin header if present
+ # This is important for WebSocket more than for HTTP, since
+ # browsers only apply CORS controls to HTTP.
+ origin = environ.get('HTTP_ORIGIN')
+ if origin:
+ allowed_origins = self._cors_allowed_origins(environ)
+ if allowed_origins is not None and origin not in \
+ allowed_origins:
+ self.logger.info(origin + ' is not an accepted origin.')
+ r = self._bad_request()
+ start_response(r['status'], r['headers'])
+ return [r['response']]
+
+ method = environ['REQUEST_METHOD']
+ query = urllib.parse.parse_qs(environ.get('QUERY_STRING', ''))
+
+ sid = query['sid'][0] if 'sid' in query else None
+ b64 = False
+ jsonp = False
+ jsonp_index = None
+
+ if 'b64' in query:
+ if query['b64'][0] == "1" or query['b64'][0].lower() == "true":
+ b64 = True
+ if 'j' in query:
+ jsonp = True
+ try:
+ jsonp_index = int(query['j'][0])
+ except (ValueError, KeyError, IndexError):
+ # Invalid JSONP index number
+ pass
+
+ if jsonp and jsonp_index is None:
+ self.logger.warning('Invalid JSONP index number')
+ r = self._bad_request()
+ elif method == 'GET':
+ if sid is None:
+ transport = query.get('transport', ['polling'])[0]
+ if transport != 'polling' and transport != 'websocket':
+ self.logger.warning('Invalid transport %s', transport)
+ r = self._bad_request()
+ else:
+ r = self._handle_connect(environ, start_response,
+ transport, b64, jsonp_index)
+ else:
+ if sid not in self.sockets:
+ self.logger.warning('Invalid session %s', sid)
+ r = self._bad_request()
+ else:
+ socket = self._get_socket(sid)
+ try:
+ packets = socket.handle_get_request(
+ environ, start_response)
+ if isinstance(packets, list):
+ r = self._ok(packets, b64=b64,
+ jsonp_index=jsonp_index)
+ else:
+ r = packets
+ except exceptions.EngineIOError:
+ if sid in self.sockets: # pragma: no cover
+ self.disconnect(sid)
+ r = self._bad_request()
+ if sid in self.sockets and self.sockets[sid].closed:
+ del self.sockets[sid]
+ elif method == 'POST':
+ if sid is None or sid not in self.sockets:
+ self.logger.warning('Invalid session %s', sid)
+ r = self._bad_request()
+ else:
+ socket = self._get_socket(sid)
+ try:
+ socket.handle_post_request(environ)
+ r = self._ok(jsonp_index=jsonp_index)
+ except exceptions.EngineIOError:
+ if sid in self.sockets: # pragma: no cover
+ self.disconnect(sid)
+ r = self._bad_request()
+ except: # pragma: no cover
+ # for any other unexpected errors, we log the error
+ # and keep going
+ self.logger.exception('post request handler error')
+ r = self._ok(jsonp_index=jsonp_index)
+ elif method == 'OPTIONS':
+ r = self._ok()
+ else:
+ self.logger.warning('Method %s not supported', method)
+ r = self._method_not_found()
+
+ if not isinstance(r, dict):
+ return r or []
+ if self.http_compression and \
+ len(r['response']) >= self.compression_threshold:
+ encodings = [e.split(';')[0].strip() for e in
+ environ.get('HTTP_ACCEPT_ENCODING', '').split(',')]
+ for encoding in encodings:
+ if encoding in self.compression_methods:
+ r['response'] = \
+ getattr(self, '_' + encoding)(r['response'])
+ r['headers'] += [('Content-Encoding', encoding)]
+ break
+ cors_headers = self._cors_headers(environ)
+ start_response(r['status'], r['headers'] + cors_headers)
+ return [r['response']]
+
+ def start_background_task(self, target, *args, **kwargs):
+ """Start a background task using the appropriate async model.
+
+ This is a utility function that applications can use to start a
+ background task using the method that is compatible with the
+ selected async mode.
+
+ :param target: the target function to execute.
+ :param args: arguments to pass to the function.
+ :param kwargs: keyword arguments to pass to the function.
+
+ This function returns an object compatible with the `Thread` class in
+ the Python standard library. The `start()` method on this object is
+ already called by this function.
+ """
+ th = self._async['thread'](target=target, args=args, kwargs=kwargs)
+ th.start()
+ return th # pragma: no cover
+
+ def sleep(self, seconds=0):
+ """Sleep for the requested amount of time using the appropriate async
+ model.
+
+ This is a utility function that applications can use to put a task to
+ sleep without having to worry about using the correct call for the
+ selected async mode.
+ """
+ return self._async['sleep'](seconds)
+
+ def create_queue(self, *args, **kwargs):
+ """Create a queue object using the appropriate async model.
+
+ This is a utility function that applications can use to create a queue
+ without having to worry about using the correct call for the selected
+ async mode.
+ """
+ return self._async['queue'](*args, **kwargs)
+
+ def get_queue_empty_exception(self):
+ """Return the queue empty exception for the appropriate async model.
+
+ This is a utility function that applications can use to work with a
+ queue without having to worry about using the correct call for the
+ selected async mode.
+ """
+ return self._async['queue_empty']
+
+ def create_event(self, *args, **kwargs):
+ """Create an event object using the appropriate async model.
+
+ This is a utility function that applications can use to create an
+ event without having to worry about using the correct call for the
+ selected async mode.
+ """
+ return self._async['event'](*args, **kwargs)
+
+ def _generate_id(self):
+ """Generate a unique session id."""
+ return uuid.uuid4().hex
+
+ def _handle_connect(self, environ, start_response, transport, b64=False,
+ jsonp_index=None):
+ """Handle a client connection request."""
+ if self.start_service_task:
+ # start the service task to monitor connected clients
+ self.start_service_task = False
+ self.start_background_task(self._service_task)
+
+ sid = self._generate_id()
+ s = socket.Socket(self, sid)
+ self.sockets[sid] = s
+
+ pkt = packet.Packet(
+ packet.OPEN, {'sid': sid,
+ 'upgrades': self._upgrades(sid, transport),
+ 'pingTimeout': int(self.ping_timeout * 1000),
+ 'pingInterval': int(self.ping_interval * 1000)})
+ s.send(pkt)
+
+ ret = self._trigger_event('connect', sid, environ, run_async=False)
+ if ret is False:
+ del self.sockets[sid]
+ self.logger.warning('Application rejected connection')
+ return self._unauthorized()
+
+ if transport == 'websocket':
+ ret = s.handle_get_request(environ, start_response)
+ if s.closed:
+ # websocket connection ended, so we are done
+ del self.sockets[sid]
+ return ret
+ else:
+ s.connected = True
+ headers = None
+ if self.cookie:
+ headers = [('Set-Cookie', self.cookie + '=' + sid)]
+ try:
+ return self._ok(s.poll(), headers=headers, b64=b64,
+ jsonp_index=jsonp_index)
+ except exceptions.QueueEmpty:
+ return self._bad_request()
+
+ def _upgrades(self, sid, transport):
+ """Return the list of possible upgrades for a client connection."""
+ if not self.allow_upgrades or self._get_socket(sid).upgraded or \
+ self._async['websocket'] is None or transport == 'websocket':
+ return []
+ return ['websocket']
+
+ def _trigger_event(self, event, *args, **kwargs):
+ """Invoke an event handler."""
+ run_async = kwargs.pop('run_async', False)
+ if event in self.handlers:
+ if run_async:
+ return self.start_background_task(self.handlers[event], *args)
+ else:
+ try:
+ return self.handlers[event](*args)
+ except:
+ self.logger.exception(event + ' handler error')
+ if event == 'connect':
+ # if connect handler raised error we reject the
+ # connection
+ return False
+
+ def _get_socket(self, sid):
+ """Return the socket object for a given session."""
+ try:
+ s = self.sockets[sid]
+ except KeyError:
+ raise KeyError('Session not found')
+ if s.closed:
+ del self.sockets[sid]
+ raise KeyError('Session is disconnected')
+ return s
+
+ def _ok(self, packets=None, headers=None, b64=False, jsonp_index=None):
+ """Generate a successful HTTP response."""
+ if packets is not None:
+ if headers is None:
+ headers = []
+ if b64:
+ headers += [('Content-Type', 'text/plain; charset=UTF-8')]
+ else:
+ headers += [('Content-Type', 'application/octet-stream')]
+ return {'status': '200 OK',
+ 'headers': headers,
+ 'response': payload.Payload(packets=packets).encode(
+ b64=b64, jsonp_index=jsonp_index)}
+ else:
+ return {'status': '200 OK',
+ 'headers': [('Content-Type', 'text/plain')],
+ 'response': b'OK'}
+
+ def _bad_request(self):
+ """Generate a bad request HTTP error response."""
+ return {'status': '400 BAD REQUEST',
+ 'headers': [('Content-Type', 'text/plain')],
+ 'response': b'Bad Request'}
+
+ def _method_not_found(self):
+ """Generate a method not found HTTP error response."""
+ return {'status': '405 METHOD NOT FOUND',
+ 'headers': [('Content-Type', 'text/plain')],
+ 'response': b'Method Not Found'}
+
+ def _unauthorized(self):
+ """Generate a unauthorized HTTP error response."""
+ return {'status': '401 UNAUTHORIZED',
+ 'headers': [('Content-Type', 'text/plain')],
+ 'response': b'Unauthorized'}
+
+ def _cors_allowed_origins(self, environ):
+ default_origins = []
+ if 'wsgi.url_scheme' in environ and 'HTTP_HOST' in environ:
+ default_origins.append('{scheme}://{host}'.format(
+ scheme=environ['wsgi.url_scheme'], host=environ['HTTP_HOST']))
+ if 'HTTP_X_FORWARDED_HOST' in environ:
+ scheme = environ.get(
+ 'HTTP_X_FORWARDED_PROTO',
+ environ['wsgi.url_scheme']).split(',')[0].strip()
+ default_origins.append('{scheme}://{host}'.format(
+ scheme=scheme, host=environ['HTTP_X_FORWARDED_HOST'].split(
+ ',')[0].strip()))
+ if self.cors_allowed_origins is None:
+ allowed_origins = default_origins
+ elif self.cors_allowed_origins == '*':
+ allowed_origins = None
+ elif isinstance(self.cors_allowed_origins, six.string_types):
+ allowed_origins = [self.cors_allowed_origins]
+ else:
+ allowed_origins = self.cors_allowed_origins
+ return allowed_origins
+
+ def _cors_headers(self, environ):
+ """Return the cross-origin-resource-sharing headers."""
+ if self.cors_allowed_origins == []:
+ # special case, CORS handling is completely disabled
+ return []
+ headers = []
+ allowed_origins = self._cors_allowed_origins(environ)
+ if 'HTTP_ORIGIN' in environ and \
+ (allowed_origins is None or environ['HTTP_ORIGIN'] in
+ allowed_origins):
+ headers = [('Access-Control-Allow-Origin', environ['HTTP_ORIGIN'])]
+ if environ['REQUEST_METHOD'] == 'OPTIONS':
+ headers += [('Access-Control-Allow-Methods', 'OPTIONS, GET, POST')]
+ if 'HTTP_ACCESS_CONTROL_REQUEST_HEADERS' in environ:
+ headers += [('Access-Control-Allow-Headers',
+ environ['HTTP_ACCESS_CONTROL_REQUEST_HEADERS'])]
+ if self.cors_credentials:
+ headers += [('Access-Control-Allow-Credentials', 'true')]
+ return headers
+
+ def _gzip(self, response):
+ """Apply gzip compression to a response."""
+ bytesio = six.BytesIO()
+ with gzip.GzipFile(fileobj=bytesio, mode='w') as gz:
+ gz.write(response)
+ return bytesio.getvalue()
+
+ def _deflate(self, response):
+ """Apply deflate compression to a response."""
+ return zlib.compress(response)
+
+ def _service_task(self): # pragma: no cover
+ """Monitor connected clients and clean up those that time out."""
+ while True:
+ if len(self.sockets) == 0:
+ # nothing to do
+ self.sleep(self.ping_timeout)
+ continue
+
+ # go through the entire client list in a ping interval cycle
+ sleep_interval = float(self.ping_timeout) / len(self.sockets)
+
+ try:
+ # iterate over the current clients
+ for s in self.sockets.copy().values():
+ if not s.closing and not s.closed:
+ s.check_ping_timeout()
+ self.sleep(sleep_interval)
+ except (SystemExit, KeyboardInterrupt):
+ self.logger.info('service task canceled')
+ break
+ except:
+ # an unexpected exception has occurred, log it and continue
+ self.logger.exception('service task exception')
diff --git a/libs/engineio/socket.py b/libs/engineio/socket.py
new file mode 100644
index 000000000..38593e7c7
--- /dev/null
+++ b/libs/engineio/socket.py
@@ -0,0 +1,248 @@
+import six
+import sys
+import time
+
+from . import exceptions
+from . import packet
+from . import payload
+
+
+class Socket(object):
+ """An Engine.IO socket."""
+ upgrade_protocols = ['websocket']
+
+ def __init__(self, server, sid):
+ self.server = server
+ self.sid = sid
+ self.queue = self.server.create_queue()
+ self.last_ping = time.time()
+ self.connected = False
+ self.upgrading = False
+ self.upgraded = False
+ self.packet_backlog = []
+ self.closing = False
+ self.closed = False
+ self.session = {}
+
+ def poll(self):
+ """Wait for packets to send to the client."""
+ queue_empty = self.server.get_queue_empty_exception()
+ try:
+ packets = [self.queue.get(timeout=self.server.ping_timeout)]
+ self.queue.task_done()
+ except queue_empty:
+ raise exceptions.QueueEmpty()
+ if packets == [None]:
+ return []
+ while True:
+ try:
+ packets.append(self.queue.get(block=False))
+ self.queue.task_done()
+ except queue_empty:
+ break
+ return packets
+
+ def receive(self, pkt):
+ """Receive packet from the client."""
+ packet_name = packet.packet_names[pkt.packet_type] \
+ if pkt.packet_type < len(packet.packet_names) else 'UNKNOWN'
+ self.server.logger.info('%s: Received packet %s data %s',
+ self.sid, packet_name,
+ pkt.data if not isinstance(pkt.data, bytes)
+ else '<binary>')
+ if pkt.packet_type == packet.PING:
+ self.last_ping = time.time()
+ self.send(packet.Packet(packet.PONG, pkt.data))
+ elif pkt.packet_type == packet.MESSAGE:
+ self.server._trigger_event('message', self.sid, pkt.data,
+ run_async=self.server.async_handlers)
+ elif pkt.packet_type == packet.UPGRADE:
+ self.send(packet.Packet(packet.NOOP))
+ elif pkt.packet_type == packet.CLOSE:
+ self.close(wait=False, abort=True)
+ else:
+ raise exceptions.UnknownPacketError()
+
+ def check_ping_timeout(self):
+ """Make sure the client is still sending pings.
+
+ This helps detect disconnections for long-polling clients.
+ """
+ if self.closed:
+ raise exceptions.SocketIsClosedError()
+ if time.time() - self.last_ping > self.server.ping_interval + \
+ self.server.ping_interval_grace_period:
+ self.server.logger.info('%s: Client is gone, closing socket',
+ self.sid)
+ # Passing abort=False here will cause close() to write a
+ # CLOSE packet. This has the effect of updating half-open sockets
+ # to their correct state of disconnected
+ self.close(wait=False, abort=False)
+ return False
+ return True
+
+ def send(self, pkt):
+ """Send a packet to the client."""
+ if not self.check_ping_timeout():
+ return
+ if self.upgrading:
+ self.packet_backlog.append(pkt)
+ else:
+ self.queue.put(pkt)
+ self.server.logger.info('%s: Sending packet %s data %s',
+ self.sid, packet.packet_names[pkt.packet_type],
+ pkt.data if not isinstance(pkt.data, bytes)
+ else '<binary>')
+
+ def handle_get_request(self, environ, start_response):
+ """Handle a long-polling GET request from the client."""
+ connections = [
+ s.strip()
+ for s in environ.get('HTTP_CONNECTION', '').lower().split(',')]
+ transport = environ.get('HTTP_UPGRADE', '').lower()
+ if 'upgrade' in connections and transport in self.upgrade_protocols:
+ self.server.logger.info('%s: Received request to upgrade to %s',
+ self.sid, transport)
+ return getattr(self, '_upgrade_' + transport)(environ,
+ start_response)
+ try:
+ packets = self.poll()
+ except exceptions.QueueEmpty:
+ exc = sys.exc_info()
+ self.close(wait=False)
+ six.reraise(*exc)
+ return packets
+
+ def handle_post_request(self, environ):
+ """Handle a long-polling POST request from the client."""
+ length = int(environ.get('CONTENT_LENGTH', '0'))
+ if length > self.server.max_http_buffer_size:
+ raise exceptions.ContentTooLongError()
+ else:
+ body = environ['wsgi.input'].read(length)
+ p = payload.Payload(encoded_payload=body)
+ for pkt in p.packets:
+ self.receive(pkt)
+
+ def close(self, wait=True, abort=False):
+ """Close the socket connection."""
+ if not self.closed and not self.closing:
+ self.closing = True
+ self.server._trigger_event('disconnect', self.sid, run_async=False)
+ if not abort:
+ self.send(packet.Packet(packet.CLOSE))
+ self.closed = True
+ self.queue.put(None)
+ if wait:
+ self.queue.join()
+
+ def _upgrade_websocket(self, environ, start_response):
+ """Upgrade the connection from polling to websocket."""
+ if self.upgraded:
+ raise IOError('Socket has been upgraded already')
+ if self.server._async['websocket'] is None:
+ # the selected async mode does not support websocket
+ return self.server._bad_request()
+ ws = self.server._async['websocket'](self._websocket_handler)
+ return ws(environ, start_response)
+
+ def _websocket_handler(self, ws):
+ """Engine.IO handler for websocket transport."""
+ # try to set a socket timeout matching the configured ping interval
+ for attr in ['_sock', 'socket']: # pragma: no cover
+ if hasattr(ws, attr) and hasattr(getattr(ws, attr), 'settimeout'):
+ getattr(ws, attr).settimeout(self.server.ping_timeout)
+
+ if self.connected:
+ # the socket was already connected, so this is an upgrade
+ self.upgrading = True # hold packet sends during the upgrade
+
+ pkt = ws.wait()
+ decoded_pkt = packet.Packet(encoded_packet=pkt)
+ if decoded_pkt.packet_type != packet.PING or \
+ decoded_pkt.data != 'probe':
+ self.server.logger.info(
+ '%s: Failed websocket upgrade, no PING packet', self.sid)
+ return []
+ ws.send(packet.Packet(
+ packet.PONG,
+ data=six.text_type('probe')).encode(always_bytes=False))
+ self.queue.put(packet.Packet(packet.NOOP)) # end poll
+
+ pkt = ws.wait()
+ decoded_pkt = packet.Packet(encoded_packet=pkt)
+ if decoded_pkt.packet_type != packet.UPGRADE:
+ self.upgraded = False
+ self.server.logger.info(
+ ('%s: Failed websocket upgrade, expected UPGRADE packet, '
+ 'received %s instead.'),
+ self.sid, pkt)
+ return []
+ self.upgraded = True
+
+ # flush any packets that were sent during the upgrade
+ for pkt in self.packet_backlog:
+ self.queue.put(pkt)
+ self.packet_backlog = []
+ self.upgrading = False
+ else:
+ self.connected = True
+ self.upgraded = True
+
+ # start separate writer thread
+ def writer():
+ while True:
+ packets = None
+ try:
+ packets = self.poll()
+ except exceptions.QueueEmpty:
+ break
+ if not packets:
+ # empty packet list returned -> connection closed
+ break
+ try:
+ for pkt in packets:
+ ws.send(pkt.encode(always_bytes=False))
+ except:
+ break
+ writer_task = self.server.start_background_task(writer)
+
+ self.server.logger.info(
+ '%s: Upgrade to websocket successful', self.sid)
+
+ while True:
+ p = None
+ try:
+ p = ws.wait()
+ except Exception as e:
+ # if the socket is already closed, we can assume this is a
+ # downstream error of that
+ if not self.closed: # pragma: no cover
+ self.server.logger.info(
+ '%s: Unexpected error "%s", closing connection',
+ self.sid, str(e))
+ break
+ if p is None:
+ # connection closed by client
+ break
+ if isinstance(p, six.text_type): # pragma: no cover
+ p = p.encode('utf-8')
+ pkt = packet.Packet(encoded_packet=p)
+ try:
+ self.receive(pkt)
+ except exceptions.UnknownPacketError: # pragma: no cover
+ pass
+ except exceptions.SocketIsClosedError: # pragma: no cover
+ self.server.logger.info('Receive error -- socket is closed')
+ break
+ except: # pragma: no cover
+ # if we get an unexpected exception we log the error and exit
+ # the connection properly
+ self.server.logger.exception('Unknown receive error')
+ break
+
+ self.queue.put(None) # unlock the writer task so that it can exit
+ writer_task.join()
+ self.close(wait=False, abort=True)
+
+ return []
diff --git a/libs/engineio/static_files.py b/libs/engineio/static_files.py
new file mode 100644
index 000000000..3058f6ea4
--- /dev/null
+++ b/libs/engineio/static_files.py
@@ -0,0 +1,55 @@
+content_types = {
+ 'css': 'text/css',
+ 'gif': 'image/gif',
+ 'html': 'text/html',
+ 'jpg': 'image/jpeg',
+ 'js': 'application/javascript',
+ 'json': 'application/json',
+ 'png': 'image/png',
+ 'txt': 'text/plain',
+}
+
+
+def get_static_file(path, static_files):
+ """Return the local filename and content type for the requested static
+ file URL.
+
+ :param path: the path portion of the requested URL.
+ :param static_files: a static file configuration dictionary.
+
+ This function returns a dictionary with two keys, "filename" and
+ "content_type". If the requested URL does not match any static file, the
+ return value is None.
+ """
+ if path in static_files:
+ f = static_files[path]
+ else:
+ f = None
+ rest = ''
+ while path != '':
+ path, last = path.rsplit('/', 1)
+ rest = '/' + last + rest
+ if path in static_files:
+ f = static_files[path] + rest
+ break
+ elif path + '/' in static_files:
+ f = static_files[path + '/'] + rest[1:]
+ break
+ if f:
+ if isinstance(f, str):
+ f = {'filename': f}
+ if f['filename'].endswith('/'):
+ if '' in static_files:
+ if isinstance(static_files[''], str):
+ f['filename'] += static_files['']
+ else:
+ f['filename'] += static_files['']['filename']
+ if 'content_type' in static_files['']:
+ f['content_type'] = static_files['']['content_type']
+ else:
+ f['filename'] += 'index.html'
+ if 'content_type' not in f:
+ ext = f['filename'].rsplit('.')[-1]
+ f['content_type'] = content_types.get(
+ ext, 'application/octet-stream')
+ return f
diff --git a/libs/flask_socketio/__init__.py b/libs/flask_socketio/__init__.py
new file mode 100644
index 000000000..e4209f1e9
--- /dev/null
+++ b/libs/flask_socketio/__init__.py
@@ -0,0 +1,922 @@
+from functools import wraps
+import os
+import sys
+
+# make sure gevent-socketio is not installed, as it conflicts with
+# python-socketio
+gevent_socketio_found = True
+try:
+ from socketio import socketio_manage
+except ImportError:
+ gevent_socketio_found = False
+if gevent_socketio_found:
+ print('The gevent-socketio package is incompatible with this version of '
+ 'the Flask-SocketIO extension. Please uninstall it, and then '
+ 'install the latest version of python-socketio in its place.')
+ sys.exit(1)
+
+import flask
+from flask import _request_ctx_stack, json as flask_json
+from flask.sessions import SessionMixin
+import socketio
+from socketio.exceptions import ConnectionRefusedError
+from werkzeug.debug import DebuggedApplication
+from werkzeug.serving import run_with_reloader
+
+from .namespace import Namespace
+from .test_client import SocketIOTestClient
+
+__version__ = '4.2.1'
+
+
+class _SocketIOMiddleware(socketio.WSGIApp):
+ """This WSGI middleware simply exposes the Flask application in the WSGI
+ environment before executing the request.
+ """
+ def __init__(self, socketio_app, flask_app, socketio_path='socket.io'):
+ self.flask_app = flask_app
+ super(_SocketIOMiddleware, self).__init__(socketio_app,
+ flask_app.wsgi_app,
+ socketio_path=socketio_path)
+
+ def __call__(self, environ, start_response):
+ environ = environ.copy()
+ environ['flask.app'] = self.flask_app
+ return super(_SocketIOMiddleware, self).__call__(environ,
+ start_response)
+
+
+class _ManagedSession(dict, SessionMixin):
+ """This class is used for user sessions that are managed by
+ Flask-SocketIO. It is simple dict, expanded with the Flask session
+ attributes."""
+ pass
+
+
+class SocketIO(object):
+ """Create a Flask-SocketIO server.
+
+ :param app: The flask application instance. If the application instance
+ isn't known at the time this class is instantiated, then call
+ ``socketio.init_app(app)`` once the application instance is
+ available.
+ :param manage_session: If set to ``True``, this extension manages the user
+ session for Socket.IO events. If set to ``False``,
+ Flask's own session management is used. When using
+ Flask's cookie based sessions it is recommended that
+ you leave this set to the default of ``True``. When
+ using server-side sessions, a ``False`` setting
+ enables sharing the user session between HTTP routes
+ and Socket.IO events.
+ :param message_queue: A connection URL for a message queue service the
+ server can use for multi-process communication. A
+ message queue is not required when using a single
+ server process.
+ :param channel: The channel name, when using a message queue. If a channel
+ isn't specified, a default channel will be used. If
+ multiple clusters of SocketIO processes need to use the
+ same message queue without interfering with each other, then
+ each cluster should use a different channel.
+ :param path: The path where the Socket.IO server is exposed. Defaults to
+ ``'socket.io'``. Leave this as is unless you know what you are
+ doing.
+ :param resource: Alias to ``path``.
+ :param kwargs: Socket.IO and Engine.IO server options.
+
+ The Socket.IO server options are detailed below:
+
+ :param client_manager: The client manager instance that will manage the
+ client list. When this is omitted, the client list
+ is stored in an in-memory structure, so the use of
+ multiple connected servers is not possible. In most
+ cases, this argument does not need to be set
+ explicitly.
+ :param logger: To enable logging set to ``True`` or pass a logger object to
+ use. To disable logging set to ``False``. The default is
+ ``False``.
+ :param binary: ``True`` to support binary payloads, ``False`` to treat all
+ payloads as text. On Python 2, if this is set to ``True``,
+ ``unicode`` values are treated as text, and ``str`` and
+ ``bytes`` values are treated as binary. This option has no
+ effect on Python 3, where text and binary payloads are
+ always automatically discovered.
+ :param json: An alternative json module to use for encoding and decoding
+ packets. Custom json modules must have ``dumps`` and ``loads``
+ functions that are compatible with the standard library
+ versions. To use the same json encoder and decoder as a Flask
+ application, use ``flask.json``.
+ :param async_handlers: If set to ``True``, event handlers for a client are
+ executed in separate threads. To run handlers for a
+ client synchronously, set to ``False``. The default
+ is ``True``.
+ :param always_connect: When set to ``False``, new connections are
+ provisory until the connect handler returns
+ something other than ``False``, at which point they
+ are accepted. When set to ``True``, connections are
+ immediately accepted, and then if the connect
+ handler returns ``False`` a disconnect is issued.
+ Set to ``True`` if you need to emit events from the
+ connect handler and your client is confused when it
+ receives events before the connection acceptance.
+ In any other case use the default of ``False``.
+
+ The Engine.IO server configuration supports the following settings:
+
+ :param async_mode: The asynchronous model to use. See the Deployment
+ section in the documentation for a description of the
+ available options. Valid async modes are
+ ``threading``, ``eventlet``, ``gevent`` and
+ ``gevent_uwsgi``. If this argument is not given,
+ ``eventlet`` is tried first, then ``gevent_uwsgi``,
+ then ``gevent``, and finally ``threading``. The
+ first async mode that has all its dependencies installed
+ is then one that is chosen.
+ :param ping_timeout: The time in seconds that the client waits for the
+ server to respond before disconnecting. The default is
+ 60 seconds.
+ :param ping_interval: The interval in seconds at which the client pings
+ the server. The default is 25 seconds.
+ :param max_http_buffer_size: The maximum size of a message when using the
+ polling transport. The default is 100,000,000
+ bytes.
+ :param allow_upgrades: Whether to allow transport upgrades or not. The
+ default is ``True``.
+ :param http_compression: Whether to compress packages when using the
+ polling transport. The default is ``True``.
+ :param compression_threshold: Only compress messages when their byte size
+ is greater than this value. The default is
+ 1024 bytes.
+ :param cookie: Name of the HTTP cookie that contains the client session
+ id. If set to ``None``, a cookie is not sent to the client.
+ The default is ``'io'``.
+ :param cors_allowed_origins: Origin or list of origins that are allowed to
+ connect to this server. Only the same origin
+ is allowed by default. Set this argument to
+ ``'*'`` to allow all origins, or to ``[]`` to
+ disable CORS handling.
+ :param cors_credentials: Whether credentials (cookies, authentication) are
+ allowed in requests to this server. The default is
+ ``True``.
+ :param monitor_clients: If set to ``True``, a background task will ensure
+ inactive clients are closed. Set to ``False`` to
+ disable the monitoring task (not recommended). The
+ default is ``True``.
+ :param engineio_logger: To enable Engine.IO logging set to ``True`` or pass
+ a logger object to use. To disable logging set to
+ ``False``. The default is ``False``.
+ """
+
+ def __init__(self, app=None, **kwargs):
+ self.server = None
+ self.server_options = {}
+ self.wsgi_server = None
+ self.handlers = []
+ self.namespace_handlers = []
+ self.exception_handlers = {}
+ self.default_exception_handler = None
+ self.manage_session = True
+ # We can call init_app when:
+ # - we were given the Flask app instance (standard initialization)
+ # - we were not given the app, but we were given a message_queue
+ # (standard initialization for auxiliary process)
+ # In all other cases we collect the arguments and assume the client
+ # will call init_app from an app factory function.
+ if app is not None or 'message_queue' in kwargs:
+ self.init_app(app, **kwargs)
+ else:
+ self.server_options.update(kwargs)
+
+ def init_app(self, app, **kwargs):
+ if app is not None:
+ if not hasattr(app, 'extensions'):
+ app.extensions = {} # pragma: no cover
+ app.extensions['socketio'] = self
+ self.server_options.update(kwargs)
+ self.manage_session = self.server_options.pop('manage_session',
+ self.manage_session)
+
+ if 'client_manager' not in self.server_options:
+ url = self.server_options.pop('message_queue', None)
+ channel = self.server_options.pop('channel', 'flask-socketio')
+ write_only = app is None
+ if url:
+ if url.startswith(('redis://', "rediss://")):
+ queue_class = socketio.RedisManager
+ elif url.startswith(('kafka://')):
+ queue_class = socketio.KafkaManager
+ elif url.startswith('zmq'):
+ queue_class = socketio.ZmqManager
+ else:
+ queue_class = socketio.KombuManager
+ queue = queue_class(url, channel=channel,
+ write_only=write_only)
+ self.server_options['client_manager'] = queue
+
+ if 'json' in self.server_options and \
+ self.server_options['json'] == flask_json:
+ # flask's json module is tricky to use because its output
+ # changes when it is invoked inside or outside the app context
+ # so here to prevent any ambiguities we replace it with wrappers
+ # that ensure that the app context is always present
+ class FlaskSafeJSON(object):
+ @staticmethod
+ def dumps(*args, **kwargs):
+ with app.app_context():
+ return flask_json.dumps(*args, **kwargs)
+
+ @staticmethod
+ def loads(*args, **kwargs):
+ with app.app_context():
+ return flask_json.loads(*args, **kwargs)
+
+ self.server_options['json'] = FlaskSafeJSON
+
+ resource = self.server_options.pop('path', None) or \
+ self.server_options.pop('resource', None) or 'socket.io'
+ if resource.startswith('/'):
+ resource = resource[1:]
+ if os.environ.get('FLASK_RUN_FROM_CLI'):
+ if self.server_options.get('async_mode') is None:
+ if app is not None:
+ app.logger.warning(
+ 'Flask-SocketIO is Running under Werkzeug, WebSocket '
+ 'is not available.')
+ self.server_options['async_mode'] = 'threading'
+ self.server = socketio.Server(**self.server_options)
+ self.async_mode = self.server.async_mode
+ for handler in self.handlers:
+ self.server.on(handler[0], handler[1], namespace=handler[2])
+ for namespace_handler in self.namespace_handlers:
+ self.server.register_namespace(namespace_handler)
+
+ if app is not None:
+ # here we attach the SocketIO middlware to the SocketIO object so it
+ # can be referenced later if debug middleware needs to be inserted
+ self.sockio_mw = _SocketIOMiddleware(self.server, app,
+ socketio_path=resource)
+ app.wsgi_app = self.sockio_mw
+
+ def on(self, message, namespace=None):
+ """Decorator to register a SocketIO event handler.
+
+ This decorator must be applied to SocketIO event handlers. Example::
+
+ @socketio.on('my event', namespace='/chat')
+ def handle_my_custom_event(json):
+ print('received json: ' + str(json))
+
+ :param message: The name of the event. This is normally a user defined
+ string, but a few event names are already defined. Use
+ ``'message'`` to define a handler that takes a string
+ payload, ``'json'`` to define a handler that takes a
+ JSON blob payload, ``'connect'`` or ``'disconnect'``
+ to create handlers for connection and disconnection
+ events.
+ :param namespace: The namespace on which the handler is to be
+ registered. Defaults to the global namespace.
+ """
+ namespace = namespace or '/'
+
+ def decorator(handler):
+ @wraps(handler)
+ def _handler(sid, *args):
+ return self._handle_event(handler, message, namespace, sid,
+ *args)
+
+ if self.server:
+ self.server.on(message, _handler, namespace=namespace)
+ else:
+ self.handlers.append((message, _handler, namespace))
+ return handler
+ return decorator
+
+ def on_error(self, namespace=None):
+ """Decorator to define a custom error handler for SocketIO events.
+
+ This decorator can be applied to a function that acts as an error
+ handler for a namespace. This handler will be invoked when a SocketIO
+ event handler raises an exception. The handler function must accept one
+ argument, which is the exception raised. Example::
+
+ @socketio.on_error(namespace='/chat')
+ def chat_error_handler(e):
+ print('An error has occurred: ' + str(e))
+
+ :param namespace: The namespace for which to register the error
+ handler. Defaults to the global namespace.
+ """
+ namespace = namespace or '/'
+
+ def decorator(exception_handler):
+ if not callable(exception_handler):
+ raise ValueError('exception_handler must be callable')
+ self.exception_handlers[namespace] = exception_handler
+ return exception_handler
+ return decorator
+
+ def on_error_default(self, exception_handler):
+ """Decorator to define a default error handler for SocketIO events.
+
+ This decorator can be applied to a function that acts as a default
+ error handler for any namespaces that do not have a specific handler.
+ Example::
+
+ @socketio.on_error_default
+ def error_handler(e):
+ print('An error has occurred: ' + str(e))
+ """
+ if not callable(exception_handler):
+ raise ValueError('exception_handler must be callable')
+ self.default_exception_handler = exception_handler
+ return exception_handler
+
+ def on_event(self, message, handler, namespace=None):
+ """Register a SocketIO event handler.
+
+ ``on_event`` is the non-decorator version of ``'on'``.
+
+ Example::
+
+ def on_foo_event(json):
+ print('received json: ' + str(json))
+
+ socketio.on_event('my event', on_foo_event, namespace='/chat')
+
+ :param message: The name of the event. This is normally a user defined
+ string, but a few event names are already defined. Use
+ ``'message'`` to define a handler that takes a string
+ payload, ``'json'`` to define a handler that takes a
+ JSON blob payload, ``'connect'`` or ``'disconnect'``
+ to create handlers for connection and disconnection
+ events.
+ :param handler: The function that handles the event.
+ :param namespace: The namespace on which the handler is to be
+ registered. Defaults to the global namespace.
+ """
+ self.on(message, namespace=namespace)(handler)
+
+ def on_namespace(self, namespace_handler):
+ if not isinstance(namespace_handler, Namespace):
+ raise ValueError('Not a namespace instance.')
+ namespace_handler._set_socketio(self)
+ if self.server:
+ self.server.register_namespace(namespace_handler)
+ else:
+ self.namespace_handlers.append(namespace_handler)
+
+ def emit(self, event, *args, **kwargs):
+ """Emit a server generated SocketIO event.
+
+ This function emits a SocketIO event to one or more connected clients.
+ A JSON blob can be attached to the event as payload. This function can
+ be used outside of a SocketIO event context, so it is appropriate to
+ use when the server is the originator of an event, outside of any
+ client context, such as in a regular HTTP request handler or a
+ background task. Example::
+
+ @app.route('/ping')
+ def ping():
+ socketio.emit('ping event', {'data': 42}, namespace='/chat')
+
+ :param event: The name of the user event to emit.
+ :param args: A dictionary with the JSON data to send as payload.
+ :param namespace: The namespace under which the message is to be sent.
+ Defaults to the global namespace.
+ :param room: Send the message to all the users in the given room. If
+ this parameter is not included, the event is sent to
+ all connected users.
+ :param skip_sid: The session id of a client to ignore when broadcasting
+ or addressing a room. This is typically set to the
+ originator of the message, so that everyone except
+ that client receive the message. To skip multiple sids
+ pass a list.
+ :param callback: If given, this function will be called to acknowledge
+ that the client has received the message. The
+ arguments that will be passed to the function are
+ those provided by the client. Callback functions can
+ only be used when addressing an individual client.
+ """
+ namespace = kwargs.pop('namespace', '/')
+ room = kwargs.pop('room', None)
+ include_self = kwargs.pop('include_self', True)
+ skip_sid = kwargs.pop('skip_sid', None)
+ if not include_self and not skip_sid:
+ skip_sid = flask.request.sid
+ callback = kwargs.pop('callback', None)
+ if callback:
+ # wrap the callback so that it sets app app and request contexts
+ sid = flask.request.sid
+ original_callback = callback
+
+ def _callback_wrapper(*args):
+ return self._handle_event(original_callback, None, namespace,
+ sid, *args)
+
+ callback = _callback_wrapper
+ self.server.emit(event, *args, namespace=namespace, room=room,
+ skip_sid=skip_sid, callback=callback, **kwargs)
+
+ def send(self, data, json=False, namespace=None, room=None,
+ callback=None, include_self=True, skip_sid=None, **kwargs):
+ """Send a server-generated SocketIO message.
+
+ This function sends a simple SocketIO message to one or more connected
+ clients. The message can be a string or a JSON blob. This is a simpler
+ version of ``emit()``, which should be preferred. This function can be
+ used outside of a SocketIO event context, so it is appropriate to use
+ when the server is the originator of an event.
+
+ :param data: The message to send, either a string or a JSON blob.
+ :param json: ``True`` if ``message`` is a JSON blob, ``False``
+ otherwise.
+ :param namespace: The namespace under which the message is to be sent.
+ Defaults to the global namespace.
+ :param room: Send the message only to the users in the given room. If
+ this parameter is not included, the message is sent to
+ all connected users.
+ :param skip_sid: The session id of a client to ignore when broadcasting
+ or addressing a room. This is typically set to the
+ originator of the message, so that everyone except
+ that client receive the message. To skip multiple sids
+ pass a list.
+ :param callback: If given, this function will be called to acknowledge
+ that the client has received the message. The
+ arguments that will be passed to the function are
+ those provided by the client. Callback functions can
+ only be used when addressing an individual client.
+ """
+ skip_sid = flask.request.sid if not include_self else skip_sid
+ if json:
+ self.emit('json', data, namespace=namespace, room=room,
+ skip_sid=skip_sid, callback=callback, **kwargs)
+ else:
+ self.emit('message', data, namespace=namespace, room=room,
+ skip_sid=skip_sid, callback=callback, **kwargs)
+
+ def close_room(self, room, namespace=None):
+ """Close a room.
+
+ This function removes any users that are in the given room and then
+ deletes the room from the server. This function can be used outside
+ of a SocketIO event context.
+
+ :param room: The name of the room to close.
+ :param namespace: The namespace under which the room exists. Defaults
+ to the global namespace.
+ """
+ self.server.close_room(room, namespace)
+
+ def run(self, app, host=None, port=None, **kwargs):
+ """Run the SocketIO web server.
+
+ :param app: The Flask application instance.
+ :param host: The hostname or IP address for the server to listen on.
+ Defaults to 127.0.0.1.
+ :param port: The port number for the server to listen on. Defaults to
+ 5000.
+ :param debug: ``True`` to start the server in debug mode, ``False`` to
+ start in normal mode.
+ :param use_reloader: ``True`` to enable the Flask reloader, ``False``
+ to disable it.
+ :param extra_files: A list of additional files that the Flask
+ reloader should watch. Defaults to ``None``
+ :param log_output: If ``True``, the server logs all incomming
+ connections. If ``False`` logging is disabled.
+ Defaults to ``True`` in debug mode, ``False``
+ in normal mode. Unused when the threading async
+ mode is used.
+ :param kwargs: Additional web server options. The web server options
+ are specific to the server used in each of the supported
+ async modes. Note that options provided here will
+ not be seen when using an external web server such
+ as gunicorn, since this method is not called in that
+ case.
+ """
+ if host is None:
+ host = '127.0.0.1'
+ if port is None:
+ server_name = app.config['SERVER_NAME']
+ if server_name and ':' in server_name:
+ port = int(server_name.rsplit(':', 1)[1])
+ else:
+ port = 5000
+
+ debug = kwargs.pop('debug', app.debug)
+ log_output = kwargs.pop('log_output', debug)
+ use_reloader = kwargs.pop('use_reloader', debug)
+ extra_files = kwargs.pop('extra_files', None)
+
+ app.debug = debug
+ if app.debug and self.server.eio.async_mode != 'threading':
+ # put the debug middleware between the SocketIO middleware
+ # and the Flask application instance
+ #
+ # mw1 mw2 mw3 Flask app
+ # o ---- o ---- o ---- o
+ # /
+ # o Flask-SocketIO
+ # \ middleware
+ # o
+ # Flask-SocketIO WebSocket handler
+ #
+ # BECOMES
+ #
+ # dbg-mw mw1 mw2 mw3 Flask app
+ # o ---- o ---- o ---- o ---- o
+ # /
+ # o Flask-SocketIO
+ # \ middleware
+ # o
+ # Flask-SocketIO WebSocket handler
+ #
+ self.sockio_mw.wsgi_app = DebuggedApplication(self.sockio_mw.wsgi_app,
+ evalex=True)
+
+ if self.server.eio.async_mode == 'threading':
+ from werkzeug._internal import _log
+ _log('warning', 'WebSocket transport not available. Install '
+ 'eventlet or gevent and gevent-websocket for '
+ 'improved performance.')
+ app.run(host=host, port=port, threaded=True,
+ use_reloader=use_reloader, **kwargs)
+ elif self.server.eio.async_mode == 'eventlet':
+ def run_server():
+ import eventlet
+ import eventlet.wsgi
+ import eventlet.green
+ addresses = eventlet.green.socket.getaddrinfo(host, port)
+ if not addresses:
+ raise RuntimeError('Could not resolve host to a valid address')
+ eventlet_socket = eventlet.listen(addresses[0][4], addresses[0][0])
+
+ # If provided an SSL argument, use an SSL socket
+ ssl_args = ['keyfile', 'certfile', 'server_side', 'cert_reqs',
+ 'ssl_version', 'ca_certs',
+ 'do_handshake_on_connect', 'suppress_ragged_eofs',
+ 'ciphers']
+ ssl_params = {k: kwargs[k] for k in kwargs if k in ssl_args}
+ if len(ssl_params) > 0:
+ for k in ssl_params:
+ kwargs.pop(k)
+ ssl_params['server_side'] = True # Listening requires true
+ eventlet_socket = eventlet.wrap_ssl(eventlet_socket,
+ **ssl_params)
+
+ eventlet.wsgi.server(eventlet_socket, app,
+ log_output=log_output, **kwargs)
+
+ if use_reloader:
+ run_with_reloader(run_server, extra_files=extra_files)
+ else:
+ run_server()
+ elif self.server.eio.async_mode == 'gevent':
+ from gevent import pywsgi
+ try:
+ from geventwebsocket.handler import WebSocketHandler
+ websocket = True
+ except ImportError:
+ websocket = False
+
+ log = 'default'
+ if not log_output:
+ log = None
+ if websocket:
+ self.wsgi_server = pywsgi.WSGIServer(
+ (host, port), app, handler_class=WebSocketHandler,
+ log=log, **kwargs)
+ else:
+ self.wsgi_server = pywsgi.WSGIServer((host, port), app,
+ log=log, **kwargs)
+
+ if use_reloader:
+ # monkey patching is required by the reloader
+ from gevent import monkey
+ monkey.patch_thread()
+ monkey.patch_time()
+
+ def run_server():
+ self.wsgi_server.serve_forever()
+
+ run_with_reloader(run_server, extra_files=extra_files)
+ else:
+ self.wsgi_server.serve_forever()
+
+ def stop(self):
+ """Stop a running SocketIO web server.
+
+ This method must be called from a HTTP or SocketIO handler function.
+ """
+ if self.server.eio.async_mode == 'threading':
+ func = flask.request.environ.get('werkzeug.server.shutdown')
+ if func:
+ func()
+ else:
+ raise RuntimeError('Cannot stop unknown web server')
+ elif self.server.eio.async_mode == 'eventlet':
+ raise SystemExit
+ elif self.server.eio.async_mode == 'gevent':
+ self.wsgi_server.stop()
+
+ def start_background_task(self, target, *args, **kwargs):
+ """Start a background task using the appropriate async model.
+
+ This is a utility function that applications can use to start a
+ background task using the method that is compatible with the
+ selected async mode.
+
+ :param target: the target function to execute.
+ :param args: arguments to pass to the function.
+ :param kwargs: keyword arguments to pass to the function.
+
+ This function returns an object compatible with the `Thread` class in
+ the Python standard library. The `start()` method on this object is
+ already called by this function.
+ """
+ return self.server.start_background_task(target, *args, **kwargs)
+
+ def sleep(self, seconds=0):
+ """Sleep for the requested amount of time using the appropriate async
+ model.
+
+ This is a utility function that applications can use to put a task to
+ sleep without having to worry about using the correct call for the
+ selected async mode.
+ """
+ return self.server.sleep(seconds)
+
+ def test_client(self, app, namespace=None, query_string=None,
+ headers=None, flask_test_client=None):
+ """The Socket.IO test client is useful for testing a Flask-SocketIO
+ server. It works in a similar way to the Flask Test Client, but
+ adapted to the Socket.IO server.
+
+ :param app: The Flask application instance.
+ :param namespace: The namespace for the client. If not provided, the
+ client connects to the server on the global
+ namespace.
+ :param query_string: A string with custom query string arguments.
+ :param headers: A dictionary with custom HTTP headers.
+ :param flask_test_client: The instance of the Flask test client
+ currently in use. Passing the Flask test
+ client is optional, but is necessary if you
+ want the Flask user session and any other
+ cookies set in HTTP routes accessible from
+ Socket.IO events.
+ """
+ return SocketIOTestClient(app, self, namespace=namespace,
+ query_string=query_string, headers=headers,
+ flask_test_client=flask_test_client)
+
+ def _handle_event(self, handler, message, namespace, sid, *args):
+ if sid not in self.server.environ:
+ # we don't have record of this client, ignore this event
+ return '', 400
+ app = self.server.environ[sid]['flask.app']
+ with app.request_context(self.server.environ[sid]):
+ if self.manage_session:
+ # manage a separate session for this client's Socket.IO events
+ # created as a copy of the regular user session
+ if 'saved_session' not in self.server.environ[sid]:
+ self.server.environ[sid]['saved_session'] = \
+ _ManagedSession(flask.session)
+ session_obj = self.server.environ[sid]['saved_session']
+ else:
+ # let Flask handle the user session
+ # for cookie based sessions, this effectively freezes the
+ # session to its state at connection time
+ # for server-side sessions, this allows HTTP and Socket.IO to
+ # share the session, with both having read/write access to it
+ session_obj = flask.session._get_current_object()
+ _request_ctx_stack.top.session = session_obj
+ flask.request.sid = sid
+ flask.request.namespace = namespace
+ flask.request.event = {'message': message, 'args': args}
+ try:
+ if message == 'connect':
+ ret = handler()
+ else:
+ ret = handler(*args)
+ except:
+ err_handler = self.exception_handlers.get(
+ namespace, self.default_exception_handler)
+ if err_handler is None:
+ raise
+ type, value, traceback = sys.exc_info()
+ return err_handler(value)
+ if not self.manage_session:
+ # when Flask is managing the user session, it needs to save it
+ if not hasattr(session_obj, 'modified') or session_obj.modified:
+ resp = app.response_class()
+ app.session_interface.save_session(app, session_obj, resp)
+ return ret
+
+
+def emit(event, *args, **kwargs):
+ """Emit a SocketIO event.
+
+ This function emits a SocketIO event to one or more connected clients. A
+ JSON blob can be attached to the event as payload. This is a function that
+ can only be called from a SocketIO event handler, as in obtains some
+ information from the current client context. Example::
+
+ @socketio.on('my event')
+ def handle_my_custom_event(json):
+ emit('my response', {'data': 42})
+
+ :param event: The name of the user event to emit.
+ :param args: A dictionary with the JSON data to send as payload.
+ :param namespace: The namespace under which the message is to be sent.
+ Defaults to the namespace used by the originating event.
+ A ``'/'`` can be used to explicitly specify the global
+ namespace.
+ :param callback: Callback function to invoke with the client's
+ acknowledgement.
+ :param broadcast: ``True`` to send the message to all clients, or ``False``
+ to only reply to the sender of the originating event.
+ :param room: Send the message to all the users in the given room. If this
+ argument is set, then broadcast is implied to be ``True``.
+ :param include_self: ``True`` to include the sender when broadcasting or
+ addressing a room, or ``False`` to send to everyone
+ but the sender.
+ :param ignore_queue: Only used when a message queue is configured. If
+ set to ``True``, the event is emitted to the
+ clients directly, without going through the queue.
+ This is more efficient, but only works when a
+ single server process is used, or when there is a
+ single addresee. It is recommended to always leave
+ this parameter with its default value of ``False``.
+ """
+ if 'namespace' in kwargs:
+ namespace = kwargs['namespace']
+ else:
+ namespace = flask.request.namespace
+ callback = kwargs.get('callback')
+ broadcast = kwargs.get('broadcast')
+ room = kwargs.get('room')
+ if room is None and not broadcast:
+ room = flask.request.sid
+ include_self = kwargs.get('include_self', True)
+ ignore_queue = kwargs.get('ignore_queue', False)
+
+ socketio = flask.current_app.extensions['socketio']
+ return socketio.emit(event, *args, namespace=namespace, room=room,
+ include_self=include_self, callback=callback,
+ ignore_queue=ignore_queue)
+
+
+def send(message, **kwargs):
+ """Send a SocketIO message.
+
+ This function sends a simple SocketIO message to one or more connected
+ clients. The message can be a string or a JSON blob. This is a simpler
+ version of ``emit()``, which should be preferred. This is a function that
+ can only be called from a SocketIO event handler.
+
+ :param message: The message to send, either a string or a JSON blob.
+ :param json: ``True`` if ``message`` is a JSON blob, ``False``
+ otherwise.
+ :param namespace: The namespace under which the message is to be sent.
+ Defaults to the namespace used by the originating event.
+ An empty string can be used to use the global namespace.
+ :param callback: Callback function to invoke with the client's
+ acknowledgement.
+ :param broadcast: ``True`` to send the message to all connected clients, or
+ ``False`` to only reply to the sender of the originating
+ event.
+ :param room: Send the message to all the users in the given room.
+ :param include_self: ``True`` to include the sender when broadcasting or
+ addressing a room, or ``False`` to send to everyone
+ but the sender.
+ :param ignore_queue: Only used when a message queue is configured. If
+ set to ``True``, the event is emitted to the
+ clients directly, without going through the queue.
+ This is more efficient, but only works when a
+ single server process is used, or when there is a
+ single addresee. It is recommended to always leave
+ this parameter with its default value of ``False``.
+ """
+ json = kwargs.get('json', False)
+ if 'namespace' in kwargs:
+ namespace = kwargs['namespace']
+ else:
+ namespace = flask.request.namespace
+ callback = kwargs.get('callback')
+ broadcast = kwargs.get('broadcast')
+ room = kwargs.get('room')
+ if room is None and not broadcast:
+ room = flask.request.sid
+ include_self = kwargs.get('include_self', True)
+ ignore_queue = kwargs.get('ignore_queue', False)
+
+ socketio = flask.current_app.extensions['socketio']
+ return socketio.send(message, json=json, namespace=namespace, room=room,
+ include_self=include_self, callback=callback,
+ ignore_queue=ignore_queue)
+
+
+def join_room(room, sid=None, namespace=None):
+ """Join a room.
+
+ This function puts the user in a room, under the current namespace. The
+ user and the namespace are obtained from the event context. This is a
+ function that can only be called from a SocketIO event handler. Example::
+
+ @socketio.on('join')
+ def on_join(data):
+ username = session['username']
+ room = data['room']
+ join_room(room)
+ send(username + ' has entered the room.', room=room)
+
+ :param room: The name of the room to join.
+ :param sid: The session id of the client. If not provided, the client is
+ obtained from the request context.
+ :param namespace: The namespace for the room. If not provided, the
+ namespace is obtained from the request context.
+ """
+ socketio = flask.current_app.extensions['socketio']
+ sid = sid or flask.request.sid
+ namespace = namespace or flask.request.namespace
+ socketio.server.enter_room(sid, room, namespace=namespace)
+
+
+def leave_room(room, sid=None, namespace=None):
+ """Leave a room.
+
+ This function removes the user from a room, under the current namespace.
+ The user and the namespace are obtained from the event context. Example::
+
+ @socketio.on('leave')
+ def on_leave(data):
+ username = session['username']
+ room = data['room']
+ leave_room(room)
+ send(username + ' has left the room.', room=room)
+
+ :param room: The name of the room to leave.
+ :param sid: The session id of the client. If not provided, the client is
+ obtained from the request context.
+ :param namespace: The namespace for the room. If not provided, the
+ namespace is obtained from the request context.
+ """
+ socketio = flask.current_app.extensions['socketio']
+ sid = sid or flask.request.sid
+ namespace = namespace or flask.request.namespace
+ socketio.server.leave_room(sid, room, namespace=namespace)
+
+
+def close_room(room, namespace=None):
+ """Close a room.
+
+ This function removes any users that are in the given room and then deletes
+ the room from the server.
+
+ :param room: The name of the room to close.
+ :param namespace: The namespace for the room. If not provided, the
+ namespace is obtained from the request context.
+ """
+ socketio = flask.current_app.extensions['socketio']
+ namespace = namespace or flask.request.namespace
+ socketio.server.close_room(room, namespace=namespace)
+
+
+def rooms(sid=None, namespace=None):
+ """Return a list of the rooms the client is in.
+
+ This function returns all the rooms the client has entered, including its
+ own room, assigned by the Socket.IO server.
+
+ :param sid: The session id of the client. If not provided, the client is
+ obtained from the request context.
+ :param namespace: The namespace for the room. If not provided, the
+ namespace is obtained from the request context.
+ """
+ socketio = flask.current_app.extensions['socketio']
+ sid = sid or flask.request.sid
+ namespace = namespace or flask.request.namespace
+ return socketio.server.rooms(sid, namespace=namespace)
+
+
+def disconnect(sid=None, namespace=None, silent=False):
+ """Disconnect the client.
+
+ This function terminates the connection with the client. As a result of
+ this call the client will receive a disconnect event. Example::
+
+ @socketio.on('message')
+ def receive_message(msg):
+ if is_banned(session['username']):
+ disconnect()
+ else:
+ # ...
+
+ :param sid: The session id of the client. If not provided, the client is
+ obtained from the request context.
+ :param namespace: The namespace for the room. If not provided, the
+ namespace is obtained from the request context.
+ :param silent: this option is deprecated.
+ """
+ socketio = flask.current_app.extensions['socketio']
+ sid = sid or flask.request.sid
+ namespace = namespace or flask.request.namespace
+ return socketio.server.disconnect(sid, namespace=namespace)
diff --git a/libs/flask_socketio/namespace.py b/libs/flask_socketio/namespace.py
new file mode 100644
index 000000000..914ff3816
--- /dev/null
+++ b/libs/flask_socketio/namespace.py
@@ -0,0 +1,47 @@
+from socketio import Namespace as _Namespace
+
+
+class Namespace(_Namespace):
+ def __init__(self, namespace=None):
+ super(Namespace, self).__init__(namespace)
+ self.socketio = None
+
+ def _set_socketio(self, socketio):
+ self.socketio = socketio
+
+ def trigger_event(self, event, *args):
+ """Dispatch an event to the proper handler method.
+
+ In the most common usage, this method is not overloaded by subclasses,
+ as it performs the routing of events to methods. However, this
+ method can be overriden if special dispatching rules are needed, or if
+ having a single method that catches all events is desired.
+ """
+ handler_name = 'on_' + event
+ if not hasattr(self, handler_name):
+ # there is no handler for this event, so we ignore it
+ return
+ handler = getattr(self, handler_name)
+ return self.socketio._handle_event(handler, event, self.namespace,
+ *args)
+
+ def emit(self, event, data=None, room=None, include_self=True,
+ namespace=None, callback=None):
+ """Emit a custom event to one or more connected clients."""
+ return self.socketio.emit(event, data, room=room,
+ include_self=include_self,
+ namespace=namespace or self.namespace,
+ callback=callback)
+
+ def send(self, data, room=None, include_self=True, namespace=None,
+ callback=None):
+ """Send a message to one or more connected clients."""
+ return self.socketio.send(data, room=room, include_self=include_self,
+ namespace=namespace or self.namespace,
+ callback=callback)
+
+ def close_room(self, room, namespace=None):
+ """Close a room."""
+ return self.socketio.close_room(room=room,
+ namespace=namespace or self.namespace)
+
diff --git a/libs/flask_socketio/test_client.py b/libs/flask_socketio/test_client.py
new file mode 100644
index 000000000..0c4592034
--- /dev/null
+++ b/libs/flask_socketio/test_client.py
@@ -0,0 +1,205 @@
+import uuid
+
+from socketio import packet
+from socketio.pubsub_manager import PubSubManager
+from werkzeug.test import EnvironBuilder
+
+
+class SocketIOTestClient(object):
+ """
+ This class is useful for testing a Flask-SocketIO server. It works in a
+ similar way to the Flask Test Client, but adapted to the Socket.IO server.
+
+ :param app: The Flask application instance.
+ :param socketio: The application's ``SocketIO`` instance.
+ :param namespace: The namespace for the client. If not provided, the client
+ connects to the server on the global namespace.
+ :param query_string: A string with custom query string arguments.
+ :param headers: A dictionary with custom HTTP headers.
+ :param flask_test_client: The instance of the Flask test client
+ currently in use. Passing the Flask test
+ client is optional, but is necessary if you
+ want the Flask user session and any other
+ cookies set in HTTP routes accessible from
+ Socket.IO events.
+ """
+ queue = {}
+ acks = {}
+
+ def __init__(self, app, socketio, namespace=None, query_string=None,
+ headers=None, flask_test_client=None):
+ def _mock_send_packet(sid, pkt):
+ if pkt.packet_type == packet.EVENT or \
+ pkt.packet_type == packet.BINARY_EVENT:
+ if sid not in self.queue:
+ self.queue[sid] = []
+ if pkt.data[0] == 'message' or pkt.data[0] == 'json':
+ self.queue[sid].append({'name': pkt.data[0],
+ 'args': pkt.data[1],
+ 'namespace': pkt.namespace or '/'})
+ else:
+ self.queue[sid].append({'name': pkt.data[0],
+ 'args': pkt.data[1:],
+ 'namespace': pkt.namespace or '/'})
+ elif pkt.packet_type == packet.ACK or \
+ pkt.packet_type == packet.BINARY_ACK:
+ self.acks[sid] = {'args': pkt.data,
+ 'namespace': pkt.namespace or '/'}
+ elif pkt.packet_type == packet.DISCONNECT:
+ self.connected[pkt.namespace or '/'] = False
+
+ self.app = app
+ self.flask_test_client = flask_test_client
+ self.sid = uuid.uuid4().hex
+ self.queue[self.sid] = []
+ self.acks[self.sid] = None
+ self.callback_counter = 0
+ self.socketio = socketio
+ self.connected = {}
+ socketio.server._send_packet = _mock_send_packet
+ socketio.server.environ[self.sid] = {}
+ socketio.server.async_handlers = False # easier to test when
+ socketio.server.eio.async_handlers = False # events are sync
+ if isinstance(socketio.server.manager, PubSubManager):
+ raise RuntimeError('Test client cannot be used with a message '
+ 'queue. Disable the queue on your test '
+ 'configuration.')
+ socketio.server.manager.initialize()
+ self.connect(namespace=namespace, query_string=query_string,
+ headers=headers)
+
+ def is_connected(self, namespace=None):
+ """Check if a namespace is connected.
+
+ :param namespace: The namespace to check. The global namespace is
+ assumed if this argument is not provided.
+ """
+ return self.connected.get(namespace or '/', False)
+
+ def connect(self, namespace=None, query_string=None, headers=None):
+ """Connect the client.
+
+ :param namespace: The namespace for the client. If not provided, the
+ client connects to the server on the global
+ namespace.
+ :param query_string: A string with custom query string arguments.
+ :param headers: A dictionary with custom HTTP headers.
+
+ Note that it is usually not necessary to explicitly call this method,
+ since a connection is automatically established when an instance of
+ this class is created. An example where it this method would be useful
+ is when the application accepts multiple namespace connections.
+ """
+ url = '/socket.io'
+ if query_string:
+ if query_string[0] != '?':
+ query_string = '?' + query_string
+ url += query_string
+ environ = EnvironBuilder(url, headers=headers).get_environ()
+ environ['flask.app'] = self.app
+ if self.flask_test_client:
+ # inject cookies from Flask
+ self.flask_test_client.cookie_jar.inject_wsgi(environ)
+ self.connected['/'] = True
+ if self.socketio.server._handle_eio_connect(
+ self.sid, environ) is False:
+ del self.connected['/']
+ if namespace is not None and namespace != '/':
+ self.connected[namespace] = True
+ pkt = packet.Packet(packet.CONNECT, namespace=namespace)
+ with self.app.app_context():
+ if self.socketio.server._handle_eio_message(
+ self.sid, pkt.encode()) is False:
+ del self.connected[namespace]
+
+ def disconnect(self, namespace=None):
+ """Disconnect the client.
+
+ :param namespace: The namespace to disconnect. The global namespace is
+ assumed if this argument is not provided.
+ """
+ if not self.is_connected(namespace):
+ raise RuntimeError('not connected')
+ pkt = packet.Packet(packet.DISCONNECT, namespace=namespace)
+ with self.app.app_context():
+ self.socketio.server._handle_eio_message(self.sid, pkt.encode())
+ del self.connected[namespace or '/']
+
+ def emit(self, event, *args, **kwargs):
+ """Emit an event to the server.
+
+ :param event: The event name.
+ :param *args: The event arguments.
+ :param callback: ``True`` if the client requests a callback, ``False``
+ if not. Note that client-side callbacks are not
+ implemented, a callback request will just tell the
+ server to provide the arguments to invoke the
+ callback, but no callback is invoked. Instead, the
+ arguments that the server provided for the callback
+ are returned by this function.
+ :param namespace: The namespace of the event. The global namespace is
+ assumed if this argument is not provided.
+ """
+ namespace = kwargs.pop('namespace', None)
+ if not self.is_connected(namespace):
+ raise RuntimeError('not connected')
+ callback = kwargs.pop('callback', False)
+ id = None
+ if callback:
+ self.callback_counter += 1
+ id = self.callback_counter
+ pkt = packet.Packet(packet.EVENT, data=[event] + list(args),
+ namespace=namespace, id=id)
+ with self.app.app_context():
+ encoded_pkt = pkt.encode()
+ if isinstance(encoded_pkt, list):
+ for epkt in encoded_pkt:
+ self.socketio.server._handle_eio_message(self.sid, epkt)
+ else:
+ self.socketio.server._handle_eio_message(self.sid, encoded_pkt)
+ ack = self.acks.pop(self.sid, None)
+ if ack is not None:
+ return ack['args'][0] if len(ack['args']) == 1 \
+ else ack['args']
+
+ def send(self, data, json=False, callback=False, namespace=None):
+ """Send a text or JSON message to the server.
+
+ :param data: A string, dictionary or list to send to the server.
+ :param json: ``True`` to send a JSON message, ``False`` to send a text
+ message.
+ :param callback: ``True`` if the client requests a callback, ``False``
+ if not. Note that client-side callbacks are not
+ implemented, a callback request will just tell the
+ server to provide the arguments to invoke the
+ callback, but no callback is invoked. Instead, the
+ arguments that the server provided for the callback
+ are returned by this function.
+ :param namespace: The namespace of the event. The global namespace is
+ assumed if this argument is not provided.
+ """
+ if json:
+ msg = 'json'
+ else:
+ msg = 'message'
+ return self.emit(msg, data, callback=callback, namespace=namespace)
+
+ def get_received(self, namespace=None):
+ """Return the list of messages received from the server.
+
+ Since this is not a real client, any time the server emits an event,
+ the event is simply stored. The test code can invoke this method to
+ obtain the list of events that were received since the last call.
+
+ :param namespace: The namespace to get events from. The global
+ namespace is assumed if this argument is not
+ provided.
+ """
+ if not self.is_connected(namespace):
+ raise RuntimeError('not connected')
+ namespace = namespace or '/'
+ r = [pkt for pkt in self.queue[self.sid]
+ if pkt['namespace'] == namespace]
+ self.queue[self.sid] = [pkt for pkt in self.queue[self.sid]
+ if pkt not in r]
+ return r
diff --git a/libs/socketio/__init__.py b/libs/socketio/__init__.py
new file mode 100644
index 000000000..d3ee7242b
--- /dev/null
+++ b/libs/socketio/__init__.py
@@ -0,0 +1,38 @@
+import sys
+
+from .client import Client
+from .base_manager import BaseManager
+from .pubsub_manager import PubSubManager
+from .kombu_manager import KombuManager
+from .redis_manager import RedisManager
+from .kafka_manager import KafkaManager
+from .zmq_manager import ZmqManager
+from .server import Server
+from .namespace import Namespace, ClientNamespace
+from .middleware import WSGIApp, Middleware
+from .tornado import get_tornado_handler
+if sys.version_info >= (3, 5): # pragma: no cover
+ from .asyncio_client import AsyncClient
+ from .asyncio_server import AsyncServer
+ from .asyncio_manager import AsyncManager
+ from .asyncio_namespace import AsyncNamespace, AsyncClientNamespace
+ from .asyncio_redis_manager import AsyncRedisManager
+ from .asyncio_aiopika_manager import AsyncAioPikaManager
+ from .asgi import ASGIApp
+else: # pragma: no cover
+ AsyncClient = None
+ AsyncServer = None
+ AsyncManager = None
+ AsyncNamespace = None
+ AsyncRedisManager = None
+ AsyncAioPikaManager = None
+
+__version__ = '4.4.0'
+
+__all__ = ['__version__', 'Client', 'Server', 'BaseManager', 'PubSubManager',
+ 'KombuManager', 'RedisManager', 'ZmqManager', 'KafkaManager',
+ 'Namespace', 'ClientNamespace', 'WSGIApp', 'Middleware']
+if AsyncServer is not None: # pragma: no cover
+ __all__ += ['AsyncClient', 'AsyncServer', 'AsyncNamespace',
+ 'AsyncClientNamespace', 'AsyncManager', 'AsyncRedisManager',
+ 'ASGIApp', 'get_tornado_handler', 'AsyncAioPikaManager']
diff --git a/libs/socketio/asgi.py b/libs/socketio/asgi.py
new file mode 100644
index 000000000..9bcdd03ba
--- /dev/null
+++ b/libs/socketio/asgi.py
@@ -0,0 +1,36 @@
+import engineio
+
+
+class ASGIApp(engineio.ASGIApp): # pragma: no cover
+ """ASGI application middleware for Socket.IO.
+
+ This middleware dispatches traffic to an Socket.IO application. It can
+ also serve a list of static files to the client, or forward unrelated
+ HTTP traffic to another ASGI application.
+
+ :param socketio_server: The Socket.IO server. Must be an instance of the
+ ``socketio.AsyncServer`` class.
+ :param static_files: A dictionary with static file mapping rules. See the
+ documentation for details on this argument.
+ :param other_asgi_app: A separate ASGI app that receives all other traffic.
+ :param socketio_path: The endpoint where the Socket.IO application should
+ be installed. The default value is appropriate for
+ most cases.
+
+ Example usage::
+
+ import socketio
+ import uvicorn
+
+ sio = socketio.AsyncServer()
+ app = engineio.ASGIApp(sio, static_files={
+ '/': 'index.html',
+ '/static': './public',
+ })
+ uvicorn.run(app, host='127.0.0.1', port=5000)
+ """
+ def __init__(self, socketio_server, other_asgi_app=None,
+ static_files=None, socketio_path='socket.io'):
+ super().__init__(socketio_server, other_asgi_app,
+ static_files=static_files,
+ engineio_path=socketio_path)
diff --git a/libs/socketio/asyncio_aiopika_manager.py b/libs/socketio/asyncio_aiopika_manager.py
new file mode 100644
index 000000000..b20d6afd9
--- /dev/null
+++ b/libs/socketio/asyncio_aiopika_manager.py
@@ -0,0 +1,105 @@
+import asyncio
+import pickle
+
+from socketio.asyncio_pubsub_manager import AsyncPubSubManager
+
+try:
+ import aio_pika
+except ImportError:
+ aio_pika = None
+
+
+class AsyncAioPikaManager(AsyncPubSubManager): # pragma: no cover
+ """Client manager that uses aio_pika for inter-process messaging under
+ asyncio.
+
+ This class implements a client manager backend for event sharing across
+ multiple processes, using RabbitMQ
+
+ To use a aio_pika backend, initialize the :class:`Server` instance as
+ follows::
+
+ url = 'amqp://user:password@hostname:port//'
+ server = socketio.Server(client_manager=socketio.AsyncAioPikaManager(
+ url))
+
+ :param url: The connection URL for the backend messaging queue. Example
+ connection URLs are ``'amqp://guest:guest@localhost:5672//'``
+ for RabbitMQ.
+ :param channel: The channel name on which the server sends and receives
+ notifications. Must be the same in all the servers.
+ With this manager, the channel name is the exchange name
+ in rabbitmq
+ :param write_only: If set ot ``True``, only initialize to emit events. The
+ default of ``False`` initializes the class for emitting
+ and receiving.
+ """
+
+ name = 'asyncaiopika'
+
+ def __init__(self, url='amqp://guest:guest@localhost:5672//',
+ channel='socketio', write_only=False, logger=None):
+ if aio_pika is None:
+ raise RuntimeError('aio_pika package is not installed '
+ '(Run "pip install aio_pika" in your '
+ 'virtualenv).')
+ self.url = url
+ self.listener_connection = None
+ self.listener_channel = None
+ self.listener_queue = None
+ super().__init__(channel=channel, write_only=write_only, logger=logger)
+
+ async def _connection(self):
+ return await aio_pika.connect_robust(self.url)
+
+ async def _channel(self, connection):
+ return await connection.channel()
+
+ async def _exchange(self, channel):
+ return await channel.declare_exchange(self.channel,
+ aio_pika.ExchangeType.FANOUT)
+
+ async def _queue(self, channel, exchange):
+ queue = await channel.declare_queue(durable=False,
+ arguments={'x-expires': 300000})
+ await queue.bind(exchange)
+ return queue
+
+ async def _publish(self, data):
+ connection = await self._connection()
+ channel = await self._channel(connection)
+ exchange = await self._exchange(channel)
+ await exchange.publish(
+ aio_pika.Message(body=pickle.dumps(data),
+ delivery_mode=aio_pika.DeliveryMode.PERSISTENT),
+ routing_key='*'
+ )
+
+ async def _listen(self):
+ retry_sleep = 1
+ while True:
+ try:
+ if self.listener_connection is None:
+ self.listener_connection = await self._connection()
+ self.listener_channel = await self._channel(
+ self.listener_connection
+ )
+ await self.listener_channel.set_qos(prefetch_count=1)
+ exchange = await self._exchange(self.listener_channel)
+ self.listener_queue = await self._queue(
+ self.listener_channel, exchange
+ )
+
+ async with self.listener_queue.iterator() as queue_iter:
+ async for message in queue_iter:
+ with message.process():
+ return pickle.loads(message.body)
+ except Exception:
+ self._get_logger().error('Cannot receive from rabbitmq... '
+ 'retrying in '
+ '{} secs'.format(retry_sleep))
+ self.listener_connection = None
+ await asyncio.sleep(retry_sleep)
+ retry_sleep *= 2
+ if retry_sleep > 60:
+ retry_sleep = 60
diff --git a/libs/socketio/asyncio_client.py b/libs/socketio/asyncio_client.py
new file mode 100644
index 000000000..2b10434ae
--- /dev/null
+++ b/libs/socketio/asyncio_client.py
@@ -0,0 +1,475 @@
+import asyncio
+import logging
+import random
+
+import engineio
+import six
+
+from . import client
+from . import exceptions
+from . import packet
+
+default_logger = logging.getLogger('socketio.client')
+
+
+class AsyncClient(client.Client):
+ """A Socket.IO client for asyncio.
+
+ This class implements a fully compliant Socket.IO web client with support
+ for websocket and long-polling transports.
+
+ :param reconnection: ``True`` if the client should automatically attempt to
+ reconnect to the server after an interruption, or
+ ``False`` to not reconnect. The default is ``True``.
+ :param reconnection_attempts: How many reconnection attempts to issue
+ before giving up, or 0 for infinity attempts.
+ The default is 0.
+ :param reconnection_delay: How long to wait in seconds before the first
+ reconnection attempt. Each successive attempt
+ doubles this delay.
+ :param reconnection_delay_max: The maximum delay between reconnection
+ attempts.
+ :param randomization_factor: Randomization amount for each delay between
+ reconnection attempts. The default is 0.5,
+ which means that each delay is randomly
+ adjusted by +/- 50%.
+ :param logger: To enable logging set to ``True`` or pass a logger object to
+ use. To disable logging set to ``False``. The default is
+ ``False``.
+ :param binary: ``True`` to support binary payloads, ``False`` to treat all
+ payloads as text. On Python 2, if this is set to ``True``,
+ ``unicode`` values are treated as text, and ``str`` and
+ ``bytes`` values are treated as binary. This option has no
+ effect on Python 3, where text and binary payloads are
+ always automatically discovered.
+ :param json: An alternative json module to use for encoding and decoding
+ packets. Custom json modules must have ``dumps`` and ``loads``
+ functions that are compatible with the standard library
+ versions.
+
+ The Engine.IO configuration supports the following settings:
+
+ :param request_timeout: A timeout in seconds for requests. The default is
+ 5 seconds.
+ :param ssl_verify: ``True`` to verify SSL certificates, or ``False`` to
+ skip SSL certificate verification, allowing
+ connections to servers with self signed certificates.
+ The default is ``True``.
+ :param engineio_logger: To enable Engine.IO logging set to ``True`` or pass
+ a logger object to use. To disable logging set to
+ ``False``. The default is ``False``.
+ """
+ def is_asyncio_based(self):
+ return True
+
+ async def connect(self, url, headers={}, transports=None,
+ namespaces=None, socketio_path='socket.io'):
+ """Connect to a Socket.IO server.
+
+ :param url: The URL of the Socket.IO server. It can include custom
+ query string parameters if required by the server.
+ :param headers: A dictionary with custom headers to send with the
+ connection request.
+ :param transports: The list of allowed transports. Valid transports
+ are ``'polling'`` and ``'websocket'``. If not
+ given, the polling transport is connected first,
+ then an upgrade to websocket is attempted.
+ :param namespaces: The list of custom namespaces to connect, in
+ addition to the default namespace. If not given,
+ the namespace list is obtained from the registered
+ event handlers.
+ :param socketio_path: The endpoint where the Socket.IO server is
+ installed. The default value is appropriate for
+ most cases.
+
+ Note: this method is a coroutine.
+
+ Example usage::
+
+ sio = socketio.Client()
+ sio.connect('http://localhost:5000')
+ """
+ self.connection_url = url
+ self.connection_headers = headers
+ self.connection_transports = transports
+ self.connection_namespaces = namespaces
+ self.socketio_path = socketio_path
+
+ if namespaces is None:
+ namespaces = set(self.handlers.keys()).union(
+ set(self.namespace_handlers.keys()))
+ elif isinstance(namespaces, six.string_types):
+ namespaces = [namespaces]
+ self.connection_namespaces = namespaces
+ self.namespaces = [n for n in namespaces if n != '/']
+ try:
+ await self.eio.connect(url, headers=headers,
+ transports=transports,
+ engineio_path=socketio_path)
+ except engineio.exceptions.ConnectionError as exc:
+ six.raise_from(exceptions.ConnectionError(exc.args[0]), None)
+ self.connected = True
+
+ async def wait(self):
+ """Wait until the connection with the server ends.
+
+ Client applications can use this function to block the main thread
+ during the life of the connection.
+
+ Note: this method is a coroutine.
+ """
+ while True:
+ await self.eio.wait()
+ await self.sleep(1) # give the reconnect task time to start up
+ if not self._reconnect_task:
+ break
+ await self._reconnect_task
+ if self.eio.state != 'connected':
+ break
+
+ async def emit(self, event, data=None, namespace=None, callback=None):
+ """Emit a custom event to one or more connected clients.
+
+ :param event: The event name. It can be any string. The event names
+ ``'connect'``, ``'message'`` and ``'disconnect'`` are
+ reserved and should not be used.
+ :param data: The data to send to the client or clients. Data can be of
+ type ``str``, ``bytes``, ``list`` or ``dict``. If a
+ ``list`` or ``dict``, the data will be serialized as JSON.
+ :param namespace: The Socket.IO namespace for the event. If this
+ argument is omitted the event is emitted to the
+ default namespace.
+ :param callback: If given, this function will be called to acknowledge
+ the the client has received the message. The arguments
+ that will be passed to the function are those provided
+ by the client. Callback functions can only be used
+ when addressing an individual client.
+
+ Note: this method is a coroutine.
+ """
+ namespace = namespace or '/'
+ if namespace != '/' and namespace not in self.namespaces:
+ raise exceptions.BadNamespaceError(
+ namespace + ' is not a connected namespace.')
+ self.logger.info('Emitting event "%s" [%s]', event, namespace)
+ if callback is not None:
+ id = self._generate_ack_id(namespace, callback)
+ else:
+ id = None
+ if six.PY2 and not self.binary:
+ binary = False # pragma: nocover
+ else:
+ binary = None
+ # tuples are expanded to multiple arguments, everything else is sent
+ # as a single argument
+ if isinstance(data, tuple):
+ data = list(data)
+ elif data is not None:
+ data = [data]
+ else:
+ data = []
+ await self._send_packet(packet.Packet(
+ packet.EVENT, namespace=namespace, data=[event] + data, id=id,
+ binary=binary))
+
+ async def send(self, data, namespace=None, callback=None):
+ """Send a message to one or more connected clients.
+
+ This function emits an event with the name ``'message'``. Use
+ :func:`emit` to issue custom event names.
+
+ :param data: The data to send to the client or clients. Data can be of
+ type ``str``, ``bytes``, ``list`` or ``dict``. If a
+ ``list`` or ``dict``, the data will be serialized as JSON.
+ :param namespace: The Socket.IO namespace for the event. If this
+ argument is omitted the event is emitted to the
+ default namespace.
+ :param callback: If given, this function will be called to acknowledge
+ the the client has received the message. The arguments
+ that will be passed to the function are those provided
+ by the client. Callback functions can only be used
+ when addressing an individual client.
+
+ Note: this method is a coroutine.
+ """
+ await self.emit('message', data=data, namespace=namespace,
+ callback=callback)
+
+ async def call(self, event, data=None, namespace=None, timeout=60):
+ """Emit a custom event to a client and wait for the response.
+
+ :param event: The event name. It can be any string. The event names
+ ``'connect'``, ``'message'`` and ``'disconnect'`` are
+ reserved and should not be used.
+ :param data: The data to send to the client or clients. Data can be of
+ type ``str``, ``bytes``, ``list`` or ``dict``. If a
+ ``list`` or ``dict``, the data will be serialized as JSON.
+ :param namespace: The Socket.IO namespace for the event. If this
+ argument is omitted the event is emitted to the
+ default namespace.
+ :param timeout: The waiting timeout. If the timeout is reached before
+ the client acknowledges the event, then a
+ ``TimeoutError`` exception is raised.
+
+ Note: this method is a coroutine.
+ """
+ callback_event = self.eio.create_event()
+ callback_args = []
+
+ def event_callback(*args):
+ callback_args.append(args)
+ callback_event.set()
+
+ await self.emit(event, data=data, namespace=namespace,
+ callback=event_callback)
+ try:
+ await asyncio.wait_for(callback_event.wait(), timeout)
+ except asyncio.TimeoutError:
+ six.raise_from(exceptions.TimeoutError(), None)
+ return callback_args[0] if len(callback_args[0]) > 1 \
+ else callback_args[0][0] if len(callback_args[0]) == 1 \
+ else None
+
+ async def disconnect(self):
+ """Disconnect from the server.
+
+ Note: this method is a coroutine.
+ """
+ # here we just request the disconnection
+ # later in _handle_eio_disconnect we invoke the disconnect handler
+ for n in self.namespaces:
+ await self._send_packet(packet.Packet(packet.DISCONNECT,
+ namespace=n))
+ await self._send_packet(packet.Packet(
+ packet.DISCONNECT, namespace='/'))
+ self.connected = False
+ await self.eio.disconnect(abort=True)
+
+ def start_background_task(self, target, *args, **kwargs):
+ """Start a background task using the appropriate async model.
+
+ This is a utility function that applications can use to start a
+ background task using the method that is compatible with the
+ selected async mode.
+
+ :param target: the target function to execute.
+ :param args: arguments to pass to the function.
+ :param kwargs: keyword arguments to pass to the function.
+
+ This function returns an object compatible with the `Thread` class in
+ the Python standard library. The `start()` method on this object is
+ already called by this function.
+ """
+ return self.eio.start_background_task(target, *args, **kwargs)
+
+ async def sleep(self, seconds=0):
+ """Sleep for the requested amount of time using the appropriate async
+ model.
+
+ This is a utility function that applications can use to put a task to
+ sleep without having to worry about using the correct call for the
+ selected async mode.
+
+ Note: this method is a coroutine.
+ """
+ return await self.eio.sleep(seconds)
+
+ async def _send_packet(self, pkt):
+ """Send a Socket.IO packet to the server."""
+ encoded_packet = pkt.encode()
+ if isinstance(encoded_packet, list):
+ binary = False
+ for ep in encoded_packet:
+ await self.eio.send(ep, binary=binary)
+ binary = True
+ else:
+ await self.eio.send(encoded_packet, binary=False)
+
+ async def _handle_connect(self, namespace):
+ namespace = namespace or '/'
+ self.logger.info('Namespace {} is connected'.format(namespace))
+ await self._trigger_event('connect', namespace=namespace)
+ if namespace == '/':
+ for n in self.namespaces:
+ await self._send_packet(packet.Packet(packet.CONNECT,
+ namespace=n))
+ elif namespace not in self.namespaces:
+ self.namespaces.append(namespace)
+
+ async def _handle_disconnect(self, namespace):
+ if not self.connected:
+ return
+ namespace = namespace or '/'
+ if namespace == '/':
+ for n in self.namespaces:
+ await self._trigger_event('disconnect', namespace=n)
+ self.namespaces = []
+ await self._trigger_event('disconnect', namespace=namespace)
+ if namespace in self.namespaces:
+ self.namespaces.remove(namespace)
+ if namespace == '/':
+ self.connected = False
+
+ async def _handle_event(self, namespace, id, data):
+ namespace = namespace or '/'
+ self.logger.info('Received event "%s" [%s]', data[0], namespace)
+ r = await self._trigger_event(data[0], namespace, *data[1:])
+ if id is not None:
+ # send ACK packet with the response returned by the handler
+ # tuples are expanded as multiple arguments
+ if r is None:
+ data = []
+ elif isinstance(r, tuple):
+ data = list(r)
+ else:
+ data = [r]
+ if six.PY2 and not self.binary:
+ binary = False # pragma: nocover
+ else:
+ binary = None
+ await self._send_packet(packet.Packet(
+ packet.ACK, namespace=namespace, id=id, data=data,
+ binary=binary))
+
+ async def _handle_ack(self, namespace, id, data):
+ namespace = namespace or '/'
+ self.logger.info('Received ack [%s]', namespace)
+ callback = None
+ try:
+ callback = self.callbacks[namespace][id]
+ except KeyError:
+ # if we get an unknown callback we just ignore it
+ self.logger.warning('Unknown callback received, ignoring.')
+ else:
+ del self.callbacks[namespace][id]
+ if callback is not None:
+ if asyncio.iscoroutinefunction(callback):
+ await callback(*data)
+ else:
+ callback(*data)
+
+ async def _handle_error(self, namespace, data):
+ namespace = namespace or '/'
+ self.logger.info('Connection to namespace {} was rejected'.format(
+ namespace))
+ if data is None:
+ data = tuple()
+ elif not isinstance(data, (tuple, list)):
+ data = (data,)
+ await self._trigger_event('connect_error', namespace, *data)
+ if namespace in self.namespaces:
+ self.namespaces.remove(namespace)
+ if namespace == '/':
+ self.namespaces = []
+ self.connected = False
+
+ async def _trigger_event(self, event, namespace, *args):
+ """Invoke an application event handler."""
+ # first see if we have an explicit handler for the event
+ if namespace in self.handlers and event in self.handlers[namespace]:
+ if asyncio.iscoroutinefunction(self.handlers[namespace][event]):
+ try:
+ ret = await self.handlers[namespace][event](*args)
+ except asyncio.CancelledError: # pragma: no cover
+ ret = None
+ else:
+ ret = self.handlers[namespace][event](*args)
+ return ret
+
+ # or else, forward the event to a namepsace handler if one exists
+ elif namespace in self.namespace_handlers:
+ return await self.namespace_handlers[namespace].trigger_event(
+ event, *args)
+
+ async def _handle_reconnect(self):
+ self._reconnect_abort.clear()
+ client.reconnecting_clients.append(self)
+ attempt_count = 0
+ current_delay = self.reconnection_delay
+ while True:
+ delay = current_delay
+ current_delay *= 2
+ if delay > self.reconnection_delay_max:
+ delay = self.reconnection_delay_max
+ delay += self.randomization_factor * (2 * random.random() - 1)
+ self.logger.info(
+ 'Connection failed, new attempt in {:.02f} seconds'.format(
+ delay))
+ try:
+ await asyncio.wait_for(self._reconnect_abort.wait(), delay)
+ self.logger.info('Reconnect task aborted')
+ break
+ except (asyncio.TimeoutError, asyncio.CancelledError):
+ pass
+ attempt_count += 1
+ try:
+ await self.connect(self.connection_url,
+ headers=self.connection_headers,
+ transports=self.connection_transports,
+ namespaces=self.connection_namespaces,
+ socketio_path=self.socketio_path)
+ except (exceptions.ConnectionError, ValueError):
+ pass
+ else:
+ self.logger.info('Reconnection successful')
+ self._reconnect_task = None
+ break
+ if self.reconnection_attempts and \
+ attempt_count >= self.reconnection_attempts:
+ self.logger.info(
+ 'Maximum reconnection attempts reached, giving up')
+ break
+ client.reconnecting_clients.remove(self)
+
+ def _handle_eio_connect(self):
+ """Handle the Engine.IO connection event."""
+ self.logger.info('Engine.IO connection established')
+ self.sid = self.eio.sid
+
+ async def _handle_eio_message(self, data):
+ """Dispatch Engine.IO messages."""
+ if self._binary_packet:
+ pkt = self._binary_packet
+ if pkt.add_attachment(data):
+ self._binary_packet = None
+ if pkt.packet_type == packet.BINARY_EVENT:
+ await self._handle_event(pkt.namespace, pkt.id, pkt.data)
+ else:
+ await self._handle_ack(pkt.namespace, pkt.id, pkt.data)
+ else:
+ pkt = packet.Packet(encoded_packet=data)
+ if pkt.packet_type == packet.CONNECT:
+ await self._handle_connect(pkt.namespace)
+ elif pkt.packet_type == packet.DISCONNECT:
+ await self._handle_disconnect(pkt.namespace)
+ elif pkt.packet_type == packet.EVENT:
+ await self._handle_event(pkt.namespace, pkt.id, pkt.data)
+ elif pkt.packet_type == packet.ACK:
+ await self._handle_ack(pkt.namespace, pkt.id, pkt.data)
+ elif pkt.packet_type == packet.BINARY_EVENT or \
+ pkt.packet_type == packet.BINARY_ACK:
+ self._binary_packet = pkt
+ elif pkt.packet_type == packet.ERROR:
+ await self._handle_error(pkt.namespace, pkt.data)
+ else:
+ raise ValueError('Unknown packet type.')
+
+ async def _handle_eio_disconnect(self):
+ """Handle the Engine.IO disconnection event."""
+ self.logger.info('Engine.IO connection dropped')
+ self._reconnect_abort.set()
+ if self.connected:
+ for n in self.namespaces:
+ await self._trigger_event('disconnect', namespace=n)
+ await self._trigger_event('disconnect', namespace='/')
+ self.namespaces = []
+ self.connected = False
+ self.callbacks = {}
+ self._binary_packet = None
+ self.sid = None
+ if self.eio.state == 'connected' and self.reconnection:
+ self._reconnect_task = self.start_background_task(
+ self._handle_reconnect)
+
+ def _engineio_client_class(self):
+ return engineio.AsyncClient
diff --git a/libs/socketio/asyncio_manager.py b/libs/socketio/asyncio_manager.py
new file mode 100644
index 000000000..f4496ec7f
--- /dev/null
+++ b/libs/socketio/asyncio_manager.py
@@ -0,0 +1,58 @@
+import asyncio
+
+from .base_manager import BaseManager
+
+
+class AsyncManager(BaseManager):
+ """Manage a client list for an asyncio server."""
+ async def emit(self, event, data, namespace, room=None, skip_sid=None,
+ callback=None, **kwargs):
+ """Emit a message to a single client, a room, or all the clients
+ connected to the namespace.
+
+ Note: this method is a coroutine.
+ """
+ if namespace not in self.rooms or room not in self.rooms[namespace]:
+ return
+ tasks = []
+ if not isinstance(skip_sid, list):
+ skip_sid = [skip_sid]
+ for sid in self.get_participants(namespace, room):
+ if sid not in skip_sid:
+ if callback is not None:
+ id = self._generate_ack_id(sid, namespace, callback)
+ else:
+ id = None
+ tasks.append(self.server._emit_internal(sid, event, data,
+ namespace, id))
+ if tasks == []: # pragma: no cover
+ return
+ await asyncio.wait(tasks)
+
+ async def close_room(self, room, namespace):
+ """Remove all participants from a room.
+
+ Note: this method is a coroutine.
+ """
+ return super().close_room(room, namespace)
+
+ async def trigger_callback(self, sid, namespace, id, data):
+ """Invoke an application callback.
+
+ Note: this method is a coroutine.
+ """
+ callback = None
+ try:
+ callback = self.callbacks[sid][namespace][id]
+ except KeyError:
+ # if we get an unknown callback we just ignore it
+ self._get_logger().warning('Unknown callback received, ignoring.')
+ else:
+ del self.callbacks[sid][namespace][id]
+ if callback is not None:
+ ret = callback(*data)
+ if asyncio.iscoroutine(ret):
+ try:
+ await ret
+ except asyncio.CancelledError: # pragma: no cover
+ pass
diff --git a/libs/socketio/asyncio_namespace.py b/libs/socketio/asyncio_namespace.py
new file mode 100644
index 000000000..12e9c0fe6
--- /dev/null
+++ b/libs/socketio/asyncio_namespace.py
@@ -0,0 +1,204 @@
+import asyncio
+
+from socketio import namespace
+
+
+class AsyncNamespace(namespace.Namespace):
+ """Base class for asyncio server-side class-based namespaces.
+
+ A class-based namespace is a class that contains all the event handlers
+ for a Socket.IO namespace. The event handlers are methods of the class
+ with the prefix ``on_``, such as ``on_connect``, ``on_disconnect``,
+ ``on_message``, ``on_json``, and so on. These can be regular functions or
+ coroutines.
+
+ :param namespace: The Socket.IO namespace to be used with all the event
+ handlers defined in this class. If this argument is
+ omitted, the default namespace is used.
+ """
+ def is_asyncio_based(self):
+ return True
+
+ async def trigger_event(self, event, *args):
+ """Dispatch an event to the proper handler method.
+
+ In the most common usage, this method is not overloaded by subclasses,
+ as it performs the routing of events to methods. However, this
+ method can be overriden if special dispatching rules are needed, or if
+ having a single method that catches all events is desired.
+
+ Note: this method is a coroutine.
+ """
+ handler_name = 'on_' + event
+ if hasattr(self, handler_name):
+ handler = getattr(self, handler_name)
+ if asyncio.iscoroutinefunction(handler) is True:
+ try:
+ ret = await handler(*args)
+ except asyncio.CancelledError: # pragma: no cover
+ ret = None
+ else:
+ ret = handler(*args)
+ return ret
+
+ async def emit(self, event, data=None, room=None, skip_sid=None,
+ namespace=None, callback=None):
+ """Emit a custom event to one or more connected clients.
+
+ The only difference with the :func:`socketio.Server.emit` method is
+ that when the ``namespace`` argument is not given the namespace
+ associated with the class is used.
+
+ Note: this method is a coroutine.
+ """
+ return await self.server.emit(event, data=data, room=room,
+ skip_sid=skip_sid,
+ namespace=namespace or self.namespace,
+ callback=callback)
+
+ async def send(self, data, room=None, skip_sid=None, namespace=None,
+ callback=None):
+ """Send a message to one or more connected clients.
+
+ The only difference with the :func:`socketio.Server.send` method is
+ that when the ``namespace`` argument is not given the namespace
+ associated with the class is used.
+
+ Note: this method is a coroutine.
+ """
+ return await self.server.send(data, room=room, skip_sid=skip_sid,
+ namespace=namespace or self.namespace,
+ callback=callback)
+
+ async def close_room(self, room, namespace=None):
+ """Close a room.
+
+ The only difference with the :func:`socketio.Server.close_room` method
+ is that when the ``namespace`` argument is not given the namespace
+ associated with the class is used.
+
+ Note: this method is a coroutine.
+ """
+ return await self.server.close_room(
+ room, namespace=namespace or self.namespace)
+
+ async def get_session(self, sid, namespace=None):
+ """Return the user session for a client.
+
+ The only difference with the :func:`socketio.Server.get_session`
+ method is that when the ``namespace`` argument is not given the
+ namespace associated with the class is used.
+
+ Note: this method is a coroutine.
+ """
+ return await self.server.get_session(
+ sid, namespace=namespace or self.namespace)
+
+ async def save_session(self, sid, session, namespace=None):
+ """Store the user session for a client.
+
+ The only difference with the :func:`socketio.Server.save_session`
+ method is that when the ``namespace`` argument is not given the
+ namespace associated with the class is used.
+
+ Note: this method is a coroutine.
+ """
+ return await self.server.save_session(
+ sid, session, namespace=namespace or self.namespace)
+
+ def session(self, sid, namespace=None):
+ """Return the user session for a client with context manager syntax.
+
+ The only difference with the :func:`socketio.Server.session` method is
+ that when the ``namespace`` argument is not given the namespace
+ associated with the class is used.
+ """
+ return self.server.session(sid, namespace=namespace or self.namespace)
+
+ async def disconnect(self, sid, namespace=None):
+ """Disconnect a client.
+
+ The only difference with the :func:`socketio.Server.disconnect` method
+ is that when the ``namespace`` argument is not given the namespace
+ associated with the class is used.
+
+ Note: this method is a coroutine.
+ """
+ return await self.server.disconnect(
+ sid, namespace=namespace or self.namespace)
+
+
+class AsyncClientNamespace(namespace.ClientNamespace):
+ """Base class for asyncio client-side class-based namespaces.
+
+ A class-based namespace is a class that contains all the event handlers
+ for a Socket.IO namespace. The event handlers are methods of the class
+ with the prefix ``on_``, such as ``on_connect``, ``on_disconnect``,
+ ``on_message``, ``on_json``, and so on. These can be regular functions or
+ coroutines.
+
+ :param namespace: The Socket.IO namespace to be used with all the event
+ handlers defined in this class. If this argument is
+ omitted, the default namespace is used.
+ """
+ def is_asyncio_based(self):
+ return True
+
+ async def trigger_event(self, event, *args):
+ """Dispatch an event to the proper handler method.
+
+ In the most common usage, this method is not overloaded by subclasses,
+ as it performs the routing of events to methods. However, this
+ method can be overriden if special dispatching rules are needed, or if
+ having a single method that catches all events is desired.
+
+ Note: this method is a coroutine.
+ """
+ handler_name = 'on_' + event
+ if hasattr(self, handler_name):
+ handler = getattr(self, handler_name)
+ if asyncio.iscoroutinefunction(handler) is True:
+ try:
+ ret = await handler(*args)
+ except asyncio.CancelledError: # pragma: no cover
+ ret = None
+ else:
+ ret = handler(*args)
+ return ret
+
+ async def emit(self, event, data=None, namespace=None, callback=None):
+ """Emit a custom event to the server.
+
+ The only difference with the :func:`socketio.Client.emit` method is
+ that when the ``namespace`` argument is not given the namespace
+ associated with the class is used.
+
+ Note: this method is a coroutine.
+ """
+ return await self.client.emit(event, data=data,
+ namespace=namespace or self.namespace,
+ callback=callback)
+
+ async def send(self, data, namespace=None, callback=None):
+ """Send a message to the server.
+
+ The only difference with the :func:`socketio.Client.send` method is
+ that when the ``namespace`` argument is not given the namespace
+ associated with the class is used.
+
+ Note: this method is a coroutine.
+ """
+ return await self.client.send(data,
+ namespace=namespace or self.namespace,
+ callback=callback)
+
+ async def disconnect(self):
+ """Disconnect a client.
+
+ The only difference with the :func:`socketio.Client.disconnect` method
+ is that when the ``namespace`` argument is not given the namespace
+ associated with the class is used.
+
+ Note: this method is a coroutine.
+ """
+ return await self.client.disconnect()
diff --git a/libs/socketio/asyncio_pubsub_manager.py b/libs/socketio/asyncio_pubsub_manager.py
new file mode 100644
index 000000000..6fdba6d0c
--- /dev/null
+++ b/libs/socketio/asyncio_pubsub_manager.py
@@ -0,0 +1,163 @@
+from functools import partial
+import uuid
+
+import json
+import pickle
+import six
+
+from .asyncio_manager import AsyncManager
+
+
+class AsyncPubSubManager(AsyncManager):
+ """Manage a client list attached to a pub/sub backend under asyncio.
+
+ This is a base class that enables multiple servers to share the list of
+ clients, with the servers communicating events through a pub/sub backend.
+ The use of a pub/sub backend also allows any client connected to the
+ backend to emit events addressed to Socket.IO clients.
+
+ The actual backends must be implemented by subclasses, this class only
+ provides a pub/sub generic framework for asyncio applications.
+
+ :param channel: The channel name on which the server sends and receives
+ notifications.
+ """
+ name = 'asyncpubsub'
+
+ def __init__(self, channel='socketio', write_only=False, logger=None):
+ super().__init__()
+ self.channel = channel
+ self.write_only = write_only
+ self.host_id = uuid.uuid4().hex
+ self.logger = logger
+
+ def initialize(self):
+ super().initialize()
+ if not self.write_only:
+ self.thread = self.server.start_background_task(self._thread)
+ self._get_logger().info(self.name + ' backend initialized.')
+
+ async def emit(self, event, data, namespace=None, room=None, skip_sid=None,
+ callback=None, **kwargs):
+ """Emit a message to a single client, a room, or all the clients
+ connected to the namespace.
+
+ This method takes care or propagating the message to all the servers
+ that are connected through the message queue.
+
+ The parameters are the same as in :meth:`.Server.emit`.
+
+ Note: this method is a coroutine.
+ """
+ if kwargs.get('ignore_queue'):
+ return await super().emit(
+ event, data, namespace=namespace, room=room, skip_sid=skip_sid,
+ callback=callback)
+ namespace = namespace or '/'
+ if callback is not None:
+ if self.server is None:
+ raise RuntimeError('Callbacks can only be issued from the '
+ 'context of a server.')
+ if room is None:
+ raise ValueError('Cannot use callback without a room set.')
+ id = self._generate_ack_id(room, namespace, callback)
+ callback = (room, namespace, id)
+ else:
+ callback = None
+ await self._publish({'method': 'emit', 'event': event, 'data': data,
+ 'namespace': namespace, 'room': room,
+ 'skip_sid': skip_sid, 'callback': callback,
+ 'host_id': self.host_id})
+
+ async def close_room(self, room, namespace=None):
+ await self._publish({'method': 'close_room', 'room': room,
+ 'namespace': namespace or '/'})
+
+ async def _publish(self, data):
+ """Publish a message on the Socket.IO channel.
+
+ This method needs to be implemented by the different subclasses that
+ support pub/sub backends.
+ """
+ raise NotImplementedError('This method must be implemented in a '
+ 'subclass.') # pragma: no cover
+
+ async def _listen(self):
+ """Return the next message published on the Socket.IO channel,
+ blocking until a message is available.
+
+ This method needs to be implemented by the different subclasses that
+ support pub/sub backends.
+ """
+ raise NotImplementedError('This method must be implemented in a '
+ 'subclass.') # pragma: no cover
+
+ async def _handle_emit(self, message):
+ # Events with callbacks are very tricky to handle across hosts
+ # Here in the receiving end we set up a local callback that preserves
+ # the callback host and id from the sender
+ remote_callback = message.get('callback')
+ remote_host_id = message.get('host_id')
+ if remote_callback is not None and len(remote_callback) == 3:
+ callback = partial(self._return_callback, remote_host_id,
+ *remote_callback)
+ else:
+ callback = None
+ await super().emit(message['event'], message['data'],
+ namespace=message.get('namespace'),
+ room=message.get('room'),
+ skip_sid=message.get('skip_sid'),
+ callback=callback)
+
+ async def _handle_callback(self, message):
+ if self.host_id == message.get('host_id'):
+ try:
+ sid = message['sid']
+ namespace = message['namespace']
+ id = message['id']
+ args = message['args']
+ except KeyError:
+ return
+ await self.trigger_callback(sid, namespace, id, args)
+
+ async def _return_callback(self, host_id, sid, namespace, callback_id,
+ *args):
+ # When an event callback is received, the callback is returned back
+ # the sender, which is identified by the host_id
+ await self._publish({'method': 'callback', 'host_id': host_id,
+ 'sid': sid, 'namespace': namespace,
+ 'id': callback_id, 'args': args})
+
+ async def _handle_close_room(self, message):
+ await super().close_room(
+ room=message.get('room'), namespace=message.get('namespace'))
+
+ async def _thread(self):
+ while True:
+ try:
+ message = await self._listen()
+ except:
+ import traceback
+ traceback.print_exc()
+ break
+ data = None
+ if isinstance(message, dict):
+ data = message
+ else:
+ if isinstance(message, six.binary_type): # pragma: no cover
+ try:
+ data = pickle.loads(message)
+ except:
+ pass
+ if data is None:
+ try:
+ data = json.loads(message)
+ except:
+ pass
+ if data and 'method' in data:
+ if data['method'] == 'emit':
+ await self._handle_emit(data)
+ elif data['method'] == 'callback':
+ await self._handle_callback(data)
+ elif data['method'] == 'close_room':
+ await self._handle_close_room(data)
diff --git a/libs/socketio/asyncio_redis_manager.py b/libs/socketio/asyncio_redis_manager.py
new file mode 100644
index 000000000..21499c26c
--- /dev/null
+++ b/libs/socketio/asyncio_redis_manager.py
@@ -0,0 +1,107 @@
+import asyncio
+import pickle
+from urllib.parse import urlparse
+
+try:
+ import aioredis
+except ImportError:
+ aioredis = None
+
+from .asyncio_pubsub_manager import AsyncPubSubManager
+
+
+def _parse_redis_url(url):
+ p = urlparse(url)
+ if p.scheme not in {'redis', 'rediss'}:
+ raise ValueError('Invalid redis url')
+ ssl = p.scheme == 'rediss'
+ host = p.hostname or 'localhost'
+ port = p.port or 6379
+ password = p.password
+ if p.path:
+ db = int(p.path[1:])
+ else:
+ db = 0
+ return host, port, password, db, ssl
+
+
+class AsyncRedisManager(AsyncPubSubManager): # pragma: no cover
+ """Redis based client manager for asyncio servers.
+
+ This class implements a Redis backend for event sharing across multiple
+ processes. Only kept here as one more example of how to build a custom
+ backend, since the kombu backend is perfectly adequate to support a Redis
+ message queue.
+
+ To use a Redis backend, initialize the :class:`Server` instance as
+ follows::
+
+ server = socketio.Server(client_manager=socketio.AsyncRedisManager(
+ 'redis://hostname:port/0'))
+
+ :param url: The connection URL for the Redis server. For a default Redis
+ store running on the same host, use ``redis://``. To use an
+ SSL connection, use ``rediss://``.
+ :param channel: The channel name on which the server sends and receives
+ notifications. Must be the same in all the servers.
+ :param write_only: If set ot ``True``, only initialize to emit events. The
+ default of ``False`` initializes the class for emitting
+ and receiving.
+ """
+ name = 'aioredis'
+
+ def __init__(self, url='redis://localhost:6379/0', channel='socketio',
+ write_only=False, logger=None):
+ if aioredis is None:
+ raise RuntimeError('Redis package is not installed '
+ '(Run "pip install aioredis" in your '
+ 'virtualenv).')
+ (
+ self.host, self.port, self.password, self.db, self.ssl
+ ) = _parse_redis_url(url)
+ self.pub = None
+ self.sub = None
+ super().__init__(channel=channel, write_only=write_only, logger=logger)
+
+ async def _publish(self, data):
+ retry = True
+ while True:
+ try:
+ if self.pub is None:
+ self.pub = await aioredis.create_redis(
+ (self.host, self.port), db=self.db,
+ password=self.password, ssl=self.ssl
+ )
+ return await self.pub.publish(self.channel,
+ pickle.dumps(data))
+ except (aioredis.RedisError, OSError):
+ if retry:
+ self._get_logger().error('Cannot publish to redis... '
+ 'retrying')
+ self.pub = None
+ retry = False
+ else:
+ self._get_logger().error('Cannot publish to redis... '
+ 'giving up')
+ break
+
+ async def _listen(self):
+ retry_sleep = 1
+ while True:
+ try:
+ if self.sub is None:
+ self.sub = await aioredis.create_redis(
+ (self.host, self.port), db=self.db,
+ password=self.password, ssl=self.ssl
+ )
+ self.ch = (await self.sub.subscribe(self.channel))[0]
+ return await self.ch.get()
+ except (aioredis.RedisError, OSError):
+ self._get_logger().error('Cannot receive from redis... '
+ 'retrying in '
+ '{} secs'.format(retry_sleep))
+ self.sub = None
+ await asyncio.sleep(retry_sleep)
+ retry_sleep *= 2
+ if retry_sleep > 60:
+ retry_sleep = 60
diff --git a/libs/socketio/asyncio_server.py b/libs/socketio/asyncio_server.py
new file mode 100644
index 000000000..251d58180
--- /dev/null
+++ b/libs/socketio/asyncio_server.py
@@ -0,0 +1,526 @@
+import asyncio
+
+import engineio
+import six
+
+from . import asyncio_manager
+from . import exceptions
+from . import packet
+from . import server
+
+
+class AsyncServer(server.Server):
+ """A Socket.IO server for asyncio.
+
+ This class implements a fully compliant Socket.IO web server with support
+ for websocket and long-polling transports, compatible with the asyncio
+ framework on Python 3.5 or newer.
+
+ :param client_manager: The client manager instance that will manage the
+ client list. When this is omitted, the client list
+ is stored in an in-memory structure, so the use of
+ multiple connected servers is not possible.
+ :param logger: To enable logging set to ``True`` or pass a logger object to
+ use. To disable logging set to ``False``.
+ :param json: An alternative json module to use for encoding and decoding
+ packets. Custom json modules must have ``dumps`` and ``loads``
+ functions that are compatible with the standard library
+ versions.
+ :param async_handlers: If set to ``True``, event handlers are executed in
+ separate threads. To run handlers synchronously,
+ set to ``False``. The default is ``True``.
+ :param kwargs: Connection parameters for the underlying Engine.IO server.
+
+ The Engine.IO configuration supports the following settings:
+
+ :param async_mode: The asynchronous model to use. See the Deployment
+ section in the documentation for a description of the
+ available options. Valid async modes are "aiohttp". If
+ this argument is not given, an async mode is chosen
+ based on the installed packages.
+ :param ping_timeout: The time in seconds that the client waits for the
+ server to respond before disconnecting.
+ :param ping_interval: The interval in seconds at which the client pings
+ the server.
+ :param max_http_buffer_size: The maximum size of a message when using the
+ polling transport.
+ :param allow_upgrades: Whether to allow transport upgrades or not.
+ :param http_compression: Whether to compress packages when using the
+ polling transport.
+ :param compression_threshold: Only compress messages when their byte size
+ is greater than this value.
+ :param cookie: Name of the HTTP cookie that contains the client session
+ id. If set to ``None``, a cookie is not sent to the client.
+ :param cors_allowed_origins: Origin or list of origins that are allowed to
+ connect to this server. Only the same origin
+ is allowed by default. Set this argument to
+ ``'*'`` to allow all origins, or to ``[]`` to
+ disable CORS handling.
+ :param cors_credentials: Whether credentials (cookies, authentication) are
+ allowed in requests to this server.
+ :param monitor_clients: If set to ``True``, a background task will ensure
+ inactive clients are closed. Set to ``False`` to
+ disable the monitoring task (not recommended). The
+ default is ``True``.
+ :param engineio_logger: To enable Engine.IO logging set to ``True`` or pass
+ a logger object to use. To disable logging set to
+ ``False``.
+ """
+ def __init__(self, client_manager=None, logger=False, json=None,
+ async_handlers=True, **kwargs):
+ if client_manager is None:
+ client_manager = asyncio_manager.AsyncManager()
+ super().__init__(client_manager=client_manager, logger=logger,
+ binary=False, json=json,
+ async_handlers=async_handlers, **kwargs)
+
+ def is_asyncio_based(self):
+ return True
+
+ def attach(self, app, socketio_path='socket.io'):
+ """Attach the Socket.IO server to an application."""
+ self.eio.attach(app, socketio_path)
+
+ async def emit(self, event, data=None, to=None, room=None, skip_sid=None,
+ namespace=None, callback=None, **kwargs):
+ """Emit a custom event to one or more connected clients.
+
+ :param event: The event name. It can be any string. The event names
+ ``'connect'``, ``'message'`` and ``'disconnect'`` are
+ reserved and should not be used.
+ :param data: The data to send to the client or clients. Data can be of
+ type ``str``, ``bytes``, ``list`` or ``dict``. If a
+ ``list`` or ``dict``, the data will be serialized as JSON.
+ :param to: The recipient of the message. This can be set to the
+ session ID of a client to address only that client, or to
+ to any custom room created by the application to address all
+ the clients in that room, If this argument is omitted the
+ event is broadcasted to all connected clients.
+ :param room: Alias for the ``to`` parameter.
+ :param skip_sid: The session ID of a client to skip when broadcasting
+ to a room or to all clients. This can be used to
+ prevent a message from being sent to the sender.
+ :param namespace: The Socket.IO namespace for the event. If this
+ argument is omitted the event is emitted to the
+ default namespace.
+ :param callback: If given, this function will be called to acknowledge
+ the the client has received the message. The arguments
+ that will be passed to the function are those provided
+ by the client. Callback functions can only be used
+ when addressing an individual client.
+ :param ignore_queue: Only used when a message queue is configured. If
+ set to ``True``, the event is emitted to the
+ clients directly, without going through the queue.
+ This is more efficient, but only works when a
+ single server process is used. It is recommended
+ to always leave this parameter with its default
+ value of ``False``.
+
+ Note: this method is a coroutine.
+ """
+ namespace = namespace or '/'
+ room = to or room
+ self.logger.info('emitting event "%s" to %s [%s]', event,
+ room or 'all', namespace)
+ await self.manager.emit(event, data, namespace, room=room,
+ skip_sid=skip_sid, callback=callback,
+ **kwargs)
+
+ async def send(self, data, to=None, room=None, skip_sid=None,
+ namespace=None, callback=None, **kwargs):
+ """Send a message to one or more connected clients.
+
+ This function emits an event with the name ``'message'``. Use
+ :func:`emit` to issue custom event names.
+
+ :param data: The data to send to the client or clients. Data can be of
+ type ``str``, ``bytes``, ``list`` or ``dict``. If a
+ ``list`` or ``dict``, the data will be serialized as JSON.
+ :param to: The recipient of the message. This can be set to the
+ session ID of a client to address only that client, or to
+ to any custom room created by the application to address all
+ the clients in that room, If this argument is omitted the
+ event is broadcasted to all connected clients.
+ :param room: Alias for the ``to`` parameter.
+ :param skip_sid: The session ID of a client to skip when broadcasting
+ to a room or to all clients. This can be used to
+ prevent a message from being sent to the sender.
+ :param namespace: The Socket.IO namespace for the event. If this
+ argument is omitted the event is emitted to the
+ default namespace.
+ :param callback: If given, this function will be called to acknowledge
+ the the client has received the message. The arguments
+ that will be passed to the function are those provided
+ by the client. Callback functions can only be used
+ when addressing an individual client.
+ :param ignore_queue: Only used when a message queue is configured. If
+ set to ``True``, the event is emitted to the
+ clients directly, without going through the queue.
+ This is more efficient, but only works when a
+ single server process is used. It is recommended
+ to always leave this parameter with its default
+ value of ``False``.
+
+ Note: this method is a coroutine.
+ """
+ await self.emit('message', data=data, to=to, room=room,
+ skip_sid=skip_sid, namespace=namespace,
+ callback=callback, **kwargs)
+
+ async def call(self, event, data=None, to=None, sid=None, namespace=None,
+ timeout=60, **kwargs):
+ """Emit a custom event to a client and wait for the response.
+
+ :param event: The event name. It can be any string. The event names
+ ``'connect'``, ``'message'`` and ``'disconnect'`` are
+ reserved and should not be used.
+ :param data: The data to send to the client or clients. Data can be of
+ type ``str``, ``bytes``, ``list`` or ``dict``. If a
+ ``list`` or ``dict``, the data will be serialized as JSON.
+ :param to: The session ID of the recipient client.
+ :param sid: Alias for the ``to`` parameter.
+ :param namespace: The Socket.IO namespace for the event. If this
+ argument is omitted the event is emitted to the
+ default namespace.
+ :param timeout: The waiting timeout. If the timeout is reached before
+ the client acknowledges the event, then a
+ ``TimeoutError`` exception is raised.
+ :param ignore_queue: Only used when a message queue is configured. If
+ set to ``True``, the event is emitted to the
+ client directly, without going through the queue.
+ This is more efficient, but only works when a
+ single server process is used. It is recommended
+ to always leave this parameter with its default
+ value of ``False``.
+ """
+ if not self.async_handlers:
+ raise RuntimeError(
+ 'Cannot use call() when async_handlers is False.')
+ callback_event = self.eio.create_event()
+ callback_args = []
+
+ def event_callback(*args):
+ callback_args.append(args)
+ callback_event.set()
+
+ await self.emit(event, data=data, room=to or sid, namespace=namespace,
+ callback=event_callback, **kwargs)
+ try:
+ await asyncio.wait_for(callback_event.wait(), timeout)
+ except asyncio.TimeoutError:
+ six.raise_from(exceptions.TimeoutError(), None)
+ return callback_args[0] if len(callback_args[0]) > 1 \
+ else callback_args[0][0] if len(callback_args[0]) == 1 \
+ else None
+
+ async def close_room(self, room, namespace=None):
+ """Close a room.
+
+ This function removes all the clients from the given room.
+
+ :param room: Room name.
+ :param namespace: The Socket.IO namespace for the event. If this
+ argument is omitted the default namespace is used.
+
+ Note: this method is a coroutine.
+ """
+ namespace = namespace or '/'
+ self.logger.info('room %s is closing [%s]', room, namespace)
+ await self.manager.close_room(room, namespace)
+
+ async def get_session(self, sid, namespace=None):
+ """Return the user session for a client.
+
+ :param sid: The session id of the client.
+ :param namespace: The Socket.IO namespace. If this argument is omitted
+ the default namespace is used.
+
+ The return value is a dictionary. Modifications made to this
+ dictionary are not guaranteed to be preserved. If you want to modify
+ the user session, use the ``session`` context manager instead.
+ """
+ namespace = namespace or '/'
+ eio_session = await self.eio.get_session(sid)
+ return eio_session.setdefault(namespace, {})
+
+ async def save_session(self, sid, session, namespace=None):
+ """Store the user session for a client.
+
+ :param sid: The session id of the client.
+ :param session: The session dictionary.
+ :param namespace: The Socket.IO namespace. If this argument is omitted
+ the default namespace is used.
+ """
+ namespace = namespace or '/'
+ eio_session = await self.eio.get_session(sid)
+ eio_session[namespace] = session
+
+ def session(self, sid, namespace=None):
+ """Return the user session for a client with context manager syntax.
+
+ :param sid: The session id of the client.
+
+ This is a context manager that returns the user session dictionary for
+ the client. Any changes that are made to this dictionary inside the
+ context manager block are saved back to the session. Example usage::
+
+ @eio.on('connect')
+ def on_connect(sid, environ):
+ username = authenticate_user(environ)
+ if not username:
+ return False
+ with eio.session(sid) as session:
+ session['username'] = username
+
+ @eio.on('message')
+ def on_message(sid, msg):
+ async with eio.session(sid) as session:
+ print('received message from ', session['username'])
+ """
+ class _session_context_manager(object):
+ def __init__(self, server, sid, namespace):
+ self.server = server
+ self.sid = sid
+ self.namespace = namespace
+ self.session = None
+
+ async def __aenter__(self):
+ self.session = await self.server.get_session(
+ sid, namespace=self.namespace)
+ return self.session
+
+ async def __aexit__(self, *args):
+ await self.server.save_session(sid, self.session,
+ namespace=self.namespace)
+
+ return _session_context_manager(self, sid, namespace)
+
+ async def disconnect(self, sid, namespace=None):
+ """Disconnect a client.
+
+ :param sid: Session ID of the client.
+ :param namespace: The Socket.IO namespace to disconnect. If this
+ argument is omitted the default namespace is used.
+
+ Note: this method is a coroutine.
+ """
+ namespace = namespace or '/'
+ if self.manager.is_connected(sid, namespace=namespace):
+ self.logger.info('Disconnecting %s [%s]', sid, namespace)
+ self.manager.pre_disconnect(sid, namespace=namespace)
+ await self._send_packet(sid, packet.Packet(packet.DISCONNECT,
+ namespace=namespace))
+ await self._trigger_event('disconnect', namespace, sid)
+ self.manager.disconnect(sid, namespace=namespace)
+ if namespace == '/':
+ await self.eio.disconnect(sid)
+
+ async def handle_request(self, *args, **kwargs):
+ """Handle an HTTP request from the client.
+
+ This is the entry point of the Socket.IO application. This function
+ returns the HTTP response body to deliver to the client.
+
+ Note: this method is a coroutine.
+ """
+ return await self.eio.handle_request(*args, **kwargs)
+
+ def start_background_task(self, target, *args, **kwargs):
+ """Start a background task using the appropriate async model.
+
+ This is a utility function that applications can use to start a
+ background task using the method that is compatible with the
+ selected async mode.
+
+ :param target: the target function to execute. Must be a coroutine.
+ :param args: arguments to pass to the function.
+ :param kwargs: keyword arguments to pass to the function.
+
+ The return value is a ``asyncio.Task`` object.
+
+ Note: this method is a coroutine.
+ """
+ return self.eio.start_background_task(target, *args, **kwargs)
+
+ async def sleep(self, seconds=0):
+ """Sleep for the requested amount of time using the appropriate async
+ model.
+
+ This is a utility function that applications can use to put a task to
+ sleep without having to worry about using the correct call for the
+ selected async mode.
+
+ Note: this method is a coroutine.
+ """
+ return await self.eio.sleep(seconds)
+
+ async def _emit_internal(self, sid, event, data, namespace=None, id=None):
+ """Send a message to a client."""
+ # tuples are expanded to multiple arguments, everything else is sent
+ # as a single argument
+ if isinstance(data, tuple):
+ data = list(data)
+ else:
+ data = [data]
+ await self._send_packet(sid, packet.Packet(
+ packet.EVENT, namespace=namespace, data=[event] + data, id=id,
+ binary=None))
+
+ async def _send_packet(self, sid, pkt):
+ """Send a Socket.IO packet to a client."""
+ encoded_packet = pkt.encode()
+ if isinstance(encoded_packet, list):
+ binary = False
+ for ep in encoded_packet:
+ await self.eio.send(sid, ep, binary=binary)
+ binary = True
+ else:
+ await self.eio.send(sid, encoded_packet, binary=False)
+
+ async def _handle_connect(self, sid, namespace):
+ """Handle a client connection request."""
+ namespace = namespace or '/'
+ self.manager.connect(sid, namespace)
+ if self.always_connect:
+ await self._send_packet(sid, packet.Packet(packet.CONNECT,
+ namespace=namespace))
+ fail_reason = None
+ try:
+ success = await self._trigger_event('connect', namespace, sid,
+ self.environ[sid])
+ except exceptions.ConnectionRefusedError as exc:
+ fail_reason = exc.error_args
+ success = False
+
+ if success is False:
+ if self.always_connect:
+ self.manager.pre_disconnect(sid, namespace)
+ await self._send_packet(sid, packet.Packet(
+ packet.DISCONNECT, data=fail_reason, namespace=namespace))
+ self.manager.disconnect(sid, namespace)
+ if not self.always_connect:
+ await self._send_packet(sid, packet.Packet(
+ packet.ERROR, data=fail_reason, namespace=namespace))
+ if sid in self.environ: # pragma: no cover
+ del self.environ[sid]
+ elif not self.always_connect:
+ await self._send_packet(sid, packet.Packet(packet.CONNECT,
+ namespace=namespace))
+
+ async def _handle_disconnect(self, sid, namespace):
+ """Handle a client disconnect."""
+ namespace = namespace or '/'
+ if namespace == '/':
+ namespace_list = list(self.manager.get_namespaces())
+ else:
+ namespace_list = [namespace]
+ for n in namespace_list:
+ if n != '/' and self.manager.is_connected(sid, n):
+ await self._trigger_event('disconnect', n, sid)
+ self.manager.disconnect(sid, n)
+ if namespace == '/' and self.manager.is_connected(sid, namespace):
+ await self._trigger_event('disconnect', '/', sid)
+ self.manager.disconnect(sid, '/')
+
+ async def _handle_event(self, sid, namespace, id, data):
+ """Handle an incoming client event."""
+ namespace = namespace or '/'
+ self.logger.info('received event "%s" from %s [%s]', data[0], sid,
+ namespace)
+ if not self.manager.is_connected(sid, namespace):
+ self.logger.warning('%s is not connected to namespace %s',
+ sid, namespace)
+ return
+ if self.async_handlers:
+ self.start_background_task(self._handle_event_internal, self, sid,
+ data, namespace, id)
+ else:
+ await self._handle_event_internal(self, sid, data, namespace, id)
+
+ async def _handle_event_internal(self, server, sid, data, namespace, id):
+ r = await server._trigger_event(data[0], namespace, sid, *data[1:])
+ if id is not None:
+ # send ACK packet with the response returned by the handler
+ # tuples are expanded as multiple arguments
+ if r is None:
+ data = []
+ elif isinstance(r, tuple):
+ data = list(r)
+ else:
+ data = [r]
+ await server._send_packet(sid, packet.Packet(packet.ACK,
+ namespace=namespace,
+ id=id, data=data,
+ binary=None))
+
+ async def _handle_ack(self, sid, namespace, id, data):
+ """Handle ACK packets from the client."""
+ namespace = namespace or '/'
+ self.logger.info('received ack from %s [%s]', sid, namespace)
+ await self.manager.trigger_callback(sid, namespace, id, data)
+
+ async def _trigger_event(self, event, namespace, *args):
+ """Invoke an application event handler."""
+ # first see if we have an explicit handler for the event
+ if namespace in self.handlers and event in self.handlers[namespace]:
+ if asyncio.iscoroutinefunction(self.handlers[namespace][event]) \
+ is True:
+ try:
+ ret = await self.handlers[namespace][event](*args)
+ except asyncio.CancelledError: # pragma: no cover
+ ret = None
+ else:
+ ret = self.handlers[namespace][event](*args)
+ return ret
+
+ # or else, forward the event to a namepsace handler if one exists
+ elif namespace in self.namespace_handlers:
+ return await self.namespace_handlers[namespace].trigger_event(
+ event, *args)
+
+ async def _handle_eio_connect(self, sid, environ):
+ """Handle the Engine.IO connection event."""
+ if not self.manager_initialized:
+ self.manager_initialized = True
+ self.manager.initialize()
+ self.environ[sid] = environ
+ return await self._handle_connect(sid, '/')
+
+ async def _handle_eio_message(self, sid, data):
+ """Dispatch Engine.IO messages."""
+ if sid in self._binary_packet:
+ pkt = self._binary_packet[sid]
+ if pkt.add_attachment(data):
+ del self._binary_packet[sid]
+ if pkt.packet_type == packet.BINARY_EVENT:
+ await self._handle_event(sid, pkt.namespace, pkt.id,
+ pkt.data)
+ else:
+ await self._handle_ack(sid, pkt.namespace, pkt.id,
+ pkt.data)
+ else:
+ pkt = packet.Packet(encoded_packet=data)
+ if pkt.packet_type == packet.CONNECT:
+ await self._handle_connect(sid, pkt.namespace)
+ elif pkt.packet_type == packet.DISCONNECT:
+ await self._handle_disconnect(sid, pkt.namespace)
+ elif pkt.packet_type == packet.EVENT:
+ await self._handle_event(sid, pkt.namespace, pkt.id, pkt.data)
+ elif pkt.packet_type == packet.ACK:
+ await self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data)
+ elif pkt.packet_type == packet.BINARY_EVENT or \
+ pkt.packet_type == packet.BINARY_ACK:
+ self._binary_packet[sid] = pkt
+ elif pkt.packet_type == packet.ERROR:
+ raise ValueError('Unexpected ERROR packet.')
+ else:
+ raise ValueError('Unknown packet type.')
+
+ async def _handle_eio_disconnect(self, sid):
+ """Handle Engine.IO disconnect event."""
+ await self._handle_disconnect(sid, '/')
+ if sid in self.environ:
+ del self.environ[sid]
+
+ def _engineio_server_class(self):
+ return engineio.AsyncServer
diff --git a/libs/socketio/base_manager.py b/libs/socketio/base_manager.py
new file mode 100644
index 000000000..3cccb8569
--- /dev/null
+++ b/libs/socketio/base_manager.py
@@ -0,0 +1,178 @@
+import itertools
+import logging
+
+import six
+
+default_logger = logging.getLogger('socketio')
+
+
+class BaseManager(object):
+ """Manage client connections.
+
+ This class keeps track of all the clients and the rooms they are in, to
+ support the broadcasting of messages. The data used by this class is
+ stored in a memory structure, making it appropriate only for single process
+ services. More sophisticated storage backends can be implemented by
+ subclasses.
+ """
+ def __init__(self):
+ self.logger = None
+ self.server = None
+ self.rooms = {}
+ self.callbacks = {}
+ self.pending_disconnect = {}
+
+ def set_server(self, server):
+ self.server = server
+
+ def initialize(self):
+ """Invoked before the first request is received. Subclasses can add
+ their initialization code here.
+ """
+ pass
+
+ def get_namespaces(self):
+ """Return an iterable with the active namespace names."""
+ return six.iterkeys(self.rooms)
+
+ def get_participants(self, namespace, room):
+ """Return an iterable with the active participants in a room."""
+ for sid, active in six.iteritems(self.rooms[namespace][room].copy()):
+ yield sid
+
+ def connect(self, sid, namespace):
+ """Register a client connection to a namespace."""
+ self.enter_room(sid, namespace, None)
+ self.enter_room(sid, namespace, sid)
+
+ def is_connected(self, sid, namespace):
+ if namespace in self.pending_disconnect and \
+ sid in self.pending_disconnect[namespace]:
+ # the client is in the process of being disconnected
+ return False
+ try:
+ return self.rooms[namespace][None][sid]
+ except KeyError:
+ pass
+
+ def pre_disconnect(self, sid, namespace):
+ """Put the client in the to-be-disconnected list.
+
+ This allows the client data structures to be present while the
+ disconnect handler is invoked, but still recognize the fact that the
+ client is soon going away.
+ """
+ if namespace not in self.pending_disconnect:
+ self.pending_disconnect[namespace] = []
+ self.pending_disconnect[namespace].append(sid)
+
+ def disconnect(self, sid, namespace):
+ """Register a client disconnect from a namespace."""
+ if namespace not in self.rooms:
+ return
+ rooms = []
+ for room_name, room in six.iteritems(self.rooms[namespace].copy()):
+ if sid in room:
+ rooms.append(room_name)
+ for room in rooms:
+ self.leave_room(sid, namespace, room)
+ if sid in self.callbacks and namespace in self.callbacks[sid]:
+ del self.callbacks[sid][namespace]
+ if len(self.callbacks[sid]) == 0:
+ del self.callbacks[sid]
+ if namespace in self.pending_disconnect and \
+ sid in self.pending_disconnect[namespace]:
+ self.pending_disconnect[namespace].remove(sid)
+ if len(self.pending_disconnect[namespace]) == 0:
+ del self.pending_disconnect[namespace]
+
+ def enter_room(self, sid, namespace, room):
+ """Add a client to a room."""
+ if namespace not in self.rooms:
+ self.rooms[namespace] = {}
+ if room not in self.rooms[namespace]:
+ self.rooms[namespace][room] = {}
+ self.rooms[namespace][room][sid] = True
+
+ def leave_room(self, sid, namespace, room):
+ """Remove a client from a room."""
+ try:
+ del self.rooms[namespace][room][sid]
+ if len(self.rooms[namespace][room]) == 0:
+ del self.rooms[namespace][room]
+ if len(self.rooms[namespace]) == 0:
+ del self.rooms[namespace]
+ except KeyError:
+ pass
+
+ def close_room(self, room, namespace):
+ """Remove all participants from a room."""
+ try:
+ for sid in self.get_participants(namespace, room):
+ self.leave_room(sid, namespace, room)
+ except KeyError:
+ pass
+
+ def get_rooms(self, sid, namespace):
+ """Return the rooms a client is in."""
+ r = []
+ try:
+ for room_name, room in six.iteritems(self.rooms[namespace]):
+ if room_name is not None and sid in room and room[sid]:
+ r.append(room_name)
+ except KeyError:
+ pass
+ return r
+
+ def emit(self, event, data, namespace, room=None, skip_sid=None,
+ callback=None, **kwargs):
+ """Emit a message to a single client, a room, or all the clients
+ connected to the namespace."""
+ if namespace not in self.rooms or room not in self.rooms[namespace]:
+ return
+ if not isinstance(skip_sid, list):
+ skip_sid = [skip_sid]
+ for sid in self.get_participants(namespace, room):
+ if sid not in skip_sid:
+ if callback is not None:
+ id = self._generate_ack_id(sid, namespace, callback)
+ else:
+ id = None
+ self.server._emit_internal(sid, event, data, namespace, id)
+
+ def trigger_callback(self, sid, namespace, id, data):
+ """Invoke an application callback."""
+ callback = None
+ try:
+ callback = self.callbacks[sid][namespace][id]
+ except KeyError:
+ # if we get an unknown callback we just ignore it
+ self._get_logger().warning('Unknown callback received, ignoring.')
+ else:
+ del self.callbacks[sid][namespace][id]
+ if callback is not None:
+ callback(*data)
+
+ def _generate_ack_id(self, sid, namespace, callback):
+ """Generate a unique identifier for an ACK packet."""
+ namespace = namespace or '/'
+ if sid not in self.callbacks:
+ self.callbacks[sid] = {}
+ if namespace not in self.callbacks[sid]:
+ self.callbacks[sid][namespace] = {0: itertools.count(1)}
+ id = six.next(self.callbacks[sid][namespace][0])
+ self.callbacks[sid][namespace][id] = callback
+ return id
+
+ def _get_logger(self):
+ """Get the appropriate logger
+
+ Prevents uninitialized servers in write-only mode from failing.
+ """
+
+ if self.logger:
+ return self.logger
+ elif self.server:
+ return self.server.logger
+ else:
+ return default_logger
diff --git a/libs/socketio/client.py b/libs/socketio/client.py
new file mode 100644
index 000000000..e917d634d
--- /dev/null
+++ b/libs/socketio/client.py
@@ -0,0 +1,620 @@
+import itertools
+import logging
+import random
+import signal
+
+import engineio
+import six
+
+from . import exceptions
+from . import namespace
+from . import packet
+
+default_logger = logging.getLogger('socketio.client')
+reconnecting_clients = []
+
+
+def signal_handler(sig, frame): # pragma: no cover
+ """SIGINT handler.
+
+ Notify any clients that are in a reconnect loop to abort. Other
+ disconnection tasks are handled at the engine.io level.
+ """
+ for client in reconnecting_clients[:]:
+ client._reconnect_abort.set()
+ return original_signal_handler(sig, frame)
+
+
+original_signal_handler = signal.signal(signal.SIGINT, signal_handler)
+
+
+class Client(object):
+ """A Socket.IO client.
+
+ This class implements a fully compliant Socket.IO web client with support
+ for websocket and long-polling transports.
+
+ :param reconnection: ``True`` if the client should automatically attempt to
+ reconnect to the server after an interruption, or
+ ``False`` to not reconnect. The default is ``True``.
+ :param reconnection_attempts: How many reconnection attempts to issue
+ before giving up, or 0 for infinity attempts.
+ The default is 0.
+ :param reconnection_delay: How long to wait in seconds before the first
+ reconnection attempt. Each successive attempt
+ doubles this delay.
+ :param reconnection_delay_max: The maximum delay between reconnection
+ attempts.
+ :param randomization_factor: Randomization amount for each delay between
+ reconnection attempts. The default is 0.5,
+ which means that each delay is randomly
+ adjusted by +/- 50%.
+ :param logger: To enable logging set to ``True`` or pass a logger object to
+ use. To disable logging set to ``False``. The default is
+ ``False``.
+ :param binary: ``True`` to support binary payloads, ``False`` to treat all
+ payloads as text. On Python 2, if this is set to ``True``,
+ ``unicode`` values are treated as text, and ``str`` and
+ ``bytes`` values are treated as binary. This option has no
+ effect on Python 3, where text and binary payloads are
+ always automatically discovered.
+ :param json: An alternative json module to use for encoding and decoding
+ packets. Custom json modules must have ``dumps`` and ``loads``
+ functions that are compatible with the standard library
+ versions.
+
+ The Engine.IO configuration supports the following settings:
+
+ :param request_timeout: A timeout in seconds for requests. The default is
+ 5 seconds.
+ :param ssl_verify: ``True`` to verify SSL certificates, or ``False`` to
+ skip SSL certificate verification, allowing
+ connections to servers with self signed certificates.
+ The default is ``True``.
+ :param engineio_logger: To enable Engine.IO logging set to ``True`` or pass
+ a logger object to use. To disable logging set to
+ ``False``. The default is ``False``.
+ """
+ def __init__(self, reconnection=True, reconnection_attempts=0,
+ reconnection_delay=1, reconnection_delay_max=5,
+ randomization_factor=0.5, logger=False, binary=False,
+ json=None, **kwargs):
+ self.reconnection = reconnection
+ self.reconnection_attempts = reconnection_attempts
+ self.reconnection_delay = reconnection_delay
+ self.reconnection_delay_max = reconnection_delay_max
+ self.randomization_factor = randomization_factor
+ self.binary = binary
+
+ engineio_options = kwargs
+ engineio_logger = engineio_options.pop('engineio_logger', None)
+ if engineio_logger is not None:
+ engineio_options['logger'] = engineio_logger
+ if json is not None:
+ packet.Packet.json = json
+ engineio_options['json'] = json
+
+ self.eio = self._engineio_client_class()(**engineio_options)
+ self.eio.on('connect', self._handle_eio_connect)
+ self.eio.on('message', self._handle_eio_message)
+ self.eio.on('disconnect', self._handle_eio_disconnect)
+
+ if not isinstance(logger, bool):
+ self.logger = logger
+ else:
+ self.logger = default_logger
+ if not logging.root.handlers and \
+ self.logger.level == logging.NOTSET:
+ if logger:
+ self.logger.setLevel(logging.INFO)
+ else:
+ self.logger.setLevel(logging.ERROR)
+ self.logger.addHandler(logging.StreamHandler())
+
+ self.connection_url = None
+ self.connection_headers = None
+ self.connection_transports = None
+ self.connection_namespaces = None
+ self.socketio_path = None
+ self.sid = None
+
+ self.connected = False
+ self.namespaces = []
+ self.handlers = {}
+ self.namespace_handlers = {}
+ self.callbacks = {}
+ self._binary_packet = None
+ self._reconnect_task = None
+ self._reconnect_abort = self.eio.create_event()
+
+ def is_asyncio_based(self):
+ return False
+
+ def on(self, event, handler=None, namespace=None):
+ """Register an event handler.
+
+ :param event: The event name. It can be any string. The event names
+ ``'connect'``, ``'message'`` and ``'disconnect'`` are
+ reserved and should not be used.
+ :param handler: The function that should be invoked to handle the
+ event. When this parameter is not given, the method
+ acts as a decorator for the handler function.
+ :param namespace: The Socket.IO namespace for the event. If this
+ argument is omitted the handler is associated with
+ the default namespace.
+
+ Example usage::
+
+ # as a decorator:
+ @sio.on('connect')
+ def connect_handler():
+ print('Connected!')
+
+ # as a method:
+ def message_handler(msg):
+ print('Received message: ', msg)
+ sio.send( 'response')
+ sio.on('message', message_handler)
+
+ The ``'connect'`` event handler receives no arguments. The
+ ``'message'`` handler and handlers for custom event names receive the
+ message payload as only argument. Any values returned from a message
+ handler will be passed to the client's acknowledgement callback
+ function if it exists. The ``'disconnect'`` handler does not take
+ arguments.
+ """
+ namespace = namespace or '/'
+
+ def set_handler(handler):
+ if namespace not in self.handlers:
+ self.handlers[namespace] = {}
+ self.handlers[namespace][event] = handler
+ return handler
+
+ if handler is None:
+ return set_handler
+ set_handler(handler)
+
+ def event(self, *args, **kwargs):
+ """Decorator to register an event handler.
+
+ This is a simplified version of the ``on()`` method that takes the
+ event name from the decorated function.
+
+ Example usage::
+
+ @sio.event
+ def my_event(data):
+ print('Received data: ', data)
+
+ The above example is equivalent to::
+
+ @sio.on('my_event')
+ def my_event(data):
+ print('Received data: ', data)
+
+ A custom namespace can be given as an argument to the decorator::
+
+ @sio.event(namespace='/test')
+ def my_event(data):
+ print('Received data: ', data)
+ """
+ if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
+ # the decorator was invoked without arguments
+ # args[0] is the decorated function
+ return self.on(args[0].__name__)(args[0])
+ else:
+ # the decorator was invoked with arguments
+ def set_handler(handler):
+ return self.on(handler.__name__, *args, **kwargs)(handler)
+
+ return set_handler
+
+ def register_namespace(self, namespace_handler):
+ """Register a namespace handler object.
+
+ :param namespace_handler: An instance of a :class:`Namespace`
+ subclass that handles all the event traffic
+ for a namespace.
+ """
+ if not isinstance(namespace_handler, namespace.ClientNamespace):
+ raise ValueError('Not a namespace instance')
+ if self.is_asyncio_based() != namespace_handler.is_asyncio_based():
+ raise ValueError('Not a valid namespace class for this client')
+ namespace_handler._set_client(self)
+ self.namespace_handlers[namespace_handler.namespace] = \
+ namespace_handler
+
+ def connect(self, url, headers={}, transports=None,
+ namespaces=None, socketio_path='socket.io'):
+ """Connect to a Socket.IO server.
+
+ :param url: The URL of the Socket.IO server. It can include custom
+ query string parameters if required by the server.
+ :param headers: A dictionary with custom headers to send with the
+ connection request.
+ :param transports: The list of allowed transports. Valid transports
+ are ``'polling'`` and ``'websocket'``. If not
+ given, the polling transport is connected first,
+ then an upgrade to websocket is attempted.
+ :param namespaces: The list of custom namespaces to connect, in
+ addition to the default namespace. If not given,
+ the namespace list is obtained from the registered
+ event handlers.
+ :param socketio_path: The endpoint where the Socket.IO server is
+ installed. The default value is appropriate for
+ most cases.
+
+ Example usage::
+
+ sio = socketio.Client()
+ sio.connect('http://localhost:5000')
+ """
+ self.connection_url = url
+ self.connection_headers = headers
+ self.connection_transports = transports
+ self.connection_namespaces = namespaces
+ self.socketio_path = socketio_path
+
+ if namespaces is None:
+ namespaces = set(self.handlers.keys()).union(
+ set(self.namespace_handlers.keys()))
+ elif isinstance(namespaces, six.string_types):
+ namespaces = [namespaces]
+ self.connection_namespaces = namespaces
+ self.namespaces = [n for n in namespaces if n != '/']
+ try:
+ self.eio.connect(url, headers=headers, transports=transports,
+ engineio_path=socketio_path)
+ except engineio.exceptions.ConnectionError as exc:
+ six.raise_from(exceptions.ConnectionError(exc.args[0]), None)
+ self.connected = True
+
+ def wait(self):
+ """Wait until the connection with the server ends.
+
+ Client applications can use this function to block the main thread
+ during the life of the connection.
+ """
+ while True:
+ self.eio.wait()
+ self.sleep(1) # give the reconnect task time to start up
+ if not self._reconnect_task:
+ break
+ self._reconnect_task.join()
+ if self.eio.state != 'connected':
+ break
+
+ def emit(self, event, data=None, namespace=None, callback=None):
+ """Emit a custom event to one or more connected clients.
+
+ :param event: The event name. It can be any string. The event names
+ ``'connect'``, ``'message'`` and ``'disconnect'`` are
+ reserved and should not be used.
+ :param data: The data to send to the client or clients. Data can be of
+ type ``str``, ``bytes``, ``list`` or ``dict``. If a
+ ``list`` or ``dict``, the data will be serialized as JSON.
+ :param namespace: The Socket.IO namespace for the event. If this
+ argument is omitted the event is emitted to the
+ default namespace.
+ :param callback: If given, this function will be called to acknowledge
+ the the client has received the message. The arguments
+ that will be passed to the function are those provided
+ by the client. Callback functions can only be used
+ when addressing an individual client.
+ """
+ namespace = namespace or '/'
+ if namespace != '/' and namespace not in self.namespaces:
+ raise exceptions.BadNamespaceError(
+ namespace + ' is not a connected namespace.')
+ self.logger.info('Emitting event "%s" [%s]', event, namespace)
+ if callback is not None:
+ id = self._generate_ack_id(namespace, callback)
+ else:
+ id = None
+ if six.PY2 and not self.binary:
+ binary = False # pragma: nocover
+ else:
+ binary = None
+ # tuples are expanded to multiple arguments, everything else is sent
+ # as a single argument
+ if isinstance(data, tuple):
+ data = list(data)
+ elif data is not None:
+ data = [data]
+ else:
+ data = []
+ self._send_packet(packet.Packet(packet.EVENT, namespace=namespace,
+ data=[event] + data, id=id,
+ binary=binary))
+
+ def send(self, data, namespace=None, callback=None):
+ """Send a message to one or more connected clients.
+
+ This function emits an event with the name ``'message'``. Use
+ :func:`emit` to issue custom event names.
+
+ :param data: The data to send to the client or clients. Data can be of
+ type ``str``, ``bytes``, ``list`` or ``dict``. If a
+ ``list`` or ``dict``, the data will be serialized as JSON.
+ :param namespace: The Socket.IO namespace for the event. If this
+ argument is omitted the event is emitted to the
+ default namespace.
+ :param callback: If given, this function will be called to acknowledge
+ the the client has received the message. The arguments
+ that will be passed to the function are those provided
+ by the client. Callback functions can only be used
+ when addressing an individual client.
+ """
+ self.emit('message', data=data, namespace=namespace,
+ callback=callback)
+
+ def call(self, event, data=None, namespace=None, timeout=60):
+ """Emit a custom event to a client and wait for the response.
+
+ :param event: The event name. It can be any string. The event names
+ ``'connect'``, ``'message'`` and ``'disconnect'`` are
+ reserved and should not be used.
+ :param data: The data to send to the client or clients. Data can be of
+ type ``str``, ``bytes``, ``list`` or ``dict``. If a
+ ``list`` or ``dict``, the data will be serialized as JSON.
+ :param namespace: The Socket.IO namespace for the event. If this
+ argument is omitted the event is emitted to the
+ default namespace.
+ :param timeout: The waiting timeout. If the timeout is reached before
+ the client acknowledges the event, then a
+ ``TimeoutError`` exception is raised.
+ """
+ callback_event = self.eio.create_event()
+ callback_args = []
+
+ def event_callback(*args):
+ callback_args.append(args)
+ callback_event.set()
+
+ self.emit(event, data=data, namespace=namespace,
+ callback=event_callback)
+ if not callback_event.wait(timeout=timeout):
+ raise exceptions.TimeoutError()
+ return callback_args[0] if len(callback_args[0]) > 1 \
+ else callback_args[0][0] if len(callback_args[0]) == 1 \
+ else None
+
+ def disconnect(self):
+ """Disconnect from the server."""
+ # here we just request the disconnection
+ # later in _handle_eio_disconnect we invoke the disconnect handler
+ for n in self.namespaces:
+ self._send_packet(packet.Packet(packet.DISCONNECT, namespace=n))
+ self._send_packet(packet.Packet(
+ packet.DISCONNECT, namespace='/'))
+ self.connected = False
+ self.eio.disconnect(abort=True)
+
+ def transport(self):
+ """Return the name of the transport used by the client.
+
+ The two possible values returned by this function are ``'polling'``
+ and ``'websocket'``.
+ """
+ return self.eio.transport()
+
+ def start_background_task(self, target, *args, **kwargs):
+ """Start a background task using the appropriate async model.
+
+ This is a utility function that applications can use to start a
+ background task using the method that is compatible with the
+ selected async mode.
+
+ :param target: the target function to execute.
+ :param args: arguments to pass to the function.
+ :param kwargs: keyword arguments to pass to the function.
+
+ This function returns an object compatible with the `Thread` class in
+ the Python standard library. The `start()` method on this object is
+ already called by this function.
+ """
+ return self.eio.start_background_task(target, *args, **kwargs)
+
+ def sleep(self, seconds=0):
+ """Sleep for the requested amount of time using the appropriate async
+ model.
+
+ This is a utility function that applications can use to put a task to
+ sleep without having to worry about using the correct call for the
+ selected async mode.
+ """
+ return self.eio.sleep(seconds)
+
+ def _send_packet(self, pkt):
+ """Send a Socket.IO packet to the server."""
+ encoded_packet = pkt.encode()
+ if isinstance(encoded_packet, list):
+ binary = False
+ for ep in encoded_packet:
+ self.eio.send(ep, binary=binary)
+ binary = True
+ else:
+ self.eio.send(encoded_packet, binary=False)
+
+ def _generate_ack_id(self, namespace, callback):
+ """Generate a unique identifier for an ACK packet."""
+ namespace = namespace or '/'
+ if namespace not in self.callbacks:
+ self.callbacks[namespace] = {0: itertools.count(1)}
+ id = six.next(self.callbacks[namespace][0])
+ self.callbacks[namespace][id] = callback
+ return id
+
+ def _handle_connect(self, namespace):
+ namespace = namespace or '/'
+ self.logger.info('Namespace {} is connected'.format(namespace))
+ self._trigger_event('connect', namespace=namespace)
+ if namespace == '/':
+ for n in self.namespaces:
+ self._send_packet(packet.Packet(packet.CONNECT, namespace=n))
+ elif namespace not in self.namespaces:
+ self.namespaces.append(namespace)
+
+ def _handle_disconnect(self, namespace):
+ if not self.connected:
+ return
+ namespace = namespace or '/'
+ if namespace == '/':
+ for n in self.namespaces:
+ self._trigger_event('disconnect', namespace=n)
+ self.namespaces = []
+ self._trigger_event('disconnect', namespace=namespace)
+ if namespace in self.namespaces:
+ self.namespaces.remove(namespace)
+ if namespace == '/':
+ self.connected = False
+
+ def _handle_event(self, namespace, id, data):
+ namespace = namespace or '/'
+ self.logger.info('Received event "%s" [%s]', data[0], namespace)
+ r = self._trigger_event(data[0], namespace, *data[1:])
+ if id is not None:
+ # send ACK packet with the response returned by the handler
+ # tuples are expanded as multiple arguments
+ if r is None:
+ data = []
+ elif isinstance(r, tuple):
+ data = list(r)
+ else:
+ data = [r]
+ if six.PY2 and not self.binary:
+ binary = False # pragma: nocover
+ else:
+ binary = None
+ self._send_packet(packet.Packet(packet.ACK, namespace=namespace,
+ id=id, data=data, binary=binary))
+
+ def _handle_ack(self, namespace, id, data):
+ namespace = namespace or '/'
+ self.logger.info('Received ack [%s]', namespace)
+ callback = None
+ try:
+ callback = self.callbacks[namespace][id]
+ except KeyError:
+ # if we get an unknown callback we just ignore it
+ self.logger.warning('Unknown callback received, ignoring.')
+ else:
+ del self.callbacks[namespace][id]
+ if callback is not None:
+ callback(*data)
+
+ def _handle_error(self, namespace, data):
+ namespace = namespace or '/'
+ self.logger.info('Connection to namespace {} was rejected'.format(
+ namespace))
+ if data is None:
+ data = tuple()
+ elif not isinstance(data, (tuple, list)):
+ data = (data,)
+ self._trigger_event('connect_error', namespace, *data)
+ if namespace in self.namespaces:
+ self.namespaces.remove(namespace)
+ if namespace == '/':
+ self.namespaces = []
+ self.connected = False
+
+ def _trigger_event(self, event, namespace, *args):
+ """Invoke an application event handler."""
+ # first see if we have an explicit handler for the event
+ if namespace in self.handlers and event in self.handlers[namespace]:
+ return self.handlers[namespace][event](*args)
+
+ # or else, forward the event to a namespace handler if one exists
+ elif namespace in self.namespace_handlers:
+ return self.namespace_handlers[namespace].trigger_event(
+ event, *args)
+
+ def _handle_reconnect(self):
+ self._reconnect_abort.clear()
+ reconnecting_clients.append(self)
+ attempt_count = 0
+ current_delay = self.reconnection_delay
+ while True:
+ delay = current_delay
+ current_delay *= 2
+ if delay > self.reconnection_delay_max:
+ delay = self.reconnection_delay_max
+ delay += self.randomization_factor * (2 * random.random() - 1)
+ self.logger.info(
+ 'Connection failed, new attempt in {:.02f} seconds'.format(
+ delay))
+ if self._reconnect_abort.wait(delay):
+ self.logger.info('Reconnect task aborted')
+ break
+ attempt_count += 1
+ try:
+ self.connect(self.connection_url,
+ headers=self.connection_headers,
+ transports=self.connection_transports,
+ namespaces=self.connection_namespaces,
+ socketio_path=self.socketio_path)
+ except (exceptions.ConnectionError, ValueError):
+ pass
+ else:
+ self.logger.info('Reconnection successful')
+ self._reconnect_task = None
+ break
+ if self.reconnection_attempts and \
+ attempt_count >= self.reconnection_attempts:
+ self.logger.info(
+ 'Maximum reconnection attempts reached, giving up')
+ break
+ reconnecting_clients.remove(self)
+
+ def _handle_eio_connect(self):
+ """Handle the Engine.IO connection event."""
+ self.logger.info('Engine.IO connection established')
+ self.sid = self.eio.sid
+
+ def _handle_eio_message(self, data):
+ """Dispatch Engine.IO messages."""
+ if self._binary_packet:
+ pkt = self._binary_packet
+ if pkt.add_attachment(data):
+ self._binary_packet = None
+ if pkt.packet_type == packet.BINARY_EVENT:
+ self._handle_event(pkt.namespace, pkt.id, pkt.data)
+ else:
+ self._handle_ack(pkt.namespace, pkt.id, pkt.data)
+ else:
+ pkt = packet.Packet(encoded_packet=data)
+ if pkt.packet_type == packet.CONNECT:
+ self._handle_connect(pkt.namespace)
+ elif pkt.packet_type == packet.DISCONNECT:
+ self._handle_disconnect(pkt.namespace)
+ elif pkt.packet_type == packet.EVENT:
+ self._handle_event(pkt.namespace, pkt.id, pkt.data)
+ elif pkt.packet_type == packet.ACK:
+ self._handle_ack(pkt.namespace, pkt.id, pkt.data)
+ elif pkt.packet_type == packet.BINARY_EVENT or \
+ pkt.packet_type == packet.BINARY_ACK:
+ self._binary_packet = pkt
+ elif pkt.packet_type == packet.ERROR:
+ self._handle_error(pkt.namespace, pkt.data)
+ else:
+ raise ValueError('Unknown packet type.')
+
+ def _handle_eio_disconnect(self):
+ """Handle the Engine.IO disconnection event."""
+ self.logger.info('Engine.IO connection dropped')
+ if self.connected:
+ for n in self.namespaces:
+ self._trigger_event('disconnect', namespace=n)
+ self._trigger_event('disconnect', namespace='/')
+ self.namespaces = []
+ self.connected = False
+ self.callbacks = {}
+ self._binary_packet = None
+ self.sid = None
+ if self.eio.state == 'connected' and self.reconnection:
+ self._reconnect_task = self.start_background_task(
+ self._handle_reconnect)
+
+ def _engineio_client_class(self):
+ return engineio.Client
diff --git a/libs/socketio/exceptions.py b/libs/socketio/exceptions.py
new file mode 100644
index 000000000..36dddd9fc
--- /dev/null
+++ b/libs/socketio/exceptions.py
@@ -0,0 +1,30 @@
+class SocketIOError(Exception):
+ pass
+
+
+class ConnectionError(SocketIOError):
+ pass
+
+
+class ConnectionRefusedError(ConnectionError):
+ """Connection refused exception.
+
+ This exception can be raised from a connect handler when the connection
+ is not accepted. The positional arguments provided with the exception are
+ returned with the error packet to the client.
+ """
+ def __init__(self, *args):
+ if len(args) == 0:
+ self.error_args = None
+ elif len(args) == 1 and not isinstance(args[0], list):
+ self.error_args = args[0]
+ else:
+ self.error_args = args
+
+
+class TimeoutError(SocketIOError):
+ pass
+
+
+class BadNamespaceError(SocketIOError):
+ pass
diff --git a/libs/socketio/kafka_manager.py b/libs/socketio/kafka_manager.py
new file mode 100644
index 000000000..00a2e7f05
--- /dev/null
+++ b/libs/socketio/kafka_manager.py
@@ -0,0 +1,63 @@
+import logging
+import pickle
+
+try:
+ import kafka
+except ImportError:
+ kafka = None
+
+from .pubsub_manager import PubSubManager
+
+logger = logging.getLogger('socketio')
+
+
+class KafkaManager(PubSubManager): # pragma: no cover
+ """Kafka based client manager.
+
+ This class implements a Kafka backend for event sharing across multiple
+ processes.
+
+ To use a Kafka backend, initialize the :class:`Server` instance as
+ follows::
+
+ url = 'kafka://hostname:port'
+ server = socketio.Server(client_manager=socketio.KafkaManager(url))
+
+ :param url: The connection URL for the Kafka server. For a default Kafka
+ store running on the same host, use ``kafka://``.
+ :param channel: The channel name (topic) on which the server sends and
+ receives notifications. Must be the same in all the
+ servers.
+ :param write_only: If set ot ``True``, only initialize to emit events. The
+ default of ``False`` initializes the class for emitting
+ and receiving.
+ """
+ name = 'kafka'
+
+ def __init__(self, url='kafka://localhost:9092', channel='socketio',
+ write_only=False):
+ if kafka is None:
+ raise RuntimeError('kafka-python package is not installed '
+ '(Run "pip install kafka-python" in your '
+ 'virtualenv).')
+
+ super(KafkaManager, self).__init__(channel=channel,
+ write_only=write_only)
+
+ self.kafka_url = url[8:] if url != 'kafka://' else 'localhost:9092'
+ self.producer = kafka.KafkaProducer(bootstrap_servers=self.kafka_url)
+ self.consumer = kafka.KafkaConsumer(self.channel,
+ bootstrap_servers=self.kafka_url)
+
+ def _publish(self, data):
+ self.producer.send(self.channel, value=pickle.dumps(data))
+ self.producer.flush()
+
+ def _kafka_listen(self):
+ for message in self.consumer:
+ yield message
+
+ def _listen(self):
+ for message in self._kafka_listen():
+ if message.topic == self.channel:
+ yield pickle.loads(message.value)
diff --git a/libs/socketio/kombu_manager.py b/libs/socketio/kombu_manager.py
new file mode 100644
index 000000000..4eb9ee498
--- /dev/null
+++ b/libs/socketio/kombu_manager.py
@@ -0,0 +1,122 @@
+import pickle
+import uuid
+
+try:
+ import kombu
+except ImportError:
+ kombu = None
+
+from .pubsub_manager import PubSubManager
+
+
+class KombuManager(PubSubManager): # pragma: no cover
+ """Client manager that uses kombu for inter-process messaging.
+
+ This class implements a client manager backend for event sharing across
+ multiple processes, using RabbitMQ, Redis or any other messaging mechanism
+ supported by `kombu <http://kombu.readthedocs.org/en/latest/>`_.
+
+ To use a kombu backend, initialize the :class:`Server` instance as
+ follows::
+
+ url = 'amqp://user:password@hostname:port//'
+ server = socketio.Server(client_manager=socketio.KombuManager(url))
+
+ :param url: The connection URL for the backend messaging queue. Example
+ connection URLs are ``'amqp://guest:guest@localhost:5672//'``
+ and ``'redis://localhost:6379/'`` for RabbitMQ and Redis
+ respectively. Consult the `kombu documentation
+ <http://kombu.readthedocs.org/en/latest/userguide\
+ /connections.html#urls>`_ for more on how to construct
+ connection URLs.
+ :param channel: The channel name on which the server sends and receives
+ notifications. Must be the same in all the servers.
+ :param write_only: If set ot ``True``, only initialize to emit events. The
+ default of ``False`` initializes the class for emitting
+ and receiving.
+ :param connection_options: additional keyword arguments to be passed to
+ ``kombu.Connection()``.
+ :param exchange_options: additional keyword arguments to be passed to
+ ``kombu.Exchange()``.
+ :param queue_options: additional keyword arguments to be passed to
+ ``kombu.Queue()``.
+ :param producer_options: additional keyword arguments to be passed to
+ ``kombu.Producer()``.
+ """
+ name = 'kombu'
+
+ def __init__(self, url='amqp://guest:guest@localhost:5672//',
+ channel='socketio', write_only=False, logger=None,
+ connection_options=None, exchange_options=None,
+ queue_options=None, producer_options=None):
+ if kombu is None:
+ raise RuntimeError('Kombu package is not installed '
+ '(Run "pip install kombu" in your '
+ 'virtualenv).')
+ super(KombuManager, self).__init__(channel=channel,
+ write_only=write_only,
+ logger=logger)
+ self.url = url
+ self.connection_options = connection_options or {}
+ self.exchange_options = exchange_options or {}
+ self.queue_options = queue_options or {}
+ self.producer_options = producer_options or {}
+ self.producer = self._producer()
+
+ def initialize(self):
+ super(KombuManager, self).initialize()
+
+ monkey_patched = True
+ if self.server.async_mode == 'eventlet':
+ from eventlet.patcher import is_monkey_patched
+ monkey_patched = is_monkey_patched('socket')
+ elif 'gevent' in self.server.async_mode:
+ from gevent.monkey import is_module_patched
+ monkey_patched = is_module_patched('socket')
+ if not monkey_patched:
+ raise RuntimeError(
+ 'Kombu requires a monkey patched socket library to work '
+ 'with ' + self.server.async_mode)
+
+ def _connection(self):
+ return kombu.Connection(self.url, **self.connection_options)
+
+ def _exchange(self):
+ options = {'type': 'fanout', 'durable': False}
+ options.update(self.exchange_options)
+ return kombu.Exchange(self.channel, **options)
+
+ def _queue(self):
+ queue_name = 'flask-socketio.' + str(uuid.uuid4())
+ options = {'durable': False, 'queue_arguments': {'x-expires': 300000}}
+ options.update(self.queue_options)
+ return kombu.Queue(queue_name, self._exchange(), **options)
+
+ def _producer(self):
+ return self._connection().Producer(exchange=self._exchange(),
+ **self.producer_options)
+
+ def __error_callback(self, exception, interval):
+ self._get_logger().exception('Sleeping {}s'.format(interval))
+
+ def _publish(self, data):
+ connection = self._connection()
+ publish = connection.ensure(self.producer, self.producer.publish,
+ errback=self.__error_callback)
+ publish(pickle.dumps(data))
+
+ def _listen(self):
+ reader_queue = self._queue()
+
+ while True:
+ connection = self._connection().ensure_connection(
+ errback=self.__error_callback)
+ try:
+ with connection.SimpleQueue(reader_queue) as queue:
+ while True:
+ message = queue.get(block=True)
+ message.ack()
+ yield message.payload
+ except connection.connection_errors:
+ self._get_logger().exception("Connection error "
+ "while reading from queue")
diff --git a/libs/socketio/middleware.py b/libs/socketio/middleware.py
new file mode 100644
index 000000000..1a6974085
--- /dev/null
+++ b/libs/socketio/middleware.py
@@ -0,0 +1,42 @@
+import engineio
+
+
+class WSGIApp(engineio.WSGIApp):
+ """WSGI middleware for Socket.IO.
+
+ This middleware dispatches traffic to a Socket.IO application. It can also
+ serve a list of static files to the client, or forward unrelated HTTP
+ traffic to another WSGI application.
+
+ :param socketio_app: The Socket.IO server. Must be an instance of the
+ ``socketio.Server`` class.
+ :param wsgi_app: The WSGI app that receives all other traffic.
+ :param static_files: A dictionary with static file mapping rules. See the
+ documentation for details on this argument.
+ :param socketio_path: The endpoint where the Socket.IO application should
+ be installed. The default value is appropriate for
+ most cases.
+
+ Example usage::
+
+ import socketio
+ import eventlet
+ from . import wsgi_app
+
+ sio = socketio.Server()
+ app = socketio.WSGIApp(sio, wsgi_app)
+ eventlet.wsgi.server(eventlet.listen(('', 8000)), app)
+ """
+ def __init__(self, socketio_app, wsgi_app=None, static_files=None,
+ socketio_path='socket.io'):
+ super(WSGIApp, self).__init__(socketio_app, wsgi_app,
+ static_files=static_files,
+ engineio_path=socketio_path)
+
+
+class Middleware(WSGIApp):
+ """This class has been renamed to WSGIApp and is now deprecated."""
+ def __init__(self, socketio_app, wsgi_app=None,
+ socketio_path='socket.io'):
+ super(Middleware, self).__init__(socketio_app, wsgi_app,
+ socketio_path=socketio_path)
diff --git a/libs/socketio/namespace.py b/libs/socketio/namespace.py
new file mode 100644
index 000000000..418615ff8
--- /dev/null
+++ b/libs/socketio/namespace.py
@@ -0,0 +1,191 @@
+class BaseNamespace(object):
+ def __init__(self, namespace=None):
+ self.namespace = namespace or '/'
+
+ def is_asyncio_based(self):
+ return False
+
+ def trigger_event(self, event, *args):
+ """Dispatch an event to the proper handler method.
+
+ In the most common usage, this method is not overloaded by subclasses,
+ as it performs the routing of events to methods. However, this
+ method can be overriden if special dispatching rules are needed, or if
+ having a single method that catches all events is desired.
+ """
+ handler_name = 'on_' + event
+ if hasattr(self, handler_name):
+ return getattr(self, handler_name)(*args)
+
+
+class Namespace(BaseNamespace):
+ """Base class for server-side class-based namespaces.
+
+ A class-based namespace is a class that contains all the event handlers
+ for a Socket.IO namespace. The event handlers are methods of the class
+ with the prefix ``on_``, such as ``on_connect``, ``on_disconnect``,
+ ``on_message``, ``on_json``, and so on.
+
+ :param namespace: The Socket.IO namespace to be used with all the event
+ handlers defined in this class. If this argument is
+ omitted, the default namespace is used.
+ """
+ def __init__(self, namespace=None):
+ super(Namespace, self).__init__(namespace=namespace)
+ self.server = None
+
+ def _set_server(self, server):
+ self.server = server
+
+ def emit(self, event, data=None, room=None, skip_sid=None, namespace=None,
+ callback=None):
+ """Emit a custom event to one or more connected clients.
+
+ The only difference with the :func:`socketio.Server.emit` method is
+ that when the ``namespace`` argument is not given the namespace
+ associated with the class is used.
+ """
+ return self.server.emit(event, data=data, room=room, skip_sid=skip_sid,
+ namespace=namespace or self.namespace,
+ callback=callback)
+
+ def send(self, data, room=None, skip_sid=None, namespace=None,
+ callback=None):
+ """Send a message to one or more connected clients.
+
+ The only difference with the :func:`socketio.Server.send` method is
+ that when the ``namespace`` argument is not given the namespace
+ associated with the class is used.
+ """
+ return self.server.send(data, room=room, skip_sid=skip_sid,
+ namespace=namespace or self.namespace,
+ callback=callback)
+
+ def enter_room(self, sid, room, namespace=None):
+ """Enter a room.
+
+ The only difference with the :func:`socketio.Server.enter_room` method
+ is that when the ``namespace`` argument is not given the namespace
+ associated with the class is used.
+ """
+ return self.server.enter_room(sid, room,
+ namespace=namespace or self.namespace)
+
+ def leave_room(self, sid, room, namespace=None):
+ """Leave a room.
+
+ The only difference with the :func:`socketio.Server.leave_room` method
+ is that when the ``namespace`` argument is not given the namespace
+ associated with the class is used.
+ """
+ return self.server.leave_room(sid, room,
+ namespace=namespace or self.namespace)
+
+ def close_room(self, room, namespace=None):
+ """Close a room.
+
+ The only difference with the :func:`socketio.Server.close_room` method
+ is that when the ``namespace`` argument is not given the namespace
+ associated with the class is used.
+ """
+ return self.server.close_room(room,
+ namespace=namespace or self.namespace)
+
+ def rooms(self, sid, namespace=None):
+ """Return the rooms a client is in.
+
+ The only difference with the :func:`socketio.Server.rooms` method is
+ that when the ``namespace`` argument is not given the namespace
+ associated with the class is used.
+ """
+ return self.server.rooms(sid, namespace=namespace or self.namespace)
+
+ def get_session(self, sid, namespace=None):
+ """Return the user session for a client.
+
+ The only difference with the :func:`socketio.Server.get_session`
+ method is that when the ``namespace`` argument is not given the
+ namespace associated with the class is used.
+ """
+ return self.server.get_session(
+ sid, namespace=namespace or self.namespace)
+
+ def save_session(self, sid, session, namespace=None):
+ """Store the user session for a client.
+
+ The only difference with the :func:`socketio.Server.save_session`
+ method is that when the ``namespace`` argument is not given the
+ namespace associated with the class is used.
+ """
+ return self.server.save_session(
+ sid, session, namespace=namespace or self.namespace)
+
+ def session(self, sid, namespace=None):
+ """Return the user session for a client with context manager syntax.
+
+ The only difference with the :func:`socketio.Server.session` method is
+ that when the ``namespace`` argument is not given the namespace
+ associated with the class is used.
+ """
+ return self.server.session(sid, namespace=namespace or self.namespace)
+
+ def disconnect(self, sid, namespace=None):
+ """Disconnect a client.
+
+ The only difference with the :func:`socketio.Server.disconnect` method
+ is that when the ``namespace`` argument is not given the namespace
+ associated with the class is used.
+ """
+ return self.server.disconnect(sid,
+ namespace=namespace or self.namespace)
+
+
+class ClientNamespace(BaseNamespace):
+ """Base class for client-side class-based namespaces.
+
+ A class-based namespace is a class that contains all the event handlers
+ for a Socket.IO namespace. The event handlers are methods of the class
+ with the prefix ``on_``, such as ``on_connect``, ``on_disconnect``,
+ ``on_message``, ``on_json``, and so on.
+
+ :param namespace: The Socket.IO namespace to be used with all the event
+ handlers defined in this class. If this argument is
+ omitted, the default namespace is used.
+ """
+ def __init__(self, namespace=None):
+ super(ClientNamespace, self).__init__(namespace=namespace)
+ self.client = None
+
+ def _set_client(self, client):
+ self.client = client
+
+ def emit(self, event, data=None, namespace=None, callback=None):
+ """Emit a custom event to the server.
+
+ The only difference with the :func:`socketio.Client.emit` method is
+ that when the ``namespace`` argument is not given the namespace
+ associated with the class is used.
+ """
+ return self.client.emit(event, data=data,
+ namespace=namespace or self.namespace,
+ callback=callback)
+
+ def send(self, data, room=None, skip_sid=None, namespace=None,
+ callback=None):
+ """Send a message to the server.
+
+ The only difference with the :func:`socketio.Client.send` method is
+ that when the ``namespace`` argument is not given the namespace
+ associated with the class is used.
+ """
+ return self.client.send(data, namespace=namespace or self.namespace,
+ callback=callback)
+
+ def disconnect(self):
+ """Disconnect from the server.
+
+ The only difference with the :func:`socketio.Client.disconnect` method
+ is that when the ``namespace`` argument is not given the namespace
+ associated with the class is used.
+ """
+ return self.client.disconnect()
diff --git a/libs/socketio/packet.py b/libs/socketio/packet.py
new file mode 100644
index 000000000..73b469d6d
--- /dev/null
+++ b/libs/socketio/packet.py
@@ -0,0 +1,179 @@
+import functools
+import json as _json
+
+import six
+
+(CONNECT, DISCONNECT, EVENT, ACK, ERROR, BINARY_EVENT, BINARY_ACK) = \
+ (0, 1, 2, 3, 4, 5, 6)
+packet_names = ['CONNECT', 'DISCONNECT', 'EVENT', 'ACK', 'ERROR',
+ 'BINARY_EVENT', 'BINARY_ACK']
+
+
+class Packet(object):
+ """Socket.IO packet."""
+
+ # the format of the Socket.IO packet is as follows:
+ #
+ # packet type: 1 byte, values 0-6
+ # num_attachments: ASCII encoded, only if num_attachments != 0
+ # '-': only if num_attachments != 0
+ # namespace: only if namespace != '/'
+ # ',': only if namespace and one of id and data are defined in this packet
+ # id: ASCII encoded, only if id is not None
+ # data: JSON dump of data payload
+
+ json = _json
+
+ def __init__(self, packet_type=EVENT, data=None, namespace=None, id=None,
+ binary=None, encoded_packet=None):
+ self.packet_type = packet_type
+ self.data = data
+ self.namespace = namespace
+ self.id = id
+ if binary or (binary is None and self._data_is_binary(self.data)):
+ if self.packet_type == EVENT:
+ self.packet_type = BINARY_EVENT
+ elif self.packet_type == ACK:
+ self.packet_type = BINARY_ACK
+ else:
+ raise ValueError('Packet does not support binary payload.')
+ self.attachment_count = 0
+ self.attachments = []
+ if encoded_packet:
+ self.attachment_count = self.decode(encoded_packet)
+
+ def encode(self):
+ """Encode the packet for transmission.
+
+ If the packet contains binary elements, this function returns a list
+ of packets where the first is the original packet with placeholders for
+ the binary components and the remaining ones the binary attachments.
+ """
+ encoded_packet = six.text_type(self.packet_type)
+ if self.packet_type == BINARY_EVENT or self.packet_type == BINARY_ACK:
+ data, attachments = self._deconstruct_binary(self.data)
+ encoded_packet += six.text_type(len(attachments)) + '-'
+ else:
+ data = self.data
+ attachments = None
+ needs_comma = False
+ if self.namespace is not None and self.namespace != '/':
+ encoded_packet += self.namespace
+ needs_comma = True
+ if self.id is not None:
+ if needs_comma:
+ encoded_packet += ','
+ needs_comma = False
+ encoded_packet += six.text_type(self.id)
+ if data is not None:
+ if needs_comma:
+ encoded_packet += ','
+ encoded_packet += self.json.dumps(data, separators=(',', ':'))
+ if attachments is not None:
+ encoded_packet = [encoded_packet] + attachments
+ return encoded_packet
+
+ def decode(self, encoded_packet):
+ """Decode a transmitted package.
+
+ The return value indicates how many binary attachment packets are
+ necessary to fully decode the packet.
+ """
+ ep = encoded_packet
+ try:
+ self.packet_type = int(ep[0:1])
+ except TypeError:
+ self.packet_type = ep
+ ep = ''
+ self.namespace = None
+ self.data = None
+ ep = ep[1:]
+ dash = ep.find('-')
+ attachment_count = 0
+ if dash > 0 and ep[0:dash].isdigit():
+ attachment_count = int(ep[0:dash])
+ ep = ep[dash + 1:]
+ if ep and ep[0:1] == '/':
+ sep = ep.find(',')
+ if sep == -1:
+ self.namespace = ep
+ ep = ''
+ else:
+ self.namespace = ep[0:sep]
+ ep = ep[sep + 1:]
+ q = self.namespace.find('?')
+ if q != -1:
+ self.namespace = self.namespace[0:q]
+ if ep and ep[0].isdigit():
+ self.id = 0
+ while ep and ep[0].isdigit():
+ self.id = self.id * 10 + int(ep[0])
+ ep = ep[1:]
+ if ep:
+ self.data = self.json.loads(ep)
+ return attachment_count
+
+ def add_attachment(self, attachment):
+ if self.attachment_count <= len(self.attachments):
+ raise ValueError('Unexpected binary attachment')
+ self.attachments.append(attachment)
+ if self.attachment_count == len(self.attachments):
+ self.reconstruct_binary(self.attachments)
+ return True
+ return False
+
+ def reconstruct_binary(self, attachments):
+ """Reconstruct a decoded packet using the given list of binary
+ attachments.
+ """
+ self.data = self._reconstruct_binary_internal(self.data,
+ self.attachments)
+
+ def _reconstruct_binary_internal(self, data, attachments):
+ if isinstance(data, list):
+ return [self._reconstruct_binary_internal(item, attachments)
+ for item in data]
+ elif isinstance(data, dict):
+ if data.get('_placeholder') and 'num' in data:
+ return attachments[data['num']]
+ else:
+ return {key: self._reconstruct_binary_internal(value,
+ attachments)
+ for key, value in six.iteritems(data)}
+ else:
+ return data
+
+ def _deconstruct_binary(self, data):
+ """Extract binary components in the packet."""
+ attachments = []
+ data = self._deconstruct_binary_internal(data, attachments)
+ return data, attachments
+
+ def _deconstruct_binary_internal(self, data, attachments):
+ if isinstance(data, six.binary_type):
+ attachments.append(data)
+ return {'_placeholder': True, 'num': len(attachments) - 1}
+ elif isinstance(data, list):
+ return [self._deconstruct_binary_internal(item, attachments)
+ for item in data]
+ elif isinstance(data, dict):
+ return {key: self._deconstruct_binary_internal(value, attachments)
+ for key, value in six.iteritems(data)}
+ else:
+ return data
+
+ def _data_is_binary(self, data):
+ """Check if the data contains binary components."""
+ if isinstance(data, six.binary_type):
+ return True
+ elif isinstance(data, list):
+ return functools.reduce(
+ lambda a, b: a or b, [self._data_is_binary(item)
+ for item in data], False)
+ elif isinstance(data, dict):
+ return functools.reduce(
+ lambda a, b: a or b, [self._data_is_binary(item)
+ for item in six.itervalues(data)],
+ False)
+ else:
+ return False
diff --git a/libs/socketio/pubsub_manager.py b/libs/socketio/pubsub_manager.py
new file mode 100644
index 000000000..2905b2c32
--- /dev/null
+++ b/libs/socketio/pubsub_manager.py
@@ -0,0 +1,154 @@
+from functools import partial
+import uuid
+
+import json
+import pickle
+import six
+
+from .base_manager import BaseManager
+
+
+class PubSubManager(BaseManager):
+ """Manage a client list attached to a pub/sub backend.
+
+ This is a base class that enables multiple servers to share the list of
+ clients, with the servers communicating events through a pub/sub backend.
+ The use of a pub/sub backend also allows any client connected to the
+ backend to emit events addressed to Socket.IO clients.
+
+ The actual backends must be implemented by subclasses, this class only
+ provides a pub/sub generic framework.
+
+ :param channel: The channel name on which the server sends and receives
+ notifications.
+ """
+ name = 'pubsub'
+
+ def __init__(self, channel='socketio', write_only=False, logger=None):
+ super(PubSubManager, self).__init__()
+ self.channel = channel
+ self.write_only = write_only
+ self.host_id = uuid.uuid4().hex
+ self.logger = logger
+
+ def initialize(self):
+ super(PubSubManager, self).initialize()
+ if not self.write_only:
+ self.thread = self.server.start_background_task(self._thread)
+ self._get_logger().info(self.name + ' backend initialized.')
+
+ def emit(self, event, data, namespace=None, room=None, skip_sid=None,
+ callback=None, **kwargs):
+ """Emit a message to a single client, a room, or all the clients
+ connected to the namespace.
+
+ This method takes care or propagating the message to all the servers
+ that are connected through the message queue.
+
+ The parameters are the same as in :meth:`.Server.emit`.
+ """
+ if kwargs.get('ignore_queue'):
+ return super(PubSubManager, self).emit(
+ event, data, namespace=namespace, room=room, skip_sid=skip_sid,
+ callback=callback)
+ namespace = namespace or '/'
+ if callback is not None:
+ if self.server is None:
+ raise RuntimeError('Callbacks can only be issued from the '
+ 'context of a server.')
+ if room is None:
+ raise ValueError('Cannot use callback without a room set.')
+ id = self._generate_ack_id(room, namespace, callback)
+ callback = (room, namespace, id)
+ else:
+ callback = None
+ self._publish({'method': 'emit', 'event': event, 'data': data,
+ 'namespace': namespace, 'room': room,
+ 'skip_sid': skip_sid, 'callback': callback,
+ 'host_id': self.host_id})
+
+ def close_room(self, room, namespace=None):
+ self._publish({'method': 'close_room', 'room': room,
+ 'namespace': namespace or '/'})
+
+ def _publish(self, data):
+ """Publish a message on the Socket.IO channel.
+
+ This method needs to be implemented by the different subclasses that
+ support pub/sub backends.
+ """
+ raise NotImplementedError('This method must be implemented in a '
+ 'subclass.') # pragma: no cover
+
+ def _listen(self):
+ """Return the next message published on the Socket.IO channel,
+ blocking until a message is available.
+
+ This method needs to be implemented by the different subclasses that
+ support pub/sub backends.
+ """
+ raise NotImplementedError('This method must be implemented in a '
+ 'subclass.') # pragma: no cover
+
+ def _handle_emit(self, message):
+ # Events with callbacks are very tricky to handle across hosts
+ # Here in the receiving end we set up a local callback that preserves
+ # the callback host and id from the sender
+ remote_callback = message.get('callback')
+ remote_host_id = message.get('host_id')
+ if remote_callback is not None and len(remote_callback) == 3:
+ callback = partial(self._return_callback, remote_host_id,
+ *remote_callback)
+ else:
+ callback = None
+ super(PubSubManager, self).emit(message['event'], message['data'],
+ namespace=message.get('namespace'),
+ room=message.get('room'),
+ skip_sid=message.get('skip_sid'),
+ callback=callback)
+
+ def _handle_callback(self, message):
+ if self.host_id == message.get('host_id'):
+ try:
+ sid = message['sid']
+ namespace = message['namespace']
+ id = message['id']
+ args = message['args']
+ except KeyError:
+ return
+ self.trigger_callback(sid, namespace, id, args)
+
+ def _return_callback(self, host_id, sid, namespace, callback_id, *args):
+ # When an event callback is received, the callback is returned back
+ # the sender, which is identified by the host_id
+ self._publish({'method': 'callback', 'host_id': host_id,
+ 'sid': sid, 'namespace': namespace, 'id': callback_id,
+ 'args': args})
+
+ def _handle_close_room(self, message):
+ super(PubSubManager, self).close_room(
+ room=message.get('room'), namespace=message.get('namespace'))
+
+ def _thread(self):
+ for message in self._listen():
+ data = None
+ if isinstance(message, dict):
+ data = message
+ else:
+ if isinstance(message, six.binary_type): # pragma: no cover
+ try:
+ data = pickle.loads(message)
+ except:
+ pass
+ if data is None:
+ try:
+ data = json.loads(message)
+ except:
+ pass
+ if data and 'method' in data:
+ if data['method'] == 'emit':
+ self._handle_emit(data)
+ elif data['method'] == 'callback':
+ self._handle_callback(data)
+ elif data['method'] == 'close_room':
+ self._handle_close_room(data)
diff --git a/libs/socketio/redis_manager.py b/libs/socketio/redis_manager.py
new file mode 100644
index 000000000..ad383345e
--- /dev/null
+++ b/libs/socketio/redis_manager.py
@@ -0,0 +1,115 @@
+import logging
+import pickle
+import time
+
+try:
+ import redis
+except ImportError:
+ redis = None
+
+from .pubsub_manager import PubSubManager
+
+logger = logging.getLogger('socketio')
+
+
+class RedisManager(PubSubManager): # pragma: no cover
+ """Redis based client manager.
+
+ This class implements a Redis backend for event sharing across multiple
+ processes. Only kept here as one more example of how to build a custom
+ backend, since the kombu backend is perfectly adequate to support a Redis
+ message queue.
+
+ To use a Redis backend, initialize the :class:`Server` instance as
+ follows::
+
+ url = 'redis://hostname:port/0'
+ server = socketio.Server(client_manager=socketio.RedisManager(url))
+
+ :param url: The connection URL for the Redis server. For a default Redis
+ store running on the same host, use ``redis://``.
+ :param channel: The channel name on which the server sends and receives
+ notifications. Must be the same in all the servers.
+ :param write_only: If set ot ``True``, only initialize to emit events. The
+ default of ``False`` initializes the class for emitting
+ and receiving.
+ :param redis_options: additional keyword arguments to be passed to
+ ``Redis.from_url()``.
+ """
+ name = 'redis'
+
+ def __init__(self, url='redis://localhost:6379/0', channel='socketio',
+ write_only=False, logger=None, redis_options=None):
+ if redis is None:
+ raise RuntimeError('Redis package is not installed '
+ '(Run "pip install redis" in your '
+ 'virtualenv).')
+ self.redis_url = url
+ self.redis_options = redis_options or {}
+ self._redis_connect()
+ super(RedisManager, self).__init__(channel=channel,
+ write_only=write_only,
+ logger=logger)
+
+ def initialize(self):
+ super(RedisManager, self).initialize()
+
+ monkey_patched = True
+ if self.server.async_mode == 'eventlet':
+ from eventlet.patcher import is_monkey_patched
+ monkey_patched = is_monkey_patched('socket')
+ elif 'gevent' in self.server.async_mode:
+ from gevent.monkey import is_module_patched
+ monkey_patched = is_module_patched('socket')
+ if not monkey_patched:
+ raise RuntimeError(
+ 'Redis requires a monkey patched socket library to work '
+ 'with ' + self.server.async_mode)
+
+ def _redis_connect(self):
+ self.redis = redis.Redis.from_url(self.redis_url,
+ **self.redis_options)
+ self.pubsub = self.redis.pubsub()
+
+ def _publish(self, data):
+ retry = True
+ while True:
+ try:
+ if not retry:
+ self._redis_connect()
+ return self.redis.publish(self.channel, pickle.dumps(data))
+ except redis.exceptions.ConnectionError:
+ if retry:
+ logger.error('Cannot publish to redis... retrying')
+ retry = False
+ else:
+ logger.error('Cannot publish to redis... giving up')
+ break
+
+ def _redis_listen_with_retries(self):
+ retry_sleep = 1
+ connect = False
+ while True:
+ try:
+ if connect:
+ self._redis_connect()
+ self.pubsub.subscribe(self.channel)
+ for message in self.pubsub.listen():
+ yield message
+ except redis.exceptions.ConnectionError:
+ logger.error('Cannot receive from redis... '
+ 'retrying in {} secs'.format(retry_sleep))
+ connect = True
+ time.sleep(retry_sleep)
+ retry_sleep *= 2
+ if retry_sleep > 60:
+ retry_sleep = 60
+
+ def _listen(self):
+ channel = self.channel.encode('utf-8')
+ self.pubsub.subscribe(self.channel)
+ for message in self._redis_listen_with_retries():
+ if message['channel'] == channel and \
+ message['type'] == 'message' and 'data' in message:
+ yield message['data']
+ self.pubsub.unsubscribe(self.channel)
diff --git a/libs/socketio/server.py b/libs/socketio/server.py
new file mode 100644
index 000000000..76b7d2e8f
--- /dev/null
+++ b/libs/socketio/server.py
@@ -0,0 +1,730 @@
+import logging
+
+import engineio
+import six
+
+from . import base_manager
+from . import exceptions
+from . import namespace
+from . import packet
+
+default_logger = logging.getLogger('socketio.server')
+
+
+class Server(object):
+ """A Socket.IO server.
+
+ This class implements a fully compliant Socket.IO web server with support
+ for websocket and long-polling transports.
+
+ :param client_manager: The client manager instance that will manage the
+ client list. When this is omitted, the client list
+ is stored in an in-memory structure, so the use of
+ multiple connected servers is not possible.
+ :param logger: To enable logging set to ``True`` or pass a logger object to
+ use. To disable logging set to ``False``. The default is
+ ``False``.
+ :param binary: ``True`` to support binary payloads, ``False`` to treat all
+ payloads as text. On Python 2, if this is set to ``True``,
+ ``unicode`` values are treated as text, and ``str`` and
+ ``bytes`` values are treated as binary. This option has no
+ effect on Python 3, where text and binary payloads are
+ always automatically discovered.
+ :param json: An alternative json module to use for encoding and decoding
+ packets. Custom json modules must have ``dumps`` and ``loads``
+ functions that are compatible with the standard library
+ versions.
+ :param async_handlers: If set to ``True``, event handlers for a client are
+ executed in separate threads. To run handlers for a
+ client synchronously, set to ``False``. The default
+ is ``True``.
+ :param always_connect: When set to ``False``, new connections are
+ provisory until the connect handler returns
+ something other than ``False``, at which point they
+ are accepted. When set to ``True``, connections are
+ immediately accepted, and then if the connect
+ handler returns ``False`` a disconnect is issued.
+ Set to ``True`` if you need to emit events from the
+ connect handler and your client is confused when it
+ receives events before the connection acceptance.
+ In any other case use the default of ``False``.
+ :param kwargs: Connection parameters for the underlying Engine.IO server.
+
+ The Engine.IO configuration supports the following settings:
+
+ :param async_mode: The asynchronous model to use. See the Deployment
+ section in the documentation for a description of the
+ available options. Valid async modes are "threading",
+ "eventlet", "gevent" and "gevent_uwsgi". If this
+ argument is not given, "eventlet" is tried first, then
+ "gevent_uwsgi", then "gevent", and finally "threading".
+ The first async mode that has all its dependencies
+ installed is then one that is chosen.
+ :param ping_timeout: The time in seconds that the client waits for the
+ server to respond before disconnecting. The default
+ is 60 seconds.
+ :param ping_interval: The interval in seconds at which the client pings
+ the server. The default is 25 seconds.
+ :param max_http_buffer_size: The maximum size of a message when using the
+ polling transport. The default is 100,000,000
+ bytes.
+ :param allow_upgrades: Whether to allow transport upgrades or not. The
+ default is ``True``.
+ :param http_compression: Whether to compress packages when using the
+ polling transport. The default is ``True``.
+ :param compression_threshold: Only compress messages when their byte size
+ is greater than this value. The default is
+ 1024 bytes.
+ :param cookie: Name of the HTTP cookie that contains the client session
+ id. If set to ``None``, a cookie is not sent to the client.
+ The default is ``'io'``.
+ :param cors_allowed_origins: Origin or list of origins that are allowed to
+ connect to this server. Only the same origin
+ is allowed by default. Set this argument to
+ ``'*'`` to allow all origins, or to ``[]`` to
+ disable CORS handling.
+ :param cors_credentials: Whether credentials (cookies, authentication) are
+ allowed in requests to this server. The default is
+ ``True``.
+ :param monitor_clients: If set to ``True``, a background task will ensure
+ inactive clients are closed. Set to ``False`` to
+ disable the monitoring task (not recommended). The
+ default is ``True``.
+ :param engineio_logger: To enable Engine.IO logging set to ``True`` or pass
+ a logger object to use. To disable logging set to
+ ``False``. The default is ``False``.
+ """
+ def __init__(self, client_manager=None, logger=False, binary=False,
+ json=None, async_handlers=True, always_connect=False,
+ **kwargs):
+ engineio_options = kwargs
+ engineio_logger = engineio_options.pop('engineio_logger', None)
+ if engineio_logger is not None:
+ engineio_options['logger'] = engineio_logger
+ if json is not None:
+ packet.Packet.json = json
+ engineio_options['json'] = json
+ engineio_options['async_handlers'] = False
+ self.eio = self._engineio_server_class()(**engineio_options)
+ self.eio.on('connect', self._handle_eio_connect)
+ self.eio.on('message', self._handle_eio_message)
+ self.eio.on('disconnect', self._handle_eio_disconnect)
+ self.binary = binary
+
+ self.environ = {}
+ self.handlers = {}
+ self.namespace_handlers = {}
+
+ self._binary_packet = {}
+
+ if not isinstance(logger, bool):
+ self.logger = logger
+ else:
+ self.logger = default_logger
+ if not logging.root.handlers and \
+ self.logger.level == logging.NOTSET:
+ if logger:
+ self.logger.setLevel(logging.INFO)
+ else:
+ self.logger.setLevel(logging.ERROR)
+ self.logger.addHandler(logging.StreamHandler())
+
+ if client_manager is None:
+ client_manager = base_manager.BaseManager()
+ self.manager = client_manager
+ self.manager.set_server(self)
+ self.manager_initialized = False
+
+ self.async_handlers = async_handlers
+ self.always_connect = always_connect
+
+ self.async_mode = self.eio.async_mode
+
+ def is_asyncio_based(self):
+ return False
+
+ def on(self, event, handler=None, namespace=None):
+ """Register an event handler.
+
+ :param event: The event name. It can be any string. The event names
+ ``'connect'``, ``'message'`` and ``'disconnect'`` are
+ reserved and should not be used.
+ :param handler: The function that should be invoked to handle the
+ event. When this parameter is not given, the method
+ acts as a decorator for the handler function.
+ :param namespace: The Socket.IO namespace for the event. If this
+ argument is omitted the handler is associated with
+ the default namespace.
+
+ Example usage::
+
+ # as a decorator:
+ @socket_io.on('connect', namespace='/chat')
+ def connect_handler(sid, environ):
+ print('Connection request')
+ if environ['REMOTE_ADDR'] in blacklisted:
+ return False # reject
+
+ # as a method:
+ def message_handler(sid, msg):
+ print('Received message: ', msg)
+ eio.send(sid, 'response')
+ socket_io.on('message', namespace='/chat', message_handler)
+
+ The handler function receives the ``sid`` (session ID) for the
+ client as first argument. The ``'connect'`` event handler receives the
+ WSGI environment as a second argument, and can return ``False`` to
+ reject the connection. The ``'message'`` handler and handlers for
+ custom event names receive the message payload as a second argument.
+ Any values returned from a message handler will be passed to the
+ client's acknowledgement callback function if it exists. The
+ ``'disconnect'`` handler does not take a second argument.
+ """
+ namespace = namespace or '/'
+
+ def set_handler(handler):
+ if namespace not in self.handlers:
+ self.handlers[namespace] = {}
+ self.handlers[namespace][event] = handler
+ return handler
+
+ if handler is None:
+ return set_handler
+ set_handler(handler)
+
+ def event(self, *args, **kwargs):
+ """Decorator to register an event handler.
+
+ This is a simplified version of the ``on()`` method that takes the
+ event name from the decorated function.
+
+ Example usage::
+
+ @sio.event
+ def my_event(data):
+ print('Received data: ', data)
+
+ The above example is equivalent to::
+
+ @sio.on('my_event')
+ def my_event(data):
+ print('Received data: ', data)
+
+ A custom namespace can be given as an argument to the decorator::
+
+ @sio.event(namespace='/test')
+ def my_event(data):
+ print('Received data: ', data)
+ """
+ if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
+ # the decorator was invoked without arguments
+ # args[0] is the decorated function
+ return self.on(args[0].__name__)(args[0])
+ else:
+ # the decorator was invoked with arguments
+ def set_handler(handler):
+ return self.on(handler.__name__, *args, **kwargs)(handler)
+
+ return set_handler
+
+ def register_namespace(self, namespace_handler):
+ """Register a namespace handler object.
+
+ :param namespace_handler: An instance of a :class:`Namespace`
+ subclass that handles all the event traffic
+ for a namespace.
+ """
+ if not isinstance(namespace_handler, namespace.Namespace):
+ raise ValueError('Not a namespace instance')
+ if self.is_asyncio_based() != namespace_handler.is_asyncio_based():
+ raise ValueError('Not a valid namespace class for this server')
+ namespace_handler._set_server(self)
+ self.namespace_handlers[namespace_handler.namespace] = \
+ namespace_handler
+
+ def emit(self, event, data=None, to=None, room=None, skip_sid=None,
+ namespace=None, callback=None, **kwargs):
+ """Emit a custom event to one or more connected clients.
+
+ :param event: The event name. It can be any string. The event names
+ ``'connect'``, ``'message'`` and ``'disconnect'`` are
+ reserved and should not be used.
+ :param data: The data to send to the client or clients. Data can be of
+ type ``str``, ``bytes``, ``list`` or ``dict``. If a
+ ``list`` or ``dict``, the data will be serialized as JSON.
+ :param to: The recipient of the message. This can be set to the
+ session ID of a client to address only that client, or to
+ to any custom room created by the application to address all
+ the clients in that room, If this argument is omitted the
+ event is broadcasted to all connected clients.
+ :param room: Alias for the ``to`` parameter.
+ :param skip_sid: The session ID of a client to skip when broadcasting
+ to a room or to all clients. This can be used to
+ prevent a message from being sent to the sender. To
+ skip multiple sids, pass a list.
+ :param namespace: The Socket.IO namespace for the event. If this
+ argument is omitted the event is emitted to the
+ default namespace.
+ :param callback: If given, this function will be called to acknowledge
+ the the client has received the message. The arguments
+ that will be passed to the function are those provided
+ by the client. Callback functions can only be used
+ when addressing an individual client.
+ :param ignore_queue: Only used when a message queue is configured. If
+ set to ``True``, the event is emitted to the
+ clients directly, without going through the queue.
+ This is more efficient, but only works when a
+ single server process is used. It is recommended
+ to always leave this parameter with its default
+ value of ``False``.
+ """
+ namespace = namespace or '/'
+ room = to or room
+ self.logger.info('emitting event "%s" to %s [%s]', event,
+ room or 'all', namespace)
+ self.manager.emit(event, data, namespace, room=room,
+ skip_sid=skip_sid, callback=callback, **kwargs)
+
+ def send(self, data, to=None, room=None, skip_sid=None, namespace=None,
+ callback=None, **kwargs):
+ """Send a message to one or more connected clients.
+
+ This function emits an event with the name ``'message'``. Use
+ :func:`emit` to issue custom event names.
+
+ :param data: The data to send to the client or clients. Data can be of
+ type ``str``, ``bytes``, ``list`` or ``dict``. If a
+ ``list`` or ``dict``, the data will be serialized as JSON.
+ :param to: The recipient of the message. This can be set to the
+ session ID of a client to address only that client, or to
+ to any custom room created by the application to address all
+ the clients in that room, If this argument is omitted the
+ event is broadcasted to all connected clients.
+ :param room: Alias for the ``to`` parameter.
+ :param skip_sid: The session ID of a client to skip when broadcasting
+ to a room or to all clients. This can be used to
+ prevent a message from being sent to the sender. To
+ skip multiple sids, pass a list.
+ :param namespace: The Socket.IO namespace for the event. If this
+ argument is omitted the event is emitted to the
+ default namespace.
+ :param callback: If given, this function will be called to acknowledge
+ the the client has received the message. The arguments
+ that will be passed to the function are those provided
+ by the client. Callback functions can only be used
+ when addressing an individual client.
+ :param ignore_queue: Only used when a message queue is configured. If
+ set to ``True``, the event is emitted to the
+ clients directly, without going through the queue.
+ This is more efficient, but only works when a
+ single server process is used. It is recommended
+ to always leave this parameter with its default
+ value of ``False``.
+ """
+ self.emit('message', data=data, to=to, room=room, skip_sid=skip_sid,
+ namespace=namespace, callback=callback, **kwargs)
+
+ def call(self, event, data=None, to=None, sid=None, namespace=None,
+ timeout=60, **kwargs):
+ """Emit a custom event to a client and wait for the response.
+
+ :param event: The event name. It can be any string. The event names
+ ``'connect'``, ``'message'`` and ``'disconnect'`` are
+ reserved and should not be used.
+ :param data: The data to send to the client or clients. Data can be of
+ type ``str``, ``bytes``, ``list`` or ``dict``. If a
+ ``list`` or ``dict``, the data will be serialized as JSON.
+ :param to: The session ID of the recipient client.
+ :param sid: Alias for the ``to`` parameter.
+ :param namespace: The Socket.IO namespace for the event. If this
+ argument is omitted the event is emitted to the
+ default namespace.
+ :param timeout: The waiting timeout. If the timeout is reached before
+ the client acknowledges the event, then a
+ ``TimeoutError`` exception is raised.
+ :param ignore_queue: Only used when a message queue is configured. If
+ set to ``True``, the event is emitted to the
+ client directly, without going through the queue.
+ This is more efficient, but only works when a
+ single server process is used. It is recommended
+ to always leave this parameter with its default
+ value of ``False``.
+ """
+ if not self.async_handlers:
+ raise RuntimeError(
+ 'Cannot use call() when async_handlers is False.')
+ callback_event = self.eio.create_event()
+ callback_args = []
+
+ def event_callback(*args):
+ callback_args.append(args)
+ callback_event.set()
+
+ self.emit(event, data=data, room=to or sid, namespace=namespace,
+ callback=event_callback, **kwargs)
+ if not callback_event.wait(timeout=timeout):
+ raise exceptions.TimeoutError()
+ return callback_args[0] if len(callback_args[0]) > 1 \
+ else callback_args[0][0] if len(callback_args[0]) == 1 \
+ else None
+
+ def enter_room(self, sid, room, namespace=None):
+ """Enter a room.
+
+ This function adds the client to a room. The :func:`emit` and
+ :func:`send` functions can optionally broadcast events to all the
+ clients in a room.
+
+ :param sid: Session ID of the client.
+ :param room: Room name. If the room does not exist it is created.
+ :param namespace: The Socket.IO namespace for the event. If this
+ argument is omitted the default namespace is used.
+ """
+ namespace = namespace or '/'
+ self.logger.info('%s is entering room %s [%s]', sid, room, namespace)
+ self.manager.enter_room(sid, namespace, room)
+
+ def leave_room(self, sid, room, namespace=None):
+ """Leave a room.
+
+ This function removes the client from a room.
+
+ :param sid: Session ID of the client.
+ :param room: Room name.
+ :param namespace: The Socket.IO namespace for the event. If this
+ argument is omitted the default namespace is used.
+ """
+ namespace = namespace or '/'
+ self.logger.info('%s is leaving room %s [%s]', sid, room, namespace)
+ self.manager.leave_room(sid, namespace, room)
+
+ def close_room(self, room, namespace=None):
+ """Close a room.
+
+ This function removes all the clients from the given room.
+
+ :param room: Room name.
+ :param namespace: The Socket.IO namespace for the event. If this
+ argument is omitted the default namespace is used.
+ """
+ namespace = namespace or '/'
+ self.logger.info('room %s is closing [%s]', room, namespace)
+ self.manager.close_room(room, namespace)
+
+ def rooms(self, sid, namespace=None):
+ """Return the rooms a client is in.
+
+ :param sid: Session ID of the client.
+ :param namespace: The Socket.IO namespace for the event. If this
+ argument is omitted the default namespace is used.
+ """
+ namespace = namespace or '/'
+ return self.manager.get_rooms(sid, namespace)
+
+ def get_session(self, sid, namespace=None):
+ """Return the user session for a client.
+
+ :param sid: The session id of the client.
+ :param namespace: The Socket.IO namespace. If this argument is omitted
+ the default namespace is used.
+
+ The return value is a dictionary. Modifications made to this
+ dictionary are not guaranteed to be preserved unless
+ ``save_session()`` is called, or when the ``session`` context manager
+ is used.
+ """
+ namespace = namespace or '/'
+ eio_session = self.eio.get_session(sid)
+ return eio_session.setdefault(namespace, {})
+
+ def save_session(self, sid, session, namespace=None):
+ """Store the user session for a client.
+
+ :param sid: The session id of the client.
+ :param session: The session dictionary.
+ :param namespace: The Socket.IO namespace. If this argument is omitted
+ the default namespace is used.
+ """
+ namespace = namespace or '/'
+ eio_session = self.eio.get_session(sid)
+ eio_session[namespace] = session
+
+ def session(self, sid, namespace=None):
+ """Return the user session for a client with context manager syntax.
+
+ :param sid: The session id of the client.
+
+ This is a context manager that returns the user session dictionary for
+ the client. Any changes that are made to this dictionary inside the
+ context manager block are saved back to the session. Example usage::
+
+ @sio.on('connect')
+ def on_connect(sid, environ):
+ username = authenticate_user(environ)
+ if not username:
+ return False
+ with sio.session(sid) as session:
+ session['username'] = username
+
+ @sio.on('message')
+ def on_message(sid, msg):
+ with sio.session(sid) as session:
+ print('received message from ', session['username'])
+ """
+ class _session_context_manager(object):
+ def __init__(self, server, sid, namespace):
+ self.server = server
+ self.sid = sid
+ self.namespace = namespace
+ self.session = None
+
+ def __enter__(self):
+ self.session = self.server.get_session(sid,
+ namespace=namespace)
+ return self.session
+
+ def __exit__(self, *args):
+ self.server.save_session(sid, self.session,
+ namespace=namespace)
+
+ return _session_context_manager(self, sid, namespace)
+
+ def disconnect(self, sid, namespace=None):
+ """Disconnect a client.
+
+ :param sid: Session ID of the client.
+ :param namespace: The Socket.IO namespace to disconnect. If this
+ argument is omitted the default namespace is used.
+ """
+ namespace = namespace or '/'
+ if self.manager.is_connected(sid, namespace=namespace):
+ self.logger.info('Disconnecting %s [%s]', sid, namespace)
+ self.manager.pre_disconnect(sid, namespace=namespace)
+ self._send_packet(sid, packet.Packet(packet.DISCONNECT,
+ namespace=namespace))
+ self._trigger_event('disconnect', namespace, sid)
+ self.manager.disconnect(sid, namespace=namespace)
+ if namespace == '/':
+ self.eio.disconnect(sid)
+
+ def transport(self, sid):
+ """Return the name of the transport used by the client.
+
+ The two possible values returned by this function are ``'polling'``
+ and ``'websocket'``.
+
+ :param sid: The session of the client.
+ """
+ return self.eio.transport(sid)
+
+ def handle_request(self, environ, start_response):
+ """Handle an HTTP request from the client.
+
+ This is the entry point of the Socket.IO application, using the same
+ interface as a WSGI application. For the typical usage, this function
+ is invoked by the :class:`Middleware` instance, but it can be invoked
+ directly when the middleware is not used.
+
+ :param environ: The WSGI environment.
+ :param start_response: The WSGI ``start_response`` function.
+
+ This function returns the HTTP response body to deliver to the client
+ as a byte sequence.
+ """
+ return self.eio.handle_request(environ, start_response)
+
+ def start_background_task(self, target, *args, **kwargs):
+ """Start a background task using the appropriate async model.
+
+ This is a utility function that applications can use to start a
+ background task using the method that is compatible with the
+ selected async mode.
+
+ :param target: the target function to execute.
+ :param args: arguments to pass to the function.
+ :param kwargs: keyword arguments to pass to the function.
+
+ This function returns an object compatible with the `Thread` class in
+ the Python standard library. The `start()` method on this object is
+ already called by this function.
+ """
+ return self.eio.start_background_task(target, *args, **kwargs)
+
+ def sleep(self, seconds=0):
+ """Sleep for the requested amount of time using the appropriate async
+ model.
+
+ This is a utility function that applications can use to put a task to
+ sleep without having to worry about using the correct call for the
+ selected async mode.
+ """
+ return self.eio.sleep(seconds)
+
+ def _emit_internal(self, sid, event, data, namespace=None, id=None):
+ """Send a message to a client."""
+ if six.PY2 and not self.binary:
+ binary = False # pragma: nocover
+ else:
+ binary = None
+ # tuples are expanded to multiple arguments, everything else is sent
+ # as a single argument
+ if isinstance(data, tuple):
+ data = list(data)
+ else:
+ data = [data]
+ self._send_packet(sid, packet.Packet(packet.EVENT, namespace=namespace,
+ data=[event] + data, id=id,
+ binary=binary))
+
+ def _send_packet(self, sid, pkt):
+ """Send a Socket.IO packet to a client."""
+ encoded_packet = pkt.encode()
+ if isinstance(encoded_packet, list):
+ binary = False
+ for ep in encoded_packet:
+ self.eio.send(sid, ep, binary=binary)
+ binary = True
+ else:
+ self.eio.send(sid, encoded_packet, binary=False)
+
+ def _handle_connect(self, sid, namespace):
+ """Handle a client connection request."""
+ namespace = namespace or '/'
+ self.manager.connect(sid, namespace)
+ if self.always_connect:
+ self._send_packet(sid, packet.Packet(packet.CONNECT,
+ namespace=namespace))
+ fail_reason = None
+ try:
+ success = self._trigger_event('connect', namespace, sid,
+ self.environ[sid])
+ except exceptions.ConnectionRefusedError as exc:
+ fail_reason = exc.error_args
+ success = False
+
+ if success is False:
+ if self.always_connect:
+ self.manager.pre_disconnect(sid, namespace)
+ self._send_packet(sid, packet.Packet(
+ packet.DISCONNECT, data=fail_reason, namespace=namespace))
+ self.manager.disconnect(sid, namespace)
+ if not self.always_connect:
+ self._send_packet(sid, packet.Packet(
+ packet.ERROR, data=fail_reason, namespace=namespace))
+ if sid in self.environ: # pragma: no cover
+ del self.environ[sid]
+ elif not self.always_connect:
+ self._send_packet(sid, packet.Packet(packet.CONNECT,
+ namespace=namespace))
+
+ def _handle_disconnect(self, sid, namespace):
+ """Handle a client disconnect."""
+ namespace = namespace or '/'
+ if namespace == '/':
+ namespace_list = list(self.manager.get_namespaces())
+ else:
+ namespace_list = [namespace]
+ for n in namespace_list:
+ if n != '/' and self.manager.is_connected(sid, n):
+ self._trigger_event('disconnect', n, sid)
+ self.manager.disconnect(sid, n)
+ if namespace == '/' and self.manager.is_connected(sid, namespace):
+ self._trigger_event('disconnect', '/', sid)
+ self.manager.disconnect(sid, '/')
+
+ def _handle_event(self, sid, namespace, id, data):
+ """Handle an incoming client event."""
+ namespace = namespace or '/'
+ self.logger.info('received event "%s" from %s [%s]', data[0], sid,
+ namespace)
+ if not self.manager.is_connected(sid, namespace):
+ self.logger.warning('%s is not connected to namespace %s',
+ sid, namespace)
+ return
+ if self.async_handlers:
+ self.start_background_task(self._handle_event_internal, self, sid,
+ data, namespace, id)
+ else:
+ self._handle_event_internal(self, sid, data, namespace, id)
+
+ def _handle_event_internal(self, server, sid, data, namespace, id):
+ r = server._trigger_event(data[0], namespace, sid, *data[1:])
+ if id is not None:
+ # send ACK packet with the response returned by the handler
+ # tuples are expanded as multiple arguments
+ if r is None:
+ data = []
+ elif isinstance(r, tuple):
+ data = list(r)
+ else:
+ data = [r]
+ if six.PY2 and not self.binary:
+ binary = False # pragma: nocover
+ else:
+ binary = None
+ server._send_packet(sid, packet.Packet(packet.ACK,
+ namespace=namespace,
+ id=id, data=data,
+ binary=binary))
+
+ def _handle_ack(self, sid, namespace, id, data):
+ """Handle ACK packets from the client."""
+ namespace = namespace or '/'
+ self.logger.info('received ack from %s [%s]', sid, namespace)
+ self.manager.trigger_callback(sid, namespace, id, data)
+
+ def _trigger_event(self, event, namespace, *args):
+ """Invoke an application event handler."""
+ # first see if we have an explicit handler for the event
+ if namespace in self.handlers and event in self.handlers[namespace]:
+ return self.handlers[namespace][event](*args)
+
+ # or else, forward the event to a namespace handler if one exists
+ elif namespace in self.namespace_handlers:
+ return self.namespace_handlers[namespace].trigger_event(
+ event, *args)
+
+ def _handle_eio_connect(self, sid, environ):
+ """Handle the Engine.IO connection event."""
+ if not self.manager_initialized:
+ self.manager_initialized = True
+ self.manager.initialize()
+ self.environ[sid] = environ
+ return self._handle_connect(sid, '/')
+
+ def _handle_eio_message(self, sid, data):
+ """Dispatch Engine.IO messages."""
+ if sid in self._binary_packet:
+ pkt = self._binary_packet[sid]
+ if pkt.add_attachment(data):
+ del self._binary_packet[sid]
+ if pkt.packet_type == packet.BINARY_EVENT:
+ self._handle_event(sid, pkt.namespace, pkt.id, pkt.data)
+ else:
+ self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data)
+ else:
+ pkt = packet.Packet(encoded_packet=data)
+ if pkt.packet_type == packet.CONNECT:
+ self._handle_connect(sid, pkt.namespace)
+ elif pkt.packet_type == packet.DISCONNECT:
+ self._handle_disconnect(sid, pkt.namespace)
+ elif pkt.packet_type == packet.EVENT:
+ self._handle_event(sid, pkt.namespace, pkt.id, pkt.data)
+ elif pkt.packet_type == packet.ACK:
+ self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data)
+ elif pkt.packet_type == packet.BINARY_EVENT or \
+ pkt.packet_type == packet.BINARY_ACK:
+ self._binary_packet[sid] = pkt
+ elif pkt.packet_type == packet.ERROR:
+ raise ValueError('Unexpected ERROR packet.')
+ else:
+ raise ValueError('Unknown packet type.')
+
+ def _handle_eio_disconnect(self, sid):
+ """Handle Engine.IO disconnect event."""
+ self._handle_disconnect(sid, '/')
+ if sid in self.environ:
+ del self.environ[sid]
+
+ def _engineio_server_class(self):
+ return engineio.Server
diff --git a/libs/socketio/tornado.py b/libs/socketio/tornado.py
new file mode 100644
index 000000000..5b2e6f684
--- /dev/null
+++ b/libs/socketio/tornado.py
@@ -0,0 +1,11 @@
+import sys
+if sys.version_info >= (3, 5):
+ try:
+ from engineio.async_drivers.tornado import get_tornado_handler as \
+ get_engineio_handler
+ except ImportError: # pragma: no cover
+ get_engineio_handler = None
+
+
+def get_tornado_handler(socketio_server): # pragma: no cover
+ return get_engineio_handler(socketio_server.eio)
diff --git a/libs/socketio/zmq_manager.py b/libs/socketio/zmq_manager.py
new file mode 100644
index 000000000..f2a2ae5dc
--- /dev/null
+++ b/libs/socketio/zmq_manager.py
@@ -0,0 +1,111 @@
+import pickle
+import re
+
+try:
+ import eventlet.green.zmq as zmq
+except ImportError:
+ zmq = None
+import six
+
+from .pubsub_manager import PubSubManager
+
+
+class ZmqManager(PubSubManager): # pragma: no cover
+ """zmq based client manager.
+
+ NOTE: this zmq implementation should be considered experimental at this
+ time. At this time, eventlet is required to use zmq.
+
+ This class implements a zmq backend for event sharing across multiple
+ processes. To use a zmq backend, initialize the :class:`Server` instance as
+ follows::
+
+ url = 'zmq+tcp://hostname:port1+port2'
+ server = socketio.Server(client_manager=socketio.ZmqManager(url))
+
+ :param url: The connection URL for the zmq message broker,
+ which will need to be provided and running.
+ :param channel: The channel name on which the server sends and receives
+ notifications. Must be the same in all the servers.
+ :param write_only: If set to ``True``, only initialize to emit events. The
+ default of ``False`` initializes the class for emitting
+ and receiving.
+
+ A zmq message broker must be running for the zmq_manager to work.
+ you can write your own or adapt one from the following simple broker
+ below::
+
+ import zmq
+
+ receiver = zmq.Context().socket(zmq.PULL)
+ receiver.bind("tcp://*:5555")
+
+ publisher = zmq.Context().socket(zmq.PUB)
+ publisher.bind("tcp://*:5556")
+
+ while True:
+ publisher.send(receiver.recv())
+ """
+ name = 'zmq'
+
+ def __init__(self, url='zmq+tcp://localhost:5555+5556',
+ channel='socketio',
+ write_only=False,
+ logger=None):
+ if zmq is None:
+ raise RuntimeError('zmq package is not installed '
+ '(Run "pip install pyzmq" in your '
+ 'virtualenv).')
+
+ r = re.compile(r':\d+\+\d+$')
+ if not (url.startswith('zmq+tcp://') and r.search(url)):
+ raise RuntimeError('unexpected connection string: ' + url)
+
+ url = url.replace('zmq+', '')
+ (sink_url, sub_port) = url.split('+')
+ sink_port = sink_url.split(':')[-1]
+ sub_url = sink_url.replace(sink_port, sub_port)
+
+ sink = zmq.Context().socket(zmq.PUSH)
+ sink.connect(sink_url)
+
+ sub = zmq.Context().socket(zmq.SUB)
+ sub.setsockopt_string(zmq.SUBSCRIBE, u'')
+ sub.connect(sub_url)
+
+ self.sink = sink
+ self.sub = sub
+ self.channel = channel
+ super(ZmqManager, self).__init__(channel=channel,
+ write_only=write_only,
+ logger=logger)
+
+ def _publish(self, data):
+ pickled_data = pickle.dumps(
+ {
+ 'type': 'message',
+ 'channel': self.channel,
+ 'data': data
+ }
+ )
+ return self.sink.send(pickled_data)
+
+ def zmq_listen(self):
+ while True:
+ response = self.sub.recv()
+ if response is not None:
+ yield response
+
+ def _listen(self):
+ for message in self.zmq_listen():
+ if isinstance(message, six.binary_type):
+ try:
+ message = pickle.loads(message)
+ except Exception:
+ pass
+ if isinstance(message, dict) and \
+ message['type'] == 'message' and \
+ message['channel'] == self.channel and \
+ 'data' in message:
+ yield message['data']
+ return
diff --git a/libs/yaml/__init__.py b/libs/yaml/__init__.py
new file mode 100644
index 000000000..5df0bb5fd
--- /dev/null
+++ b/libs/yaml/__init__.py
@@ -0,0 +1,402 @@
+
+from .error import *
+
+from .tokens import *
+from .events import *
+from .nodes import *
+
+from .loader import *
+from .dumper import *
+
+__version__ = '5.1'
+try:
+ from .cyaml import *
+ __with_libyaml__ = True
+except ImportError:
+ __with_libyaml__ = False
+
+import io
+
+#------------------------------------------------------------------------------
+# Warnings control
+#------------------------------------------------------------------------------
+
+# 'Global' warnings state:
+_warnings_enabled = {
+ 'YAMLLoadWarning': True,
+}
+
+# Get or set global warnings' state
+def warnings(settings=None):
+ if settings is None:
+ return _warnings_enabled
+
+ if type(settings) is dict:
+ for key in settings:
+ if key in _warnings_enabled:
+ _warnings_enabled[key] = settings[key]
+
+# Warn when load() is called without Loader=...
+class YAMLLoadWarning(RuntimeWarning):
+ pass
+
+def load_warning(method):
+ if _warnings_enabled['YAMLLoadWarning'] is False:
+ return
+
+ import warnings
+
+ message = (
+ "calling yaml.%s() without Loader=... is deprecated, as the "
+ "default Loader is unsafe. Please read "
+ "https://msg.pyyaml.org/load for full details."
+ ) % method
+
+ warnings.warn(message, YAMLLoadWarning, stacklevel=3)
+
+#------------------------------------------------------------------------------
+def scan(stream, Loader=Loader):
+ """
+ Scan a YAML stream and produce scanning tokens.
+ """
+ loader = Loader(stream)
+ try:
+ while loader.check_token():
+ yield loader.get_token()
+ finally:
+ loader.dispose()
+
+def parse(stream, Loader=Loader):
+ """
+ Parse a YAML stream and produce parsing events.
+ """
+ loader = Loader(stream)
+ try:
+ while loader.check_event():
+ yield loader.get_event()
+ finally:
+ loader.dispose()
+
+def compose(stream, Loader=Loader):
+ """
+ Parse the first YAML document in a stream
+ and produce the corresponding representation tree.
+ """
+ loader = Loader(stream)
+ try:
+ return loader.get_single_node()
+ finally:
+ loader.dispose()
+
+def compose_all(stream, Loader=Loader):
+ """
+ Parse all YAML documents in a stream
+ and produce corresponding representation trees.
+ """
+ loader = Loader(stream)
+ try:
+ while loader.check_node():
+ yield loader.get_node()
+ finally:
+ loader.dispose()
+
+def load(stream, Loader=None):
+ """
+ Parse the first YAML document in a stream
+ and produce the corresponding Python object.
+ """
+ if Loader is None:
+ load_warning('load')
+ Loader = FullLoader
+
+ loader = Loader(stream)
+ try:
+ return loader.get_single_data()
+ finally:
+ loader.dispose()
+
+def load_all(stream, Loader=None):
+ """
+ Parse all YAML documents in a stream
+ and produce corresponding Python objects.
+ """
+ if Loader is None:
+ load_warning('load_all')
+ Loader = FullLoader
+
+ loader = Loader(stream)
+ try:
+ while loader.check_data():
+ yield loader.get_data()
+ finally:
+ loader.dispose()
+
+def full_load(stream):
+ """
+ Parse the first YAML document in a stream
+ and produce the corresponding Python object.
+
+ Resolve all tags except those known to be
+ unsafe on untrusted input.
+ """
+ return load(stream, FullLoader)
+
+def full_load_all(stream):
+ """
+ Parse all YAML documents in a stream
+ and produce corresponding Python objects.
+
+ Resolve all tags except those known to be
+ unsafe on untrusted input.
+ """
+ return load_all(stream, FullLoader)
+
+def safe_load(stream):
+ """
+ Parse the first YAML document in a stream
+ and produce the corresponding Python object.
+
+ Resolve only basic YAML tags. This is known
+ to be safe for untrusted input.
+ """
+ return load(stream, SafeLoader)
+
+def safe_load_all(stream):
+ """
+ Parse all YAML documents in a stream
+ and produce corresponding Python objects.
+
+ Resolve only basic YAML tags. This is known
+ to be safe for untrusted input.
+ """
+ return load_all(stream, SafeLoader)
+
+def unsafe_load(stream):
+ """
+ Parse the first YAML document in a stream
+ and produce the corresponding Python object.
+
+ Resolve all tags, even those known to be
+ unsafe on untrusted input.
+ """
+ return load(stream, UnsafeLoader)
+
+def unsafe_load_all(stream):
+ """
+ Parse all YAML documents in a stream
+ and produce corresponding Python objects.
+
+ Resolve all tags, even those known to be
+ unsafe on untrusted input.
+ """
+ return load_all(stream, UnsafeLoader)
+
+def emit(events, stream=None, Dumper=Dumper,
+ canonical=None, indent=None, width=None,
+ allow_unicode=None, line_break=None):
+ """
+ Emit YAML parsing events into a stream.
+ If stream is None, return the produced string instead.
+ """
+ getvalue = None
+ if stream is None:
+ stream = io.StringIO()
+ getvalue = stream.getvalue
+ dumper = Dumper(stream, canonical=canonical, indent=indent, width=width,
+ allow_unicode=allow_unicode, line_break=line_break)
+ try:
+ for event in events:
+ dumper.emit(event)
+ finally:
+ dumper.dispose()
+ if getvalue:
+ return getvalue()
+
+def serialize_all(nodes, stream=None, Dumper=Dumper,
+ canonical=None, indent=None, width=None,
+ allow_unicode=None, line_break=None,
+ encoding=None, explicit_start=None, explicit_end=None,
+ version=None, tags=None):
+ """
+ Serialize a sequence of representation trees into a YAML stream.
+ If stream is None, return the produced string instead.
+ """
+ getvalue = None
+ if stream is None:
+ if encoding is None:
+ stream = io.StringIO()
+ else:
+ stream = io.BytesIO()
+ getvalue = stream.getvalue
+ dumper = Dumper(stream, canonical=canonical, indent=indent, width=width,
+ allow_unicode=allow_unicode, line_break=line_break,
+ encoding=encoding, version=version, tags=tags,
+ explicit_start=explicit_start, explicit_end=explicit_end)
+ try:
+ dumper.open()
+ for node in nodes:
+ dumper.serialize(node)
+ dumper.close()
+ finally:
+ dumper.dispose()
+ if getvalue:
+ return getvalue()
+
+def serialize(node, stream=None, Dumper=Dumper, **kwds):
+ """
+ Serialize a representation tree into a YAML stream.
+ If stream is None, return the produced string instead.
+ """
+ return serialize_all([node], stream, Dumper=Dumper, **kwds)
+
+def dump_all(documents, stream=None, Dumper=Dumper,
+ default_style=None, default_flow_style=False,
+ canonical=None, indent=None, width=None,
+ allow_unicode=None, line_break=None,
+ encoding=None, explicit_start=None, explicit_end=None,
+ version=None, tags=None, sort_keys=True):
+ """
+ Serialize a sequence of Python objects into a YAML stream.
+ If stream is None, return the produced string instead.
+ """
+ getvalue = None
+ if stream is None:
+ if encoding is None:
+ stream = io.StringIO()
+ else:
+ stream = io.BytesIO()
+ getvalue = stream.getvalue
+ dumper = Dumper(stream, default_style=default_style,
+ default_flow_style=default_flow_style,
+ canonical=canonical, indent=indent, width=width,
+ allow_unicode=allow_unicode, line_break=line_break,
+ encoding=encoding, version=version, tags=tags,
+ explicit_start=explicit_start, explicit_end=explicit_end, sort_keys=sort_keys)
+ try:
+ dumper.open()
+ for data in documents:
+ dumper.represent(data)
+ dumper.close()
+ finally:
+ dumper.dispose()
+ if getvalue:
+ return getvalue()
+
+def dump(data, stream=None, Dumper=Dumper, **kwds):
+ """
+ Serialize a Python object into a YAML stream.
+ If stream is None, return the produced string instead.
+ """
+ return dump_all([data], stream, Dumper=Dumper, **kwds)
+
+def safe_dump_all(documents, stream=None, **kwds):
+ """
+ Serialize a sequence of Python objects into a YAML stream.
+ Produce only basic YAML tags.
+ If stream is None, return the produced string instead.
+ """
+ return dump_all(documents, stream, Dumper=SafeDumper, **kwds)
+
+def safe_dump(data, stream=None, **kwds):
+ """
+ Serialize a Python object into a YAML stream.
+ Produce only basic YAML tags.
+ If stream is None, return the produced string instead.
+ """
+ return dump_all([data], stream, Dumper=SafeDumper, **kwds)
+
+def add_implicit_resolver(tag, regexp, first=None,
+ Loader=Loader, Dumper=Dumper):
+ """
+ Add an implicit scalar detector.
+ If an implicit scalar value matches the given regexp,
+ the corresponding tag is assigned to the scalar.
+ first is a sequence of possible initial characters or None.
+ """
+ Loader.add_implicit_resolver(tag, regexp, first)
+ Dumper.add_implicit_resolver(tag, regexp, first)
+
+def add_path_resolver(tag, path, kind=None, Loader=Loader, Dumper=Dumper):
+ """
+ Add a path based resolver for the given tag.
+ A path is a list of keys that forms a path
+ to a node in the representation tree.
+ Keys can be string values, integers, or None.
+ """
+ Loader.add_path_resolver(tag, path, kind)
+ Dumper.add_path_resolver(tag, path, kind)
+
+def add_constructor(tag, constructor, Loader=Loader):
+ """
+ Add a constructor for the given tag.
+ Constructor is a function that accepts a Loader instance
+ and a node object and produces the corresponding Python object.
+ """
+ Loader.add_constructor(tag, constructor)
+
+def add_multi_constructor(tag_prefix, multi_constructor, Loader=Loader):
+ """
+ Add a multi-constructor for the given tag prefix.
+ Multi-constructor is called for a node if its tag starts with tag_prefix.
+ Multi-constructor accepts a Loader instance, a tag suffix,
+ and a node object and produces the corresponding Python object.
+ """
+ Loader.add_multi_constructor(tag_prefix, multi_constructor)
+
+def add_representer(data_type, representer, Dumper=Dumper):
+ """
+ Add a representer for the given type.
+ Representer is a function accepting a Dumper instance
+ and an instance of the given data type
+ and producing the corresponding representation node.
+ """
+ Dumper.add_representer(data_type, representer)
+
+def add_multi_representer(data_type, multi_representer, Dumper=Dumper):
+ """
+ Add a representer for the given type.
+ Multi-representer is a function accepting a Dumper instance
+ and an instance of the given data type or subtype
+ and producing the corresponding representation node.
+ """
+ Dumper.add_multi_representer(data_type, multi_representer)
+
+class YAMLObjectMetaclass(type):
+ """
+ The metaclass for YAMLObject.
+ """
+ def __init__(cls, name, bases, kwds):
+ super(YAMLObjectMetaclass, cls).__init__(name, bases, kwds)
+ if 'yaml_tag' in kwds and kwds['yaml_tag'] is not None:
+ cls.yaml_loader.add_constructor(cls.yaml_tag, cls.from_yaml)
+ cls.yaml_dumper.add_representer(cls, cls.to_yaml)
+
+class YAMLObject(metaclass=YAMLObjectMetaclass):
+ """
+ An object that can dump itself to a YAML stream
+ and load itself from a YAML stream.
+ """
+
+ __slots__ = () # no direct instantiation, so allow immutable subclasses
+
+ yaml_loader = Loader
+ yaml_dumper = Dumper
+
+ yaml_tag = None
+ yaml_flow_style = None
+
+ @classmethod
+ def from_yaml(cls, loader, node):
+ """
+ Convert a representation node to a Python object.
+ """
+ return loader.construct_yaml_object(node, cls)
+
+ @classmethod
+ def to_yaml(cls, dumper, data):
+ """
+ Convert a Python object to a representation node.
+ """
+ return dumper.represent_yaml_object(cls.yaml_tag, data, cls,
+ flow_style=cls.yaml_flow_style)
+
diff --git a/libs/yaml/composer.py b/libs/yaml/composer.py
new file mode 100644
index 000000000..6d15cb40e
--- /dev/null
+++ b/libs/yaml/composer.py
@@ -0,0 +1,139 @@
+
+__all__ = ['Composer', 'ComposerError']
+
+from .error import MarkedYAMLError
+from .events import *
+from .nodes import *
+
+class ComposerError(MarkedYAMLError):
+ pass
+
+class Composer:
+
+ def __init__(self):
+ self.anchors = {}
+
+ def check_node(self):
+ # Drop the STREAM-START event.
+ if self.check_event(StreamStartEvent):
+ self.get_event()
+
+ # If there are more documents available?
+ return not self.check_event(StreamEndEvent)
+
+ def get_node(self):
+ # Get the root node of the next document.
+ if not self.check_event(StreamEndEvent):
+ return self.compose_document()
+
+ def get_single_node(self):
+ # Drop the STREAM-START event.
+ self.get_event()
+
+ # Compose a document if the stream is not empty.
+ document = None
+ if not self.check_event(StreamEndEvent):
+ document = self.compose_document()
+
+ # Ensure that the stream contains no more documents.
+ if not self.check_event(StreamEndEvent):
+ event = self.get_event()
+ raise ComposerError("expected a single document in the stream",
+ document.start_mark, "but found another document",
+ event.start_mark)
+
+ # Drop the STREAM-END event.
+ self.get_event()
+
+ return document
+
+ def compose_document(self):
+ # Drop the DOCUMENT-START event.
+ self.get_event()
+
+ # Compose the root node.
+ node = self.compose_node(None, None)
+
+ # Drop the DOCUMENT-END event.
+ self.get_event()
+
+ self.anchors = {}
+ return node
+
+ def compose_node(self, parent, index):
+ if self.check_event(AliasEvent):
+ event = self.get_event()
+ anchor = event.anchor
+ if anchor not in self.anchors:
+ raise ComposerError(None, None, "found undefined alias %r"
+ % anchor, event.start_mark)
+ return self.anchors[anchor]
+ event = self.peek_event()
+ anchor = event.anchor
+ if anchor is not None:
+ if anchor in self.anchors:
+ raise ComposerError("found duplicate anchor %r; first occurrence"
+ % anchor, self.anchors[anchor].start_mark,
+ "second occurrence", event.start_mark)
+ self.descend_resolver(parent, index)
+ if self.check_event(ScalarEvent):
+ node = self.compose_scalar_node(anchor)
+ elif self.check_event(SequenceStartEvent):
+ node = self.compose_sequence_node(anchor)
+ elif self.check_event(MappingStartEvent):
+ node = self.compose_mapping_node(anchor)
+ self.ascend_resolver()
+ return node
+
+ def compose_scalar_node(self, anchor):
+ event = self.get_event()
+ tag = event.tag
+ if tag is None or tag == '!':
+ tag = self.resolve(ScalarNode, event.value, event.implicit)
+ node = ScalarNode(tag, event.value,
+ event.start_mark, event.end_mark, style=event.style)
+ if anchor is not None:
+ self.anchors[anchor] = node
+ return node
+
+ def compose_sequence_node(self, anchor):
+ start_event = self.get_event()
+ tag = start_event.tag
+ if tag is None or tag == '!':
+ tag = self.resolve(SequenceNode, None, start_event.implicit)
+ node = SequenceNode(tag, [],
+ start_event.start_mark, None,
+ flow_style=start_event.flow_style)
+ if anchor is not None:
+ self.anchors[anchor] = node
+ index = 0
+ while not self.check_event(SequenceEndEvent):
+ node.value.append(self.compose_node(node, index))
+ index += 1
+ end_event = self.get_event()
+ node.end_mark = end_event.end_mark
+ return node
+
+ def compose_mapping_node(self, anchor):
+ start_event = self.get_event()
+ tag = start_event.tag
+ if tag is None or tag == '!':
+ tag = self.resolve(MappingNode, None, start_event.implicit)
+ node = MappingNode(tag, [],
+ start_event.start_mark, None,
+ flow_style=start_event.flow_style)
+ if anchor is not None:
+ self.anchors[anchor] = node
+ while not self.check_event(MappingEndEvent):
+ #key_event = self.peek_event()
+ item_key = self.compose_node(node, None)
+ #if item_key in node.value:
+ # raise ComposerError("while composing a mapping", start_event.start_mark,
+ # "found duplicate key", key_event.start_mark)
+ item_value = self.compose_node(node, item_key)
+ #node.value[item_key] = item_value
+ node.value.append((item_key, item_value))
+ end_event = self.get_event()
+ node.end_mark = end_event.end_mark
+ return node
+
diff --git a/libs/yaml/constructor.py b/libs/yaml/constructor.py
new file mode 100644
index 000000000..34fc1ae92
--- /dev/null
+++ b/libs/yaml/constructor.py
@@ -0,0 +1,720 @@
+
+__all__ = [
+ 'BaseConstructor',
+ 'SafeConstructor',
+ 'FullConstructor',
+ 'UnsafeConstructor',
+ 'Constructor',
+ 'ConstructorError'
+]
+
+from .error import *
+from .nodes import *
+
+import collections.abc, datetime, base64, binascii, re, sys, types
+
+class ConstructorError(MarkedYAMLError):
+ pass
+
+class BaseConstructor:
+
+ yaml_constructors = {}
+ yaml_multi_constructors = {}
+
+ def __init__(self):
+ self.constructed_objects = {}
+ self.recursive_objects = {}
+ self.state_generators = []
+ self.deep_construct = False
+
+ def check_data(self):
+ # If there are more documents available?
+ return self.check_node()
+
+ def get_data(self):
+ # Construct and return the next document.
+ if self.check_node():
+ return self.construct_document(self.get_node())
+
+ def get_single_data(self):
+ # Ensure that the stream contains a single document and construct it.
+ node = self.get_single_node()
+ if node is not None:
+ return self.construct_document(node)
+ return None
+
+ def construct_document(self, node):
+ data = self.construct_object(node)
+ while self.state_generators:
+ state_generators = self.state_generators
+ self.state_generators = []
+ for generator in state_generators:
+ for dummy in generator:
+ pass
+ self.constructed_objects = {}
+ self.recursive_objects = {}
+ self.deep_construct = False
+ return data
+
+ def construct_object(self, node, deep=False):
+ if node in self.constructed_objects:
+ return self.constructed_objects[node]
+ if deep:
+ old_deep = self.deep_construct
+ self.deep_construct = True
+ if node in self.recursive_objects:
+ raise ConstructorError(None, None,
+ "found unconstructable recursive node", node.start_mark)
+ self.recursive_objects[node] = None
+ constructor = None
+ tag_suffix = None
+ if node.tag in self.yaml_constructors:
+ constructor = self.yaml_constructors[node.tag]
+ else:
+ for tag_prefix in self.yaml_multi_constructors:
+ if node.tag.startswith(tag_prefix):
+ tag_suffix = node.tag[len(tag_prefix):]
+ constructor = self.yaml_multi_constructors[tag_prefix]
+ break
+ else:
+ if None in self.yaml_multi_constructors:
+ tag_suffix = node.tag
+ constructor = self.yaml_multi_constructors[None]
+ elif None in self.yaml_constructors:
+ constructor = self.yaml_constructors[None]
+ elif isinstance(node, ScalarNode):
+ constructor = self.__class__.construct_scalar
+ elif isinstance(node, SequenceNode):
+ constructor = self.__class__.construct_sequence
+ elif isinstance(node, MappingNode):
+ constructor = self.__class__.construct_mapping
+ if tag_suffix is None:
+ data = constructor(self, node)
+ else:
+ data = constructor(self, tag_suffix, node)
+ if isinstance(data, types.GeneratorType):
+ generator = data
+ data = next(generator)
+ if self.deep_construct:
+ for dummy in generator:
+ pass
+ else:
+ self.state_generators.append(generator)
+ self.constructed_objects[node] = data
+ del self.recursive_objects[node]
+ if deep:
+ self.deep_construct = old_deep
+ return data
+
+ def construct_scalar(self, node):
+ if not isinstance(node, ScalarNode):
+ raise ConstructorError(None, None,
+ "expected a scalar node, but found %s" % node.id,
+ node.start_mark)
+ return node.value
+
+ def construct_sequence(self, node, deep=False):
+ if not isinstance(node, SequenceNode):
+ raise ConstructorError(None, None,
+ "expected a sequence node, but found %s" % node.id,
+ node.start_mark)
+ return [self.construct_object(child, deep=deep)
+ for child in node.value]
+
+ def construct_mapping(self, node, deep=False):
+ if not isinstance(node, MappingNode):
+ raise ConstructorError(None, None,
+ "expected a mapping node, but found %s" % node.id,
+ node.start_mark)
+ mapping = {}
+ for key_node, value_node in node.value:
+ key = self.construct_object(key_node, deep=deep)
+ if not isinstance(key, collections.abc.Hashable):
+ raise ConstructorError("while constructing a mapping", node.start_mark,
+ "found unhashable key", key_node.start_mark)
+ value = self.construct_object(value_node, deep=deep)
+ mapping[key] = value
+ return mapping
+
+ def construct_pairs(self, node, deep=False):
+ if not isinstance(node, MappingNode):
+ raise ConstructorError(None, None,
+ "expected a mapping node, but found %s" % node.id,
+ node.start_mark)
+ pairs = []
+ for key_node, value_node in node.value:
+ key = self.construct_object(key_node, deep=deep)
+ value = self.construct_object(value_node, deep=deep)
+ pairs.append((key, value))
+ return pairs
+
+ @classmethod
+ def add_constructor(cls, tag, constructor):
+ if not 'yaml_constructors' in cls.__dict__:
+ cls.yaml_constructors = cls.yaml_constructors.copy()
+ cls.yaml_constructors[tag] = constructor
+
+ @classmethod
+ def add_multi_constructor(cls, tag_prefix, multi_constructor):
+ if not 'yaml_multi_constructors' in cls.__dict__:
+ cls.yaml_multi_constructors = cls.yaml_multi_constructors.copy()
+ cls.yaml_multi_constructors[tag_prefix] = multi_constructor
+
+class SafeConstructor(BaseConstructor):
+
+ def construct_scalar(self, node):
+ if isinstance(node, MappingNode):
+ for key_node, value_node in node.value:
+ if key_node.tag == 'tag:yaml.org,2002:value':
+ return self.construct_scalar(value_node)
+ return super().construct_scalar(node)
+
+ def flatten_mapping(self, node):
+ merge = []
+ index = 0
+ while index < len(node.value):
+ key_node, value_node = node.value[index]
+ if key_node.tag == 'tag:yaml.org,2002:merge':
+ del node.value[index]
+ if isinstance(value_node, MappingNode):
+ self.flatten_mapping(value_node)
+ merge.extend(value_node.value)
+ elif isinstance(value_node, SequenceNode):
+ submerge = []
+ for subnode in value_node.value:
+ if not isinstance(subnode, MappingNode):
+ raise ConstructorError("while constructing a mapping",
+ node.start_mark,
+ "expected a mapping for merging, but found %s"
+ % subnode.id, subnode.start_mark)
+ self.flatten_mapping(subnode)
+ submerge.append(subnode.value)
+ submerge.reverse()
+ for value in submerge:
+ merge.extend(value)
+ else:
+ raise ConstructorError("while constructing a mapping", node.start_mark,
+ "expected a mapping or list of mappings for merging, but found %s"
+ % value_node.id, value_node.start_mark)
+ elif key_node.tag == 'tag:yaml.org,2002:value':
+ key_node.tag = 'tag:yaml.org,2002:str'
+ index += 1
+ else:
+ index += 1
+ if merge:
+ node.value = merge + node.value
+
+ def construct_mapping(self, node, deep=False):
+ if isinstance(node, MappingNode):
+ self.flatten_mapping(node)
+ return super().construct_mapping(node, deep=deep)
+
+ def construct_yaml_null(self, node):
+ self.construct_scalar(node)
+ return None
+
+ bool_values = {
+ 'yes': True,
+ 'no': False,
+ 'true': True,
+ 'false': False,
+ 'on': True,
+ 'off': False,
+ }
+
+ def construct_yaml_bool(self, node):
+ value = self.construct_scalar(node)
+ return self.bool_values[value.lower()]
+
+ def construct_yaml_int(self, node):
+ value = self.construct_scalar(node)
+ value = value.replace('_', '')
+ sign = +1
+ if value[0] == '-':
+ sign = -1
+ if value[0] in '+-':
+ value = value[1:]
+ if value == '0':
+ return 0
+ elif value.startswith('0b'):
+ return sign*int(value[2:], 2)
+ elif value.startswith('0x'):
+ return sign*int(value[2:], 16)
+ elif value[0] == '0':
+ return sign*int(value, 8)
+ elif ':' in value:
+ digits = [int(part) for part in value.split(':')]
+ digits.reverse()
+ base = 1
+ value = 0
+ for digit in digits:
+ value += digit*base
+ base *= 60
+ return sign*value
+ else:
+ return sign*int(value)
+
+ inf_value = 1e300
+ while inf_value != inf_value*inf_value:
+ inf_value *= inf_value
+ nan_value = -inf_value/inf_value # Trying to make a quiet NaN (like C99).
+
+ def construct_yaml_float(self, node):
+ value = self.construct_scalar(node)
+ value = value.replace('_', '').lower()
+ sign = +1
+ if value[0] == '-':
+ sign = -1
+ if value[0] in '+-':
+ value = value[1:]
+ if value == '.inf':
+ return sign*self.inf_value
+ elif value == '.nan':
+ return self.nan_value
+ elif ':' in value:
+ digits = [float(part) for part in value.split(':')]
+ digits.reverse()
+ base = 1
+ value = 0.0
+ for digit in digits:
+ value += digit*base
+ base *= 60
+ return sign*value
+ else:
+ return sign*float(value)
+
+ def construct_yaml_binary(self, node):
+ try:
+ value = self.construct_scalar(node).encode('ascii')
+ except UnicodeEncodeError as exc:
+ raise ConstructorError(None, None,
+ "failed to convert base64 data into ascii: %s" % exc,
+ node.start_mark)
+ try:
+ if hasattr(base64, 'decodebytes'):
+ return base64.decodebytes(value)
+ else:
+ return base64.decodestring(value)
+ except binascii.Error as exc:
+ raise ConstructorError(None, None,
+ "failed to decode base64 data: %s" % exc, node.start_mark)
+
+ timestamp_regexp = re.compile(
+ r'''^(?P<year>[0-9][0-9][0-9][0-9])
+ -(?P<month>[0-9][0-9]?)
+ -(?P<day>[0-9][0-9]?)
+ (?:(?:[Tt]|[ \t]+)
+ (?P<hour>[0-9][0-9]?)
+ :(?P<minute>[0-9][0-9])
+ :(?P<second>[0-9][0-9])
+ (?:\.(?P<fraction>[0-9]*))?
+ (?:[ \t]*(?P<tz>Z|(?P<tz_sign>[-+])(?P<tz_hour>[0-9][0-9]?)
+ (?::(?P<tz_minute>[0-9][0-9]))?))?)?$''', re.X)
+
+ def construct_yaml_timestamp(self, node):
+ value = self.construct_scalar(node)
+ match = self.timestamp_regexp.match(node.value)
+ values = match.groupdict()
+ year = int(values['year'])
+ month = int(values['month'])
+ day = int(values['day'])
+ if not values['hour']:
+ return datetime.date(year, month, day)
+ hour = int(values['hour'])
+ minute = int(values['minute'])
+ second = int(values['second'])
+ fraction = 0
+ if values['fraction']:
+ fraction = values['fraction'][:6]
+ while len(fraction) < 6:
+ fraction += '0'
+ fraction = int(fraction)
+ delta = None
+ if values['tz_sign']:
+ tz_hour = int(values['tz_hour'])
+ tz_minute = int(values['tz_minute'] or 0)
+ delta = datetime.timedelta(hours=tz_hour, minutes=tz_minute)
+ if values['tz_sign'] == '-':
+ delta = -delta
+ data = datetime.datetime(year, month, day, hour, minute, second, fraction)
+ if delta:
+ data -= delta
+ return data
+
+ def construct_yaml_omap(self, node):
+ # Note: we do not check for duplicate keys, because it's too
+ # CPU-expensive.
+ omap = []
+ yield omap
+ if not isinstance(node, SequenceNode):
+ raise ConstructorError("while constructing an ordered map", node.start_mark,
+ "expected a sequence, but found %s" % node.id, node.start_mark)
+ for subnode in node.value:
+ if not isinstance(subnode, MappingNode):
+ raise ConstructorError("while constructing an ordered map", node.start_mark,
+ "expected a mapping of length 1, but found %s" % subnode.id,
+ subnode.start_mark)
+ if len(subnode.value) != 1:
+ raise ConstructorError("while constructing an ordered map", node.start_mark,
+ "expected a single mapping item, but found %d items" % len(subnode.value),
+ subnode.start_mark)
+ key_node, value_node = subnode.value[0]
+ key = self.construct_object(key_node)
+ value = self.construct_object(value_node)
+ omap.append((key, value))
+
+ def construct_yaml_pairs(self, node):
+ # Note: the same code as `construct_yaml_omap`.
+ pairs = []
+ yield pairs
+ if not isinstance(node, SequenceNode):
+ raise ConstructorError("while constructing pairs", node.start_mark,
+ "expected a sequence, but found %s" % node.id, node.start_mark)
+ for subnode in node.value:
+ if not isinstance(subnode, MappingNode):
+ raise ConstructorError("while constructing pairs", node.start_mark,
+ "expected a mapping of length 1, but found %s" % subnode.id,
+ subnode.start_mark)
+ if len(subnode.value) != 1:
+ raise ConstructorError("while constructing pairs", node.start_mark,
+ "expected a single mapping item, but found %d items" % len(subnode.value),
+ subnode.start_mark)
+ key_node, value_node = subnode.value[0]
+ key = self.construct_object(key_node)
+ value = self.construct_object(value_node)
+ pairs.append((key, value))
+
+ def construct_yaml_set(self, node):
+ data = set()
+ yield data
+ value = self.construct_mapping(node)
+ data.update(value)
+
+ def construct_yaml_str(self, node):
+ return self.construct_scalar(node)
+
+ def construct_yaml_seq(self, node):
+ data = []
+ yield data
+ data.extend(self.construct_sequence(node))
+
+ def construct_yaml_map(self, node):
+ data = {}
+ yield data
+ value = self.construct_mapping(node)
+ data.update(value)
+
+ def construct_yaml_object(self, node, cls):
+ data = cls.__new__(cls)
+ yield data
+ if hasattr(data, '__setstate__'):
+ state = self.construct_mapping(node, deep=True)
+ data.__setstate__(state)
+ else:
+ state = self.construct_mapping(node)
+ data.__dict__.update(state)
+
+ def construct_undefined(self, node):
+ raise ConstructorError(None, None,
+ "could not determine a constructor for the tag %r" % node.tag,
+ node.start_mark)
+
+SafeConstructor.add_constructor(
+ 'tag:yaml.org,2002:null',
+ SafeConstructor.construct_yaml_null)
+
+SafeConstructor.add_constructor(
+ 'tag:yaml.org,2002:bool',
+ SafeConstructor.construct_yaml_bool)
+
+SafeConstructor.add_constructor(
+ 'tag:yaml.org,2002:int',
+ SafeConstructor.construct_yaml_int)
+
+SafeConstructor.add_constructor(
+ 'tag:yaml.org,2002:float',
+ SafeConstructor.construct_yaml_float)
+
+SafeConstructor.add_constructor(
+ 'tag:yaml.org,2002:binary',
+ SafeConstructor.construct_yaml_binary)
+
+SafeConstructor.add_constructor(
+ 'tag:yaml.org,2002:timestamp',
+ SafeConstructor.construct_yaml_timestamp)
+
+SafeConstructor.add_constructor(
+ 'tag:yaml.org,2002:omap',
+ SafeConstructor.construct_yaml_omap)
+
+SafeConstructor.add_constructor(
+ 'tag:yaml.org,2002:pairs',
+ SafeConstructor.construct_yaml_pairs)
+
+SafeConstructor.add_constructor(
+ 'tag:yaml.org,2002:set',
+ SafeConstructor.construct_yaml_set)
+
+SafeConstructor.add_constructor(
+ 'tag:yaml.org,2002:str',
+ SafeConstructor.construct_yaml_str)
+
+SafeConstructor.add_constructor(
+ 'tag:yaml.org,2002:seq',
+ SafeConstructor.construct_yaml_seq)
+
+SafeConstructor.add_constructor(
+ 'tag:yaml.org,2002:map',
+ SafeConstructor.construct_yaml_map)
+
+SafeConstructor.add_constructor(None,
+ SafeConstructor.construct_undefined)
+
+class FullConstructor(SafeConstructor):
+
+ def construct_python_str(self, node):
+ return self.construct_scalar(node)
+
+ def construct_python_unicode(self, node):
+ return self.construct_scalar(node)
+
+ def construct_python_bytes(self, node):
+ try:
+ value = self.construct_scalar(node).encode('ascii')
+ except UnicodeEncodeError as exc:
+ raise ConstructorError(None, None,
+ "failed to convert base64 data into ascii: %s" % exc,
+ node.start_mark)
+ try:
+ if hasattr(base64, 'decodebytes'):
+ return base64.decodebytes(value)
+ else:
+ return base64.decodestring(value)
+ except binascii.Error as exc:
+ raise ConstructorError(None, None,
+ "failed to decode base64 data: %s" % exc, node.start_mark)
+
+ def construct_python_long(self, node):
+ return self.construct_yaml_int(node)
+
+ def construct_python_complex(self, node):
+ return complex(self.construct_scalar(node))
+
+ def construct_python_tuple(self, node):
+ return tuple(self.construct_sequence(node))
+
+ def find_python_module(self, name, mark, unsafe=False):
+ if not name:
+ raise ConstructorError("while constructing a Python module", mark,
+ "expected non-empty name appended to the tag", mark)
+ if unsafe:
+ try:
+ __import__(name)
+ except ImportError as exc:
+ raise ConstructorError("while constructing a Python module", mark,
+ "cannot find module %r (%s)" % (name, exc), mark)
+ if not name in sys.modules:
+ raise ConstructorError("while constructing a Python module", mark,
+ "module %r is not imported" % name, mark)
+ return sys.modules[name]
+
+ def find_python_name(self, name, mark, unsafe=False):
+ if not name:
+ raise ConstructorError("while constructing a Python object", mark,
+ "expected non-empty name appended to the tag", mark)
+ if '.' in name:
+ module_name, object_name = name.rsplit('.', 1)
+ else:
+ module_name = 'builtins'
+ object_name = name
+ if unsafe:
+ try:
+ __import__(module_name)
+ except ImportError as exc:
+ raise ConstructorError("while constructing a Python object", mark,
+ "cannot find module %r (%s)" % (module_name, exc), mark)
+ if not module_name in sys.modules:
+ raise ConstructorError("while constructing a Python object", mark,
+ "module %r is not imported" % module_name, mark)
+ module = sys.modules[module_name]
+ if not hasattr(module, object_name):
+ raise ConstructorError("while constructing a Python object", mark,
+ "cannot find %r in the module %r"
+ % (object_name, module.__name__), mark)
+ return getattr(module, object_name)
+
+ def construct_python_name(self, suffix, node):
+ value = self.construct_scalar(node)
+ if value:
+ raise ConstructorError("while constructing a Python name", node.start_mark,
+ "expected the empty value, but found %r" % value, node.start_mark)
+ return self.find_python_name(suffix, node.start_mark)
+
+ def construct_python_module(self, suffix, node):
+ value = self.construct_scalar(node)
+ if value:
+ raise ConstructorError("while constructing a Python module", node.start_mark,
+ "expected the empty value, but found %r" % value, node.start_mark)
+ return self.find_python_module(suffix, node.start_mark)
+
+ def make_python_instance(self, suffix, node,
+ args=None, kwds=None, newobj=False, unsafe=False):
+ if not args:
+ args = []
+ if not kwds:
+ kwds = {}
+ cls = self.find_python_name(suffix, node.start_mark)
+ if not (unsafe or isinstance(cls, type)):
+ raise ConstructorError("while constructing a Python instance", node.start_mark,
+ "expected a class, but found %r" % type(cls),
+ node.start_mark)
+ if newobj and isinstance(cls, type):
+ return cls.__new__(cls, *args, **kwds)
+ else:
+ return cls(*args, **kwds)
+
+ def set_python_instance_state(self, instance, state):
+ if hasattr(instance, '__setstate__'):
+ instance.__setstate__(state)
+ else:
+ slotstate = {}
+ if isinstance(state, tuple) and len(state) == 2:
+ state, slotstate = state
+ if hasattr(instance, '__dict__'):
+ instance.__dict__.update(state)
+ elif state:
+ slotstate.update(state)
+ for key, value in slotstate.items():
+ setattr(object, key, value)
+
+ def construct_python_object(self, suffix, node):
+ # Format:
+ # !!python/object:module.name { ... state ... }
+ instance = self.make_python_instance(suffix, node, newobj=True)
+ yield instance
+ deep = hasattr(instance, '__setstate__')
+ state = self.construct_mapping(node, deep=deep)
+ self.set_python_instance_state(instance, state)
+
+ def construct_python_object_apply(self, suffix, node, newobj=False):
+ # Format:
+ # !!python/object/apply # (or !!python/object/new)
+ # args: [ ... arguments ... ]
+ # kwds: { ... keywords ... }
+ # state: ... state ...
+ # listitems: [ ... listitems ... ]
+ # dictitems: { ... dictitems ... }
+ # or short format:
+ # !!python/object/apply [ ... arguments ... ]
+ # The difference between !!python/object/apply and !!python/object/new
+ # is how an object is created, check make_python_instance for details.
+ if isinstance(node, SequenceNode):
+ args = self.construct_sequence(node, deep=True)
+ kwds = {}
+ state = {}
+ listitems = []
+ dictitems = {}
+ else:
+ value = self.construct_mapping(node, deep=True)
+ args = value.get('args', [])
+ kwds = value.get('kwds', {})
+ state = value.get('state', {})
+ listitems = value.get('listitems', [])
+ dictitems = value.get('dictitems', {})
+ instance = self.make_python_instance(suffix, node, args, kwds, newobj)
+ if state:
+ self.set_python_instance_state(instance, state)
+ if listitems:
+ instance.extend(listitems)
+ if dictitems:
+ for key in dictitems:
+ instance[key] = dictitems[key]
+ return instance
+
+ def construct_python_object_new(self, suffix, node):
+ return self.construct_python_object_apply(suffix, node, newobj=True)
+
+FullConstructor.add_constructor(
+ 'tag:yaml.org,2002:python/none',
+ FullConstructor.construct_yaml_null)
+
+FullConstructor.add_constructor(
+ 'tag:yaml.org,2002:python/bool',
+ FullConstructor.construct_yaml_bool)
+
+FullConstructor.add_constructor(
+ 'tag:yaml.org,2002:python/str',
+ FullConstructor.construct_python_str)
+
+FullConstructor.add_constructor(
+ 'tag:yaml.org,2002:python/unicode',
+ FullConstructor.construct_python_unicode)
+
+FullConstructor.add_constructor(
+ 'tag:yaml.org,2002:python/bytes',
+ FullConstructor.construct_python_bytes)
+
+FullConstructor.add_constructor(
+ 'tag:yaml.org,2002:python/int',
+ FullConstructor.construct_yaml_int)
+
+FullConstructor.add_constructor(
+ 'tag:yaml.org,2002:python/long',
+ FullConstructor.construct_python_long)
+
+FullConstructor.add_constructor(
+ 'tag:yaml.org,2002:python/float',
+ FullConstructor.construct_yaml_float)
+
+FullConstructor.add_constructor(
+ 'tag:yaml.org,2002:python/complex',
+ FullConstructor.construct_python_complex)
+
+FullConstructor.add_constructor(
+ 'tag:yaml.org,2002:python/list',
+ FullConstructor.construct_yaml_seq)
+
+FullConstructor.add_constructor(
+ 'tag:yaml.org,2002:python/tuple',
+ FullConstructor.construct_python_tuple)
+
+FullConstructor.add_constructor(
+ 'tag:yaml.org,2002:python/dict',
+ FullConstructor.construct_yaml_map)
+
+FullConstructor.add_multi_constructor(
+ 'tag:yaml.org,2002:python/name:',
+ FullConstructor.construct_python_name)
+
+FullConstructor.add_multi_constructor(
+ 'tag:yaml.org,2002:python/module:',
+ FullConstructor.construct_python_module)
+
+FullConstructor.add_multi_constructor(
+ 'tag:yaml.org,2002:python/object:',
+ FullConstructor.construct_python_object)
+
+FullConstructor.add_multi_constructor(
+ 'tag:yaml.org,2002:python/object/apply:',
+ FullConstructor.construct_python_object_apply)
+
+FullConstructor.add_multi_constructor(
+ 'tag:yaml.org,2002:python/object/new:',
+ FullConstructor.construct_python_object_new)
+
+class UnsafeConstructor(FullConstructor):
+
+ def find_python_module(self, name, mark):
+ return super(UnsafeConstructor, self).find_python_module(name, mark, unsafe=True)
+
+ def find_python_name(self, name, mark):
+ return super(UnsafeConstructor, self).find_python_name(name, mark, unsafe=True)
+
+ def make_python_instance(self, suffix, node, args=None, kwds=None, newobj=False):
+ return super(UnsafeConstructor, self).make_python_instance(
+ suffix, node, args, kwds, newobj, unsafe=True)
+
+# Constructor is same as UnsafeConstructor. Need to leave this in place in case
+# people have extended it directly.
+class Constructor(UnsafeConstructor):
+ pass
diff --git a/libs/yaml/cyaml.py b/libs/yaml/cyaml.py
new file mode 100644
index 000000000..1e606c74b
--- /dev/null
+++ b/libs/yaml/cyaml.py
@@ -0,0 +1,101 @@
+
+__all__ = [
+ 'CBaseLoader', 'CSafeLoader', 'CFullLoader', 'CUnsafeLoader', 'CLoader',
+ 'CBaseDumper', 'CSafeDumper', 'CDumper'
+]
+
+from _yaml import CParser, CEmitter
+
+from .constructor import *
+
+from .serializer import *
+from .representer import *
+
+from .resolver import *
+
+class CBaseLoader(CParser, BaseConstructor, BaseResolver):
+
+ def __init__(self, stream):
+ CParser.__init__(self, stream)
+ BaseConstructor.__init__(self)
+ BaseResolver.__init__(self)
+
+class CSafeLoader(CParser, SafeConstructor, Resolver):
+
+ def __init__(self, stream):
+ CParser.__init__(self, stream)
+ SafeConstructor.__init__(self)
+ Resolver.__init__(self)
+
+class CFullLoader(CParser, FullConstructor, Resolver):
+
+ def __init__(self, stream):
+ CParser.__init__(self, stream)
+ FullConstructor.__init__(self)
+ Resolver.__init__(self)
+
+class CUnsafeLoader(CParser, UnsafeConstructor, Resolver):
+
+ def __init__(self, stream):
+ CParser.__init__(self, stream)
+ UnsafeConstructor.__init__(self)
+ Resolver.__init__(self)
+
+class CLoader(CParser, Constructor, Resolver):
+
+ def __init__(self, stream):
+ CParser.__init__(self, stream)
+ Constructor.__init__(self)
+ Resolver.__init__(self)
+
+class CBaseDumper(CEmitter, BaseRepresenter, BaseResolver):
+
+ def __init__(self, stream,
+ default_style=None, default_flow_style=False,
+ canonical=None, indent=None, width=None,
+ allow_unicode=None, line_break=None,
+ encoding=None, explicit_start=None, explicit_end=None,
+ version=None, tags=None, sort_keys=True):
+ CEmitter.__init__(self, stream, canonical=canonical,
+ indent=indent, width=width, encoding=encoding,
+ allow_unicode=allow_unicode, line_break=line_break,
+ explicit_start=explicit_start, explicit_end=explicit_end,
+ version=version, tags=tags)
+ Representer.__init__(self, default_style=default_style,
+ default_flow_style=default_flow_style, sort_keys=sort_keys)
+ Resolver.__init__(self)
+
+class CSafeDumper(CEmitter, SafeRepresenter, Resolver):
+
+ def __init__(self, stream,
+ default_style=None, default_flow_style=False,
+ canonical=None, indent=None, width=None,
+ allow_unicode=None, line_break=None,
+ encoding=None, explicit_start=None, explicit_end=None,
+ version=None, tags=None, sort_keys=True):
+ CEmitter.__init__(self, stream, canonical=canonical,
+ indent=indent, width=width, encoding=encoding,
+ allow_unicode=allow_unicode, line_break=line_break,
+ explicit_start=explicit_start, explicit_end=explicit_end,
+ version=version, tags=tags)
+ SafeRepresenter.__init__(self, default_style=default_style,
+ default_flow_style=default_flow_style, sort_keys=sort_keys)
+ Resolver.__init__(self)
+
+class CDumper(CEmitter, Serializer, Representer, Resolver):
+
+ def __init__(self, stream,
+ default_style=None, default_flow_style=False,
+ canonical=None, indent=None, width=None,
+ allow_unicode=None, line_break=None,
+ encoding=None, explicit_start=None, explicit_end=None,
+ version=None, tags=None, sort_keys=True):
+ CEmitter.__init__(self, stream, canonical=canonical,
+ indent=indent, width=width, encoding=encoding,
+ allow_unicode=allow_unicode, line_break=line_break,
+ explicit_start=explicit_start, explicit_end=explicit_end,
+ version=version, tags=tags)
+ Representer.__init__(self, default_style=default_style,
+ default_flow_style=default_flow_style, sort_keys=sort_keys)
+ Resolver.__init__(self)
+
diff --git a/libs/yaml/dumper.py b/libs/yaml/dumper.py
new file mode 100644
index 000000000..6aadba551
--- /dev/null
+++ b/libs/yaml/dumper.py
@@ -0,0 +1,62 @@
+
+__all__ = ['BaseDumper', 'SafeDumper', 'Dumper']
+
+from .emitter import *
+from .serializer import *
+from .representer import *
+from .resolver import *
+
+class BaseDumper(Emitter, Serializer, BaseRepresenter, BaseResolver):
+
+ def __init__(self, stream,
+ default_style=None, default_flow_style=False,
+ canonical=None, indent=None, width=None,
+ allow_unicode=None, line_break=None,
+ encoding=None, explicit_start=None, explicit_end=None,
+ version=None, tags=None, sort_keys=True):
+ Emitter.__init__(self, stream, canonical=canonical,
+ indent=indent, width=width,
+ allow_unicode=allow_unicode, line_break=line_break)
+ Serializer.__init__(self, encoding=encoding,
+ explicit_start=explicit_start, explicit_end=explicit_end,
+ version=version, tags=tags)
+ Representer.__init__(self, default_style=default_style,
+ default_flow_style=default_flow_style, sort_keys=sort_keys)
+ Resolver.__init__(self)
+
+class SafeDumper(Emitter, Serializer, SafeRepresenter, Resolver):
+
+ def __init__(self, stream,
+ default_style=None, default_flow_style=False,
+ canonical=None, indent=None, width=None,
+ allow_unicode=None, line_break=None,
+ encoding=None, explicit_start=None, explicit_end=None,
+ version=None, tags=None, sort_keys=True):
+ Emitter.__init__(self, stream, canonical=canonical,
+ indent=indent, width=width,
+ allow_unicode=allow_unicode, line_break=line_break)
+ Serializer.__init__(self, encoding=encoding,
+ explicit_start=explicit_start, explicit_end=explicit_end,
+ version=version, tags=tags)
+ SafeRepresenter.__init__(self, default_style=default_style,
+ default_flow_style=default_flow_style, sort_keys=sort_keys)
+ Resolver.__init__(self)
+
+class Dumper(Emitter, Serializer, Representer, Resolver):
+
+ def __init__(self, stream,
+ default_style=None, default_flow_style=False,
+ canonical=None, indent=None, width=None,
+ allow_unicode=None, line_break=None,
+ encoding=None, explicit_start=None, explicit_end=None,
+ version=None, tags=None, sort_keys=True):
+ Emitter.__init__(self, stream, canonical=canonical,
+ indent=indent, width=width,
+ allow_unicode=allow_unicode, line_break=line_break)
+ Serializer.__init__(self, encoding=encoding,
+ explicit_start=explicit_start, explicit_end=explicit_end,
+ version=version, tags=tags)
+ Representer.__init__(self, default_style=default_style,
+ default_flow_style=default_flow_style, sort_keys=sort_keys)
+ Resolver.__init__(self)
+
diff --git a/libs/yaml/emitter.py b/libs/yaml/emitter.py
new file mode 100644
index 000000000..a664d0111
--- /dev/null
+++ b/libs/yaml/emitter.py
@@ -0,0 +1,1137 @@
+
+# Emitter expects events obeying the following grammar:
+# stream ::= STREAM-START document* STREAM-END
+# document ::= DOCUMENT-START node DOCUMENT-END
+# node ::= SCALAR | sequence | mapping
+# sequence ::= SEQUENCE-START node* SEQUENCE-END
+# mapping ::= MAPPING-START (node node)* MAPPING-END
+
+__all__ = ['Emitter', 'EmitterError']
+
+from .error import YAMLError
+from .events import *
+
+class EmitterError(YAMLError):
+ pass
+
+class ScalarAnalysis:
+ def __init__(self, scalar, empty, multiline,
+ allow_flow_plain, allow_block_plain,
+ allow_single_quoted, allow_double_quoted,
+ allow_block):
+ self.scalar = scalar
+ self.empty = empty
+ self.multiline = multiline
+ self.allow_flow_plain = allow_flow_plain
+ self.allow_block_plain = allow_block_plain
+ self.allow_single_quoted = allow_single_quoted
+ self.allow_double_quoted = allow_double_quoted
+ self.allow_block = allow_block
+
+class Emitter:
+
+ DEFAULT_TAG_PREFIXES = {
+ '!' : '!',
+ 'tag:yaml.org,2002:' : '!!',
+ }
+
+ def __init__(self, stream, canonical=None, indent=None, width=None,
+ allow_unicode=None, line_break=None):
+
+ # The stream should have the methods `write` and possibly `flush`.
+ self.stream = stream
+
+ # Encoding can be overridden by STREAM-START.
+ self.encoding = None
+
+ # Emitter is a state machine with a stack of states to handle nested
+ # structures.
+ self.states = []
+ self.state = self.expect_stream_start
+
+ # Current event and the event queue.
+ self.events = []
+ self.event = None
+
+ # The current indentation level and the stack of previous indents.
+ self.indents = []
+ self.indent = None
+
+ # Flow level.
+ self.flow_level = 0
+
+ # Contexts.
+ self.root_context = False
+ self.sequence_context = False
+ self.mapping_context = False
+ self.simple_key_context = False
+
+ # Characteristics of the last emitted character:
+ # - current position.
+ # - is it a whitespace?
+ # - is it an indention character
+ # (indentation space, '-', '?', or ':')?
+ self.line = 0
+ self.column = 0
+ self.whitespace = True
+ self.indention = True
+
+ # Whether the document requires an explicit document indicator
+ self.open_ended = False
+
+ # Formatting details.
+ self.canonical = canonical
+ self.allow_unicode = allow_unicode
+ self.best_indent = 2
+ if indent and 1 < indent < 10:
+ self.best_indent = indent
+ self.best_width = 80
+ if width and width > self.best_indent*2:
+ self.best_width = width
+ self.best_line_break = '\n'
+ if line_break in ['\r', '\n', '\r\n']:
+ self.best_line_break = line_break
+
+ # Tag prefixes.
+ self.tag_prefixes = None
+
+ # Prepared anchor and tag.
+ self.prepared_anchor = None
+ self.prepared_tag = None
+
+ # Scalar analysis and style.
+ self.analysis = None
+ self.style = None
+
+ def dispose(self):
+ # Reset the state attributes (to clear self-references)
+ self.states = []
+ self.state = None
+
+ def emit(self, event):
+ self.events.append(event)
+ while not self.need_more_events():
+ self.event = self.events.pop(0)
+ self.state()
+ self.event = None
+
+ # In some cases, we wait for a few next events before emitting.
+
+ def need_more_events(self):
+ if not self.events:
+ return True
+ event = self.events[0]
+ if isinstance(event, DocumentStartEvent):
+ return self.need_events(1)
+ elif isinstance(event, SequenceStartEvent):
+ return self.need_events(2)
+ elif isinstance(event, MappingStartEvent):
+ return self.need_events(3)
+ else:
+ return False
+
+ def need_events(self, count):
+ level = 0
+ for event in self.events[1:]:
+ if isinstance(event, (DocumentStartEvent, CollectionStartEvent)):
+ level += 1
+ elif isinstance(event, (DocumentEndEvent, CollectionEndEvent)):
+ level -= 1
+ elif isinstance(event, StreamEndEvent):
+ level = -1
+ if level < 0:
+ return False
+ return (len(self.events) < count+1)
+
+ def increase_indent(self, flow=False, indentless=False):
+ self.indents.append(self.indent)
+ if self.indent is None:
+ if flow:
+ self.indent = self.best_indent
+ else:
+ self.indent = 0
+ elif not indentless:
+ self.indent += self.best_indent
+
+ # States.
+
+ # Stream handlers.
+
+ def expect_stream_start(self):
+ if isinstance(self.event, StreamStartEvent):
+ if self.event.encoding and not hasattr(self.stream, 'encoding'):
+ self.encoding = self.event.encoding
+ self.write_stream_start()
+ self.state = self.expect_first_document_start
+ else:
+ raise EmitterError("expected StreamStartEvent, but got %s"
+ % self.event)
+
+ def expect_nothing(self):
+ raise EmitterError("expected nothing, but got %s" % self.event)
+
+ # Document handlers.
+
+ def expect_first_document_start(self):
+ return self.expect_document_start(first=True)
+
+ def expect_document_start(self, first=False):
+ if isinstance(self.event, DocumentStartEvent):
+ if (self.event.version or self.event.tags) and self.open_ended:
+ self.write_indicator('...', True)
+ self.write_indent()
+ if self.event.version:
+ version_text = self.prepare_version(self.event.version)
+ self.write_version_directive(version_text)
+ self.tag_prefixes = self.DEFAULT_TAG_PREFIXES.copy()
+ if self.event.tags:
+ handles = sorted(self.event.tags.keys())
+ for handle in handles:
+ prefix = self.event.tags[handle]
+ self.tag_prefixes[prefix] = handle
+ handle_text = self.prepare_tag_handle(handle)
+ prefix_text = self.prepare_tag_prefix(prefix)
+ self.write_tag_directive(handle_text, prefix_text)
+ implicit = (first and not self.event.explicit and not self.canonical
+ and not self.event.version and not self.event.tags
+ and not self.check_empty_document())
+ if not implicit:
+ self.write_indent()
+ self.write_indicator('---', True)
+ if self.canonical:
+ self.write_indent()
+ self.state = self.expect_document_root
+ elif isinstance(self.event, StreamEndEvent):
+ if self.open_ended:
+ self.write_indicator('...', True)
+ self.write_indent()
+ self.write_stream_end()
+ self.state = self.expect_nothing
+ else:
+ raise EmitterError("expected DocumentStartEvent, but got %s"
+ % self.event)
+
+ def expect_document_end(self):
+ if isinstance(self.event, DocumentEndEvent):
+ self.write_indent()
+ if self.event.explicit:
+ self.write_indicator('...', True)
+ self.write_indent()
+ self.flush_stream()
+ self.state = self.expect_document_start
+ else:
+ raise EmitterError("expected DocumentEndEvent, but got %s"
+ % self.event)
+
+ def expect_document_root(self):
+ self.states.append(self.expect_document_end)
+ self.expect_node(root=True)
+
+ # Node handlers.
+
+ def expect_node(self, root=False, sequence=False, mapping=False,
+ simple_key=False):
+ self.root_context = root
+ self.sequence_context = sequence
+ self.mapping_context = mapping
+ self.simple_key_context = simple_key
+ if isinstance(self.event, AliasEvent):
+ self.expect_alias()
+ elif isinstance(self.event, (ScalarEvent, CollectionStartEvent)):
+ self.process_anchor('&')
+ self.process_tag()
+ if isinstance(self.event, ScalarEvent):
+ self.expect_scalar()
+ elif isinstance(self.event, SequenceStartEvent):
+ if self.flow_level or self.canonical or self.event.flow_style \
+ or self.check_empty_sequence():
+ self.expect_flow_sequence()
+ else:
+ self.expect_block_sequence()
+ elif isinstance(self.event, MappingStartEvent):
+ if self.flow_level or self.canonical or self.event.flow_style \
+ or self.check_empty_mapping():
+ self.expect_flow_mapping()
+ else:
+ self.expect_block_mapping()
+ else:
+ raise EmitterError("expected NodeEvent, but got %s" % self.event)
+
+ def expect_alias(self):
+ if self.event.anchor is None:
+ raise EmitterError("anchor is not specified for alias")
+ self.process_anchor('*')
+ self.state = self.states.pop()
+
+ def expect_scalar(self):
+ self.increase_indent(flow=True)
+ self.process_scalar()
+ self.indent = self.indents.pop()
+ self.state = self.states.pop()
+
+ # Flow sequence handlers.
+
+ def expect_flow_sequence(self):
+ self.write_indicator('[', True, whitespace=True)
+ self.flow_level += 1
+ self.increase_indent(flow=True)
+ self.state = self.expect_first_flow_sequence_item
+
+ def expect_first_flow_sequence_item(self):
+ if isinstance(self.event, SequenceEndEvent):
+ self.indent = self.indents.pop()
+ self.flow_level -= 1
+ self.write_indicator(']', False)
+ self.state = self.states.pop()
+ else:
+ if self.canonical or self.column > self.best_width:
+ self.write_indent()
+ self.states.append(self.expect_flow_sequence_item)
+ self.expect_node(sequence=True)
+
+ def expect_flow_sequence_item(self):
+ if isinstance(self.event, SequenceEndEvent):
+ self.indent = self.indents.pop()
+ self.flow_level -= 1
+ if self.canonical:
+ self.write_indicator(',', False)
+ self.write_indent()
+ self.write_indicator(']', False)
+ self.state = self.states.pop()
+ else:
+ self.write_indicator(',', False)
+ if self.canonical or self.column > self.best_width:
+ self.write_indent()
+ self.states.append(self.expect_flow_sequence_item)
+ self.expect_node(sequence=True)
+
+ # Flow mapping handlers.
+
+ def expect_flow_mapping(self):
+ self.write_indicator('{', True, whitespace=True)
+ self.flow_level += 1
+ self.increase_indent(flow=True)
+ self.state = self.expect_first_flow_mapping_key
+
+ def expect_first_flow_mapping_key(self):
+ if isinstance(self.event, MappingEndEvent):
+ self.indent = self.indents.pop()
+ self.flow_level -= 1
+ self.write_indicator('}', False)
+ self.state = self.states.pop()
+ else:
+ if self.canonical or self.column > self.best_width:
+ self.write_indent()
+ if not self.canonical and self.check_simple_key():
+ self.states.append(self.expect_flow_mapping_simple_value)
+ self.expect_node(mapping=True, simple_key=True)
+ else:
+ self.write_indicator('?', True)
+ self.states.append(self.expect_flow_mapping_value)
+ self.expect_node(mapping=True)
+
+ def expect_flow_mapping_key(self):
+ if isinstance(self.event, MappingEndEvent):
+ self.indent = self.indents.pop()
+ self.flow_level -= 1
+ if self.canonical:
+ self.write_indicator(',', False)
+ self.write_indent()
+ self.write_indicator('}', False)
+ self.state = self.states.pop()
+ else:
+ self.write_indicator(',', False)
+ if self.canonical or self.column > self.best_width:
+ self.write_indent()
+ if not self.canonical and self.check_simple_key():
+ self.states.append(self.expect_flow_mapping_simple_value)
+ self.expect_node(mapping=True, simple_key=True)
+ else:
+ self.write_indicator('?', True)
+ self.states.append(self.expect_flow_mapping_value)
+ self.expect_node(mapping=True)
+
+ def expect_flow_mapping_simple_value(self):
+ self.write_indicator(':', False)
+ self.states.append(self.expect_flow_mapping_key)
+ self.expect_node(mapping=True)
+
+ def expect_flow_mapping_value(self):
+ if self.canonical or self.column > self.best_width:
+ self.write_indent()
+ self.write_indicator(':', True)
+ self.states.append(self.expect_flow_mapping_key)
+ self.expect_node(mapping=True)
+
+ # Block sequence handlers.
+
+ def expect_block_sequence(self):
+ indentless = (self.mapping_context and not self.indention)
+ self.increase_indent(flow=False, indentless=indentless)
+ self.state = self.expect_first_block_sequence_item
+
+ def expect_first_block_sequence_item(self):
+ return self.expect_block_sequence_item(first=True)
+
+ def expect_block_sequence_item(self, first=False):
+ if not first and isinstance(self.event, SequenceEndEvent):
+ self.indent = self.indents.pop()
+ self.state = self.states.pop()
+ else:
+ self.write_indent()
+ self.write_indicator('-', True, indention=True)
+ self.states.append(self.expect_block_sequence_item)
+ self.expect_node(sequence=True)
+
+ # Block mapping handlers.
+
+ def expect_block_mapping(self):
+ self.increase_indent(flow=False)
+ self.state = self.expect_first_block_mapping_key
+
+ def expect_first_block_mapping_key(self):
+ return self.expect_block_mapping_key(first=True)
+
+ def expect_block_mapping_key(self, first=False):
+ if not first and isinstance(self.event, MappingEndEvent):
+ self.indent = self.indents.pop()
+ self.state = self.states.pop()
+ else:
+ self.write_indent()
+ if self.check_simple_key():
+ self.states.append(self.expect_block_mapping_simple_value)
+ self.expect_node(mapping=True, simple_key=True)
+ else:
+ self.write_indicator('?', True, indention=True)
+ self.states.append(self.expect_block_mapping_value)
+ self.expect_node(mapping=True)
+
+ def expect_block_mapping_simple_value(self):
+ self.write_indicator(':', False)
+ self.states.append(self.expect_block_mapping_key)
+ self.expect_node(mapping=True)
+
+ def expect_block_mapping_value(self):
+ self.write_indent()
+ self.write_indicator(':', True, indention=True)
+ self.states.append(self.expect_block_mapping_key)
+ self.expect_node(mapping=True)
+
+ # Checkers.
+
+ def check_empty_sequence(self):
+ return (isinstance(self.event, SequenceStartEvent) and self.events
+ and isinstance(self.events[0], SequenceEndEvent))
+
+ def check_empty_mapping(self):
+ return (isinstance(self.event, MappingStartEvent) and self.events
+ and isinstance(self.events[0], MappingEndEvent))
+
+ def check_empty_document(self):
+ if not isinstance(self.event, DocumentStartEvent) or not self.events:
+ return False
+ event = self.events[0]
+ return (isinstance(event, ScalarEvent) and event.anchor is None
+ and event.tag is None and event.implicit and event.value == '')
+
+ def check_simple_key(self):
+ length = 0
+ if isinstance(self.event, NodeEvent) and self.event.anchor is not None:
+ if self.prepared_anchor is None:
+ self.prepared_anchor = self.prepare_anchor(self.event.anchor)
+ length += len(self.prepared_anchor)
+ if isinstance(self.event, (ScalarEvent, CollectionStartEvent)) \
+ and self.event.tag is not None:
+ if self.prepared_tag is None:
+ self.prepared_tag = self.prepare_tag(self.event.tag)
+ length += len(self.prepared_tag)
+ if isinstance(self.event, ScalarEvent):
+ if self.analysis is None:
+ self.analysis = self.analyze_scalar(self.event.value)
+ length += len(self.analysis.scalar)
+ return (length < 128 and (isinstance(self.event, AliasEvent)
+ or (isinstance(self.event, ScalarEvent)
+ and not self.analysis.empty and not self.analysis.multiline)
+ or self.check_empty_sequence() or self.check_empty_mapping()))
+
+ # Anchor, Tag, and Scalar processors.
+
+ def process_anchor(self, indicator):
+ if self.event.anchor is None:
+ self.prepared_anchor = None
+ return
+ if self.prepared_anchor is None:
+ self.prepared_anchor = self.prepare_anchor(self.event.anchor)
+ if self.prepared_anchor:
+ self.write_indicator(indicator+self.prepared_anchor, True)
+ self.prepared_anchor = None
+
+ def process_tag(self):
+ tag = self.event.tag
+ if isinstance(self.event, ScalarEvent):
+ if self.style is None:
+ self.style = self.choose_scalar_style()
+ if ((not self.canonical or tag is None) and
+ ((self.style == '' and self.event.implicit[0])
+ or (self.style != '' and self.event.implicit[1]))):
+ self.prepared_tag = None
+ return
+ if self.event.implicit[0] and tag is None:
+ tag = '!'
+ self.prepared_tag = None
+ else:
+ if (not self.canonical or tag is None) and self.event.implicit:
+ self.prepared_tag = None
+ return
+ if tag is None:
+ raise EmitterError("tag is not specified")
+ if self.prepared_tag is None:
+ self.prepared_tag = self.prepare_tag(tag)
+ if self.prepared_tag:
+ self.write_indicator(self.prepared_tag, True)
+ self.prepared_tag = None
+
+ def choose_scalar_style(self):
+ if self.analysis is None:
+ self.analysis = self.analyze_scalar(self.event.value)
+ if self.event.style == '"' or self.canonical:
+ return '"'
+ if not self.event.style and self.event.implicit[0]:
+ if (not (self.simple_key_context and
+ (self.analysis.empty or self.analysis.multiline))
+ and (self.flow_level and self.analysis.allow_flow_plain
+ or (not self.flow_level and self.analysis.allow_block_plain))):
+ return ''
+ if self.event.style and self.event.style in '|>':
+ if (not self.flow_level and not self.simple_key_context
+ and self.analysis.allow_block):
+ return self.event.style
+ if not self.event.style or self.event.style == '\'':
+ if (self.analysis.allow_single_quoted and
+ not (self.simple_key_context and self.analysis.multiline)):
+ return '\''
+ return '"'
+
+ def process_scalar(self):
+ if self.analysis is None:
+ self.analysis = self.analyze_scalar(self.event.value)
+ if self.style is None:
+ self.style = self.choose_scalar_style()
+ split = (not self.simple_key_context)
+ #if self.analysis.multiline and split \
+ # and (not self.style or self.style in '\'\"'):
+ # self.write_indent()
+ if self.style == '"':
+ self.write_double_quoted(self.analysis.scalar, split)
+ elif self.style == '\'':
+ self.write_single_quoted(self.analysis.scalar, split)
+ elif self.style == '>':
+ self.write_folded(self.analysis.scalar)
+ elif self.style == '|':
+ self.write_literal(self.analysis.scalar)
+ else:
+ self.write_plain(self.analysis.scalar, split)
+ self.analysis = None
+ self.style = None
+
+ # Analyzers.
+
+ def prepare_version(self, version):
+ major, minor = version
+ if major != 1:
+ raise EmitterError("unsupported YAML version: %d.%d" % (major, minor))
+ return '%d.%d' % (major, minor)
+
+ def prepare_tag_handle(self, handle):
+ if not handle:
+ raise EmitterError("tag handle must not be empty")
+ if handle[0] != '!' or handle[-1] != '!':
+ raise EmitterError("tag handle must start and end with '!': %r" % handle)
+ for ch in handle[1:-1]:
+ if not ('0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \
+ or ch in '-_'):
+ raise EmitterError("invalid character %r in the tag handle: %r"
+ % (ch, handle))
+ return handle
+
+ def prepare_tag_prefix(self, prefix):
+ if not prefix:
+ raise EmitterError("tag prefix must not be empty")
+ chunks = []
+ start = end = 0
+ if prefix[0] == '!':
+ end = 1
+ while end < len(prefix):
+ ch = prefix[end]
+ if '0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \
+ or ch in '-;/?!:@&=+$,_.~*\'()[]':
+ end += 1
+ else:
+ if start < end:
+ chunks.append(prefix[start:end])
+ start = end = end+1
+ data = ch.encode('utf-8')
+ for ch in data:
+ chunks.append('%%%02X' % ord(ch))
+ if start < end:
+ chunks.append(prefix[start:end])
+ return ''.join(chunks)
+
+ def prepare_tag(self, tag):
+ if not tag:
+ raise EmitterError("tag must not be empty")
+ if tag == '!':
+ return tag
+ handle = None
+ suffix = tag
+ prefixes = sorted(self.tag_prefixes.keys())
+ for prefix in prefixes:
+ if tag.startswith(prefix) \
+ and (prefix == '!' or len(prefix) < len(tag)):
+ handle = self.tag_prefixes[prefix]
+ suffix = tag[len(prefix):]
+ chunks = []
+ start = end = 0
+ while end < len(suffix):
+ ch = suffix[end]
+ if '0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \
+ or ch in '-;/?:@&=+$,_.~*\'()[]' \
+ or (ch == '!' and handle != '!'):
+ end += 1
+ else:
+ if start < end:
+ chunks.append(suffix[start:end])
+ start = end = end+1
+ data = ch.encode('utf-8')
+ for ch in data:
+ chunks.append('%%%02X' % ch)
+ if start < end:
+ chunks.append(suffix[start:end])
+ suffix_text = ''.join(chunks)
+ if handle:
+ return '%s%s' % (handle, suffix_text)
+ else:
+ return '!<%s>' % suffix_text
+
+ def prepare_anchor(self, anchor):
+ if not anchor:
+ raise EmitterError("anchor must not be empty")
+ for ch in anchor:
+ if not ('0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \
+ or ch in '-_'):
+ raise EmitterError("invalid character %r in the anchor: %r"
+ % (ch, anchor))
+ return anchor
+
+ def analyze_scalar(self, scalar):
+
+ # Empty scalar is a special case.
+ if not scalar:
+ return ScalarAnalysis(scalar=scalar, empty=True, multiline=False,
+ allow_flow_plain=False, allow_block_plain=True,
+ allow_single_quoted=True, allow_double_quoted=True,
+ allow_block=False)
+
+ # Indicators and special characters.
+ block_indicators = False
+ flow_indicators = False
+ line_breaks = False
+ special_characters = False
+
+ # Important whitespace combinations.
+ leading_space = False
+ leading_break = False
+ trailing_space = False
+ trailing_break = False
+ break_space = False
+ space_break = False
+
+ # Check document indicators.
+ if scalar.startswith('---') or scalar.startswith('...'):
+ block_indicators = True
+ flow_indicators = True
+
+ # First character or preceded by a whitespace.
+ preceded_by_whitespace = True
+
+ # Last character or followed by a whitespace.
+ followed_by_whitespace = (len(scalar) == 1 or
+ scalar[1] in '\0 \t\r\n\x85\u2028\u2029')
+
+ # The previous character is a space.
+ previous_space = False
+
+ # The previous character is a break.
+ previous_break = False
+
+ index = 0
+ while index < len(scalar):
+ ch = scalar[index]
+
+ # Check for indicators.
+ if index == 0:
+ # Leading indicators are special characters.
+ if ch in '#,[]{}&*!|>\'\"%@`':
+ flow_indicators = True
+ block_indicators = True
+ if ch in '?:':
+ flow_indicators = True
+ if followed_by_whitespace:
+ block_indicators = True
+ if ch == '-' and followed_by_whitespace:
+ flow_indicators = True
+ block_indicators = True
+ else:
+ # Some indicators cannot appear within a scalar as well.
+ if ch in ',?[]{}':
+ flow_indicators = True
+ if ch == ':':
+ flow_indicators = True
+ if followed_by_whitespace:
+ block_indicators = True
+ if ch == '#' and preceded_by_whitespace:
+ flow_indicators = True
+ block_indicators = True
+
+ # Check for line breaks, special, and unicode characters.
+ if ch in '\n\x85\u2028\u2029':
+ line_breaks = True
+ if not (ch == '\n' or '\x20' <= ch <= '\x7E'):
+ if (ch == '\x85' or '\xA0' <= ch <= '\uD7FF'
+ or '\uE000' <= ch <= '\uFFFD'
+ or '\U00010000' <= ch < '\U0010ffff') and ch != '\uFEFF':
+ unicode_characters = True
+ if not self.allow_unicode:
+ special_characters = True
+ else:
+ special_characters = True
+
+ # Detect important whitespace combinations.
+ if ch == ' ':
+ if index == 0:
+ leading_space = True
+ if index == len(scalar)-1:
+ trailing_space = True
+ if previous_break:
+ break_space = True
+ previous_space = True
+ previous_break = False
+ elif ch in '\n\x85\u2028\u2029':
+ if index == 0:
+ leading_break = True
+ if index == len(scalar)-1:
+ trailing_break = True
+ if previous_space:
+ space_break = True
+ previous_space = False
+ previous_break = True
+ else:
+ previous_space = False
+ previous_break = False
+
+ # Prepare for the next character.
+ index += 1
+ preceded_by_whitespace = (ch in '\0 \t\r\n\x85\u2028\u2029')
+ followed_by_whitespace = (index+1 >= len(scalar) or
+ scalar[index+1] in '\0 \t\r\n\x85\u2028\u2029')
+
+ # Let's decide what styles are allowed.
+ allow_flow_plain = True
+ allow_block_plain = True
+ allow_single_quoted = True
+ allow_double_quoted = True
+ allow_block = True
+
+ # Leading and trailing whitespaces are bad for plain scalars.
+ if (leading_space or leading_break
+ or trailing_space or trailing_break):
+ allow_flow_plain = allow_block_plain = False
+
+ # We do not permit trailing spaces for block scalars.
+ if trailing_space:
+ allow_block = False
+
+ # Spaces at the beginning of a new line are only acceptable for block
+ # scalars.
+ if break_space:
+ allow_flow_plain = allow_block_plain = allow_single_quoted = False
+
+ # Spaces followed by breaks, as well as special character are only
+ # allowed for double quoted scalars.
+ if space_break or special_characters:
+ allow_flow_plain = allow_block_plain = \
+ allow_single_quoted = allow_block = False
+
+ # Although the plain scalar writer supports breaks, we never emit
+ # multiline plain scalars.
+ if line_breaks:
+ allow_flow_plain = allow_block_plain = False
+
+ # Flow indicators are forbidden for flow plain scalars.
+ if flow_indicators:
+ allow_flow_plain = False
+
+ # Block indicators are forbidden for block plain scalars.
+ if block_indicators:
+ allow_block_plain = False
+
+ return ScalarAnalysis(scalar=scalar,
+ empty=False, multiline=line_breaks,
+ allow_flow_plain=allow_flow_plain,
+ allow_block_plain=allow_block_plain,
+ allow_single_quoted=allow_single_quoted,
+ allow_double_quoted=allow_double_quoted,
+ allow_block=allow_block)
+
+ # Writers.
+
+ def flush_stream(self):
+ if hasattr(self.stream, 'flush'):
+ self.stream.flush()
+
+ def write_stream_start(self):
+ # Write BOM if needed.
+ if self.encoding and self.encoding.startswith('utf-16'):
+ self.stream.write('\uFEFF'.encode(self.encoding))
+
+ def write_stream_end(self):
+ self.flush_stream()
+
+ def write_indicator(self, indicator, need_whitespace,
+ whitespace=False, indention=False):
+ if self.whitespace or not need_whitespace:
+ data = indicator
+ else:
+ data = ' '+indicator
+ self.whitespace = whitespace
+ self.indention = self.indention and indention
+ self.column += len(data)
+ self.open_ended = False
+ if self.encoding:
+ data = data.encode(self.encoding)
+ self.stream.write(data)
+
+ def write_indent(self):
+ indent = self.indent or 0
+ if not self.indention or self.column > indent \
+ or (self.column == indent and not self.whitespace):
+ self.write_line_break()
+ if self.column < indent:
+ self.whitespace = True
+ data = ' '*(indent-self.column)
+ self.column = indent
+ if self.encoding:
+ data = data.encode(self.encoding)
+ self.stream.write(data)
+
+ def write_line_break(self, data=None):
+ if data is None:
+ data = self.best_line_break
+ self.whitespace = True
+ self.indention = True
+ self.line += 1
+ self.column = 0
+ if self.encoding:
+ data = data.encode(self.encoding)
+ self.stream.write(data)
+
+ def write_version_directive(self, version_text):
+ data = '%%YAML %s' % version_text
+ if self.encoding:
+ data = data.encode(self.encoding)
+ self.stream.write(data)
+ self.write_line_break()
+
+ def write_tag_directive(self, handle_text, prefix_text):
+ data = '%%TAG %s %s' % (handle_text, prefix_text)
+ if self.encoding:
+ data = data.encode(self.encoding)
+ self.stream.write(data)
+ self.write_line_break()
+
+ # Scalar streams.
+
+ def write_single_quoted(self, text, split=True):
+ self.write_indicator('\'', True)
+ spaces = False
+ breaks = False
+ start = end = 0
+ while end <= len(text):
+ ch = None
+ if end < len(text):
+ ch = text[end]
+ if spaces:
+ if ch is None or ch != ' ':
+ if start+1 == end and self.column > self.best_width and split \
+ and start != 0 and end != len(text):
+ self.write_indent()
+ else:
+ data = text[start:end]
+ self.column += len(data)
+ if self.encoding:
+ data = data.encode(self.encoding)
+ self.stream.write(data)
+ start = end
+ elif breaks:
+ if ch is None or ch not in '\n\x85\u2028\u2029':
+ if text[start] == '\n':
+ self.write_line_break()
+ for br in text[start:end]:
+ if br == '\n':
+ self.write_line_break()
+ else:
+ self.write_line_break(br)
+ self.write_indent()
+ start = end
+ else:
+ if ch is None or ch in ' \n\x85\u2028\u2029' or ch == '\'':
+ if start < end:
+ data = text[start:end]
+ self.column += len(data)
+ if self.encoding:
+ data = data.encode(self.encoding)
+ self.stream.write(data)
+ start = end
+ if ch == '\'':
+ data = '\'\''
+ self.column += 2
+ if self.encoding:
+ data = data.encode(self.encoding)
+ self.stream.write(data)
+ start = end + 1
+ if ch is not None:
+ spaces = (ch == ' ')
+ breaks = (ch in '\n\x85\u2028\u2029')
+ end += 1
+ self.write_indicator('\'', False)
+
+ ESCAPE_REPLACEMENTS = {
+ '\0': '0',
+ '\x07': 'a',
+ '\x08': 'b',
+ '\x09': 't',
+ '\x0A': 'n',
+ '\x0B': 'v',
+ '\x0C': 'f',
+ '\x0D': 'r',
+ '\x1B': 'e',
+ '\"': '\"',
+ '\\': '\\',
+ '\x85': 'N',
+ '\xA0': '_',
+ '\u2028': 'L',
+ '\u2029': 'P',
+ }
+
+ def write_double_quoted(self, text, split=True):
+ self.write_indicator('"', True)
+ start = end = 0
+ while end <= len(text):
+ ch = None
+ if end < len(text):
+ ch = text[end]
+ if ch is None or ch in '"\\\x85\u2028\u2029\uFEFF' \
+ or not ('\x20' <= ch <= '\x7E'
+ or (self.allow_unicode
+ and ('\xA0' <= ch <= '\uD7FF'
+ or '\uE000' <= ch <= '\uFFFD'))):
+ if start < end:
+ data = text[start:end]
+ self.column += len(data)
+ if self.encoding:
+ data = data.encode(self.encoding)
+ self.stream.write(data)
+ start = end
+ if ch is not None:
+ if ch in self.ESCAPE_REPLACEMENTS:
+ data = '\\'+self.ESCAPE_REPLACEMENTS[ch]
+ elif ch <= '\xFF':
+ data = '\\x%02X' % ord(ch)
+ elif ch <= '\uFFFF':
+ data = '\\u%04X' % ord(ch)
+ else:
+ data = '\\U%08X' % ord(ch)
+ self.column += len(data)
+ if self.encoding:
+ data = data.encode(self.encoding)
+ self.stream.write(data)
+ start = end+1
+ if 0 < end < len(text)-1 and (ch == ' ' or start >= end) \
+ and self.column+(end-start) > self.best_width and split:
+ data = text[start:end]+'\\'
+ if start < end:
+ start = end
+ self.column += len(data)
+ if self.encoding:
+ data = data.encode(self.encoding)
+ self.stream.write(data)
+ self.write_indent()
+ self.whitespace = False
+ self.indention = False
+ if text[start] == ' ':
+ data = '\\'
+ self.column += len(data)
+ if self.encoding:
+ data = data.encode(self.encoding)
+ self.stream.write(data)
+ end += 1
+ self.write_indicator('"', False)
+
+ def determine_block_hints(self, text):
+ hints = ''
+ if text:
+ if text[0] in ' \n\x85\u2028\u2029':
+ hints += str(self.best_indent)
+ if text[-1] not in '\n\x85\u2028\u2029':
+ hints += '-'
+ elif len(text) == 1 or text[-2] in '\n\x85\u2028\u2029':
+ hints += '+'
+ return hints
+
+ def write_folded(self, text):
+ hints = self.determine_block_hints(text)
+ self.write_indicator('>'+hints, True)
+ if hints[-1:] == '+':
+ self.open_ended = True
+ self.write_line_break()
+ leading_space = True
+ spaces = False
+ breaks = True
+ start = end = 0
+ while end <= len(text):
+ ch = None
+ if end < len(text):
+ ch = text[end]
+ if breaks:
+ if ch is None or ch not in '\n\x85\u2028\u2029':
+ if not leading_space and ch is not None and ch != ' ' \
+ and text[start] == '\n':
+ self.write_line_break()
+ leading_space = (ch == ' ')
+ for br in text[start:end]:
+ if br == '\n':
+ self.write_line_break()
+ else:
+ self.write_line_break(br)
+ if ch is not None:
+ self.write_indent()
+ start = end
+ elif spaces:
+ if ch != ' ':
+ if start+1 == end and self.column > self.best_width:
+ self.write_indent()
+ else:
+ data = text[start:end]
+ self.column += len(data)
+ if self.encoding:
+ data = data.encode(self.encoding)
+ self.stream.write(data)
+ start = end
+ else:
+ if ch is None or ch in ' \n\x85\u2028\u2029':
+ data = text[start:end]
+ self.column += len(data)
+ if self.encoding:
+ data = data.encode(self.encoding)
+ self.stream.write(data)
+ if ch is None:
+ self.write_line_break()
+ start = end
+ if ch is not None:
+ breaks = (ch in '\n\x85\u2028\u2029')
+ spaces = (ch == ' ')
+ end += 1
+
+ def write_literal(self, text):
+ hints = self.determine_block_hints(text)
+ self.write_indicator('|'+hints, True)
+ if hints[-1:] == '+':
+ self.open_ended = True
+ self.write_line_break()
+ breaks = True
+ start = end = 0
+ while end <= len(text):
+ ch = None
+ if end < len(text):
+ ch = text[end]
+ if breaks:
+ if ch is None or ch not in '\n\x85\u2028\u2029':
+ for br in text[start:end]:
+ if br == '\n':
+ self.write_line_break()
+ else:
+ self.write_line_break(br)
+ if ch is not None:
+ self.write_indent()
+ start = end
+ else:
+ if ch is None or ch in '\n\x85\u2028\u2029':
+ data = text[start:end]
+ if self.encoding:
+ data = data.encode(self.encoding)
+ self.stream.write(data)
+ if ch is None:
+ self.write_line_break()
+ start = end
+ if ch is not None:
+ breaks = (ch in '\n\x85\u2028\u2029')
+ end += 1
+
+ def write_plain(self, text, split=True):
+ if self.root_context:
+ self.open_ended = True
+ if not text:
+ return
+ if not self.whitespace:
+ data = ' '
+ self.column += len(data)
+ if self.encoding:
+ data = data.encode(self.encoding)
+ self.stream.write(data)
+ self.whitespace = False
+ self.indention = False
+ spaces = False
+ breaks = False
+ start = end = 0
+ while end <= len(text):
+ ch = None
+ if end < len(text):
+ ch = text[end]
+ if spaces:
+ if ch != ' ':
+ if start+1 == end and self.column > self.best_width and split:
+ self.write_indent()
+ self.whitespace = False
+ self.indention = False
+ else:
+ data = text[start:end]
+ self.column += len(data)
+ if self.encoding:
+ data = data.encode(self.encoding)
+ self.stream.write(data)
+ start = end
+ elif breaks:
+ if ch not in '\n\x85\u2028\u2029':
+ if text[start] == '\n':
+ self.write_line_break()
+ for br in text[start:end]:
+ if br == '\n':
+ self.write_line_break()
+ else:
+ self.write_line_break(br)
+ self.write_indent()
+ self.whitespace = False
+ self.indention = False
+ start = end
+ else:
+ if ch is None or ch in ' \n\x85\u2028\u2029':
+ data = text[start:end]
+ self.column += len(data)
+ if self.encoding:
+ data = data.encode(self.encoding)
+ self.stream.write(data)
+ start = end
+ if ch is not None:
+ spaces = (ch == ' ')
+ breaks = (ch in '\n\x85\u2028\u2029')
+ end += 1
diff --git a/libs/yaml/error.py b/libs/yaml/error.py
new file mode 100644
index 000000000..b796b4dc5
--- /dev/null
+++ b/libs/yaml/error.py
@@ -0,0 +1,75 @@
+
+__all__ = ['Mark', 'YAMLError', 'MarkedYAMLError']
+
+class Mark:
+
+ def __init__(self, name, index, line, column, buffer, pointer):
+ self.name = name
+ self.index = index
+ self.line = line
+ self.column = column
+ self.buffer = buffer
+ self.pointer = pointer
+
+ def get_snippet(self, indent=4, max_length=75):
+ if self.buffer is None:
+ return None
+ head = ''
+ start = self.pointer
+ while start > 0 and self.buffer[start-1] not in '\0\r\n\x85\u2028\u2029':
+ start -= 1
+ if self.pointer-start > max_length/2-1:
+ head = ' ... '
+ start += 5
+ break
+ tail = ''
+ end = self.pointer
+ while end < len(self.buffer) and self.buffer[end] not in '\0\r\n\x85\u2028\u2029':
+ end += 1
+ if end-self.pointer > max_length/2-1:
+ tail = ' ... '
+ end -= 5
+ break
+ snippet = self.buffer[start:end]
+ return ' '*indent + head + snippet + tail + '\n' \
+ + ' '*(indent+self.pointer-start+len(head)) + '^'
+
+ def __str__(self):
+ snippet = self.get_snippet()
+ where = " in \"%s\", line %d, column %d" \
+ % (self.name, self.line+1, self.column+1)
+ if snippet is not None:
+ where += ":\n"+snippet
+ return where
+
+class YAMLError(Exception):
+ pass
+
+class MarkedYAMLError(YAMLError):
+
+ def __init__(self, context=None, context_mark=None,
+ problem=None, problem_mark=None, note=None):
+ self.context = context
+ self.context_mark = context_mark
+ self.problem = problem
+ self.problem_mark = problem_mark
+ self.note = note
+
+ def __str__(self):
+ lines = []
+ if self.context is not None:
+ lines.append(self.context)
+ if self.context_mark is not None \
+ and (self.problem is None or self.problem_mark is None
+ or self.context_mark.name != self.problem_mark.name
+ or self.context_mark.line != self.problem_mark.line
+ or self.context_mark.column != self.problem_mark.column):
+ lines.append(str(self.context_mark))
+ if self.problem is not None:
+ lines.append(self.problem)
+ if self.problem_mark is not None:
+ lines.append(str(self.problem_mark))
+ if self.note is not None:
+ lines.append(self.note)
+ return '\n'.join(lines)
+
diff --git a/libs/yaml/events.py b/libs/yaml/events.py
new file mode 100644
index 000000000..f79ad389c
--- /dev/null
+++ b/libs/yaml/events.py
@@ -0,0 +1,86 @@
+
+# Abstract classes.
+
+class Event(object):
+ def __init__(self, start_mark=None, end_mark=None):
+ self.start_mark = start_mark
+ self.end_mark = end_mark
+ def __repr__(self):
+ attributes = [key for key in ['anchor', 'tag', 'implicit', 'value']
+ if hasattr(self, key)]
+ arguments = ', '.join(['%s=%r' % (key, getattr(self, key))
+ for key in attributes])
+ return '%s(%s)' % (self.__class__.__name__, arguments)
+
+class NodeEvent(Event):
+ def __init__(self, anchor, start_mark=None, end_mark=None):
+ self.anchor = anchor
+ self.start_mark = start_mark
+ self.end_mark = end_mark
+
+class CollectionStartEvent(NodeEvent):
+ def __init__(self, anchor, tag, implicit, start_mark=None, end_mark=None,
+ flow_style=None):
+ self.anchor = anchor
+ self.tag = tag
+ self.implicit = implicit
+ self.start_mark = start_mark
+ self.end_mark = end_mark
+ self.flow_style = flow_style
+
+class CollectionEndEvent(Event):
+ pass
+
+# Implementations.
+
+class StreamStartEvent(Event):
+ def __init__(self, start_mark=None, end_mark=None, encoding=None):
+ self.start_mark = start_mark
+ self.end_mark = end_mark
+ self.encoding = encoding
+
+class StreamEndEvent(Event):
+ pass
+
+class DocumentStartEvent(Event):
+ def __init__(self, start_mark=None, end_mark=None,
+ explicit=None, version=None, tags=None):
+ self.start_mark = start_mark
+ self.end_mark = end_mark
+ self.explicit = explicit
+ self.version = version
+ self.tags = tags
+
+class DocumentEndEvent(Event):
+ def __init__(self, start_mark=None, end_mark=None,
+ explicit=None):
+ self.start_mark = start_mark
+ self.end_mark = end_mark
+ self.explicit = explicit
+
+class AliasEvent(NodeEvent):
+ pass
+
+class ScalarEvent(NodeEvent):
+ def __init__(self, anchor, tag, implicit, value,
+ start_mark=None, end_mark=None, style=None):
+ self.anchor = anchor
+ self.tag = tag
+ self.implicit = implicit
+ self.value = value
+ self.start_mark = start_mark
+ self.end_mark = end_mark
+ self.style = style
+
+class SequenceStartEvent(CollectionStartEvent):
+ pass
+
+class SequenceEndEvent(CollectionEndEvent):
+ pass
+
+class MappingStartEvent(CollectionStartEvent):
+ pass
+
+class MappingEndEvent(CollectionEndEvent):
+ pass
+
diff --git a/libs/yaml/loader.py b/libs/yaml/loader.py
new file mode 100644
index 000000000..414cb2c15
--- /dev/null
+++ b/libs/yaml/loader.py
@@ -0,0 +1,63 @@
+
+__all__ = ['BaseLoader', 'FullLoader', 'SafeLoader', 'Loader', 'UnsafeLoader']
+
+from .reader import *
+from .scanner import *
+from .parser import *
+from .composer import *
+from .constructor import *
+from .resolver import *
+
+class BaseLoader(Reader, Scanner, Parser, Composer, BaseConstructor, BaseResolver):
+
+ def __init__(self, stream):
+ Reader.__init__(self, stream)
+ Scanner.__init__(self)
+ Parser.__init__(self)
+ Composer.__init__(self)
+ BaseConstructor.__init__(self)
+ BaseResolver.__init__(self)
+
+class FullLoader(Reader, Scanner, Parser, Composer, FullConstructor, Resolver):
+
+ def __init__(self, stream):
+ Reader.__init__(self, stream)
+ Scanner.__init__(self)
+ Parser.__init__(self)
+ Composer.__init__(self)
+ FullConstructor.__init__(self)
+ Resolver.__init__(self)
+
+class SafeLoader(Reader, Scanner, Parser, Composer, SafeConstructor, Resolver):
+
+ def __init__(self, stream):
+ Reader.__init__(self, stream)
+ Scanner.__init__(self)
+ Parser.__init__(self)
+ Composer.__init__(self)
+ SafeConstructor.__init__(self)
+ Resolver.__init__(self)
+
+class Loader(Reader, Scanner, Parser, Composer, Constructor, Resolver):
+
+ def __init__(self, stream):
+ Reader.__init__(self, stream)
+ Scanner.__init__(self)
+ Parser.__init__(self)
+ Composer.__init__(self)
+ Constructor.__init__(self)
+ Resolver.__init__(self)
+
+# UnsafeLoader is the same as Loader (which is and was always unsafe on
+# untrusted input). Use of either Loader or UnsafeLoader should be rare, since
+# FullLoad should be able to load almost all YAML safely. Loader is left intact
+# to ensure backwards compatability.
+class UnsafeLoader(Reader, Scanner, Parser, Composer, Constructor, Resolver):
+
+ def __init__(self, stream):
+ Reader.__init__(self, stream)
+ Scanner.__init__(self)
+ Parser.__init__(self)
+ Composer.__init__(self)
+ Constructor.__init__(self)
+ Resolver.__init__(self)
diff --git a/libs/yaml/nodes.py b/libs/yaml/nodes.py
new file mode 100644
index 000000000..c4f070c41
--- /dev/null
+++ b/libs/yaml/nodes.py
@@ -0,0 +1,49 @@
+
+class Node(object):
+ def __init__(self, tag, value, start_mark, end_mark):
+ self.tag = tag
+ self.value = value
+ self.start_mark = start_mark
+ self.end_mark = end_mark
+ def __repr__(self):
+ value = self.value
+ #if isinstance(value, list):
+ # if len(value) == 0:
+ # value = '<empty>'
+ # elif len(value) == 1:
+ # value = '<1 item>'
+ # else:
+ # value = '<%d items>' % len(value)
+ #else:
+ # if len(value) > 75:
+ # value = repr(value[:70]+u' ... ')
+ # else:
+ # value = repr(value)
+ value = repr(value)
+ return '%s(tag=%r, value=%s)' % (self.__class__.__name__, self.tag, value)
+
+class ScalarNode(Node):
+ id = 'scalar'
+ def __init__(self, tag, value,
+ start_mark=None, end_mark=None, style=None):
+ self.tag = tag
+ self.value = value
+ self.start_mark = start_mark
+ self.end_mark = end_mark
+ self.style = style
+
+class CollectionNode(Node):
+ def __init__(self, tag, value,
+ start_mark=None, end_mark=None, flow_style=None):
+ self.tag = tag
+ self.value = value
+ self.start_mark = start_mark
+ self.end_mark = end_mark
+ self.flow_style = flow_style
+
+class SequenceNode(CollectionNode):
+ id = 'sequence'
+
+class MappingNode(CollectionNode):
+ id = 'mapping'
+
diff --git a/libs/yaml/parser.py b/libs/yaml/parser.py
new file mode 100644
index 000000000..13a5995d2
--- /dev/null
+++ b/libs/yaml/parser.py
@@ -0,0 +1,589 @@
+
+# The following YAML grammar is LL(1) and is parsed by a recursive descent
+# parser.
+#
+# stream ::= STREAM-START implicit_document? explicit_document* STREAM-END
+# implicit_document ::= block_node DOCUMENT-END*
+# explicit_document ::= DIRECTIVE* DOCUMENT-START block_node? DOCUMENT-END*
+# block_node_or_indentless_sequence ::=
+# ALIAS
+# | properties (block_content | indentless_block_sequence)?
+# | block_content
+# | indentless_block_sequence
+# block_node ::= ALIAS
+# | properties block_content?
+# | block_content
+# flow_node ::= ALIAS
+# | properties flow_content?
+# | flow_content
+# properties ::= TAG ANCHOR? | ANCHOR TAG?
+# block_content ::= block_collection | flow_collection | SCALAR
+# flow_content ::= flow_collection | SCALAR
+# block_collection ::= block_sequence | block_mapping
+# flow_collection ::= flow_sequence | flow_mapping
+# block_sequence ::= BLOCK-SEQUENCE-START (BLOCK-ENTRY block_node?)* BLOCK-END
+# indentless_sequence ::= (BLOCK-ENTRY block_node?)+
+# block_mapping ::= BLOCK-MAPPING_START
+# ((KEY block_node_or_indentless_sequence?)?
+# (VALUE block_node_or_indentless_sequence?)?)*
+# BLOCK-END
+# flow_sequence ::= FLOW-SEQUENCE-START
+# (flow_sequence_entry FLOW-ENTRY)*
+# flow_sequence_entry?
+# FLOW-SEQUENCE-END
+# flow_sequence_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)?
+# flow_mapping ::= FLOW-MAPPING-START
+# (flow_mapping_entry FLOW-ENTRY)*
+# flow_mapping_entry?
+# FLOW-MAPPING-END
+# flow_mapping_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)?
+#
+# FIRST sets:
+#
+# stream: { STREAM-START }
+# explicit_document: { DIRECTIVE DOCUMENT-START }
+# implicit_document: FIRST(block_node)
+# block_node: { ALIAS TAG ANCHOR SCALAR BLOCK-SEQUENCE-START BLOCK-MAPPING-START FLOW-SEQUENCE-START FLOW-MAPPING-START }
+# flow_node: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START FLOW-MAPPING-START }
+# block_content: { BLOCK-SEQUENCE-START BLOCK-MAPPING-START FLOW-SEQUENCE-START FLOW-MAPPING-START SCALAR }
+# flow_content: { FLOW-SEQUENCE-START FLOW-MAPPING-START SCALAR }
+# block_collection: { BLOCK-SEQUENCE-START BLOCK-MAPPING-START }
+# flow_collection: { FLOW-SEQUENCE-START FLOW-MAPPING-START }
+# block_sequence: { BLOCK-SEQUENCE-START }
+# block_mapping: { BLOCK-MAPPING-START }
+# block_node_or_indentless_sequence: { ALIAS ANCHOR TAG SCALAR BLOCK-SEQUENCE-START BLOCK-MAPPING-START FLOW-SEQUENCE-START FLOW-MAPPING-START BLOCK-ENTRY }
+# indentless_sequence: { ENTRY }
+# flow_collection: { FLOW-SEQUENCE-START FLOW-MAPPING-START }
+# flow_sequence: { FLOW-SEQUENCE-START }
+# flow_mapping: { FLOW-MAPPING-START }
+# flow_sequence_entry: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START FLOW-MAPPING-START KEY }
+# flow_mapping_entry: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START FLOW-MAPPING-START KEY }
+
+__all__ = ['Parser', 'ParserError']
+
+from .error import MarkedYAMLError
+from .tokens import *
+from .events import *
+from .scanner import *
+
+class ParserError(MarkedYAMLError):
+ pass
+
+class Parser:
+ # Since writing a recursive-descendant parser is a straightforward task, we
+ # do not give many comments here.
+
+ DEFAULT_TAGS = {
+ '!': '!',
+ '!!': 'tag:yaml.org,2002:',
+ }
+
+ def __init__(self):
+ self.current_event = None
+ self.yaml_version = None
+ self.tag_handles = {}
+ self.states = []
+ self.marks = []
+ self.state = self.parse_stream_start
+
+ def dispose(self):
+ # Reset the state attributes (to clear self-references)
+ self.states = []
+ self.state = None
+
+ def check_event(self, *choices):
+ # Check the type of the next event.
+ if self.current_event is None:
+ if self.state:
+ self.current_event = self.state()
+ if self.current_event is not None:
+ if not choices:
+ return True
+ for choice in choices:
+ if isinstance(self.current_event, choice):
+ return True
+ return False
+
+ def peek_event(self):
+ # Get the next event.
+ if self.current_event is None:
+ if self.state:
+ self.current_event = self.state()
+ return self.current_event
+
+ def get_event(self):
+ # Get the next event and proceed further.
+ if self.current_event is None:
+ if self.state:
+ self.current_event = self.state()
+ value = self.current_event
+ self.current_event = None
+ return value
+
+ # stream ::= STREAM-START implicit_document? explicit_document* STREAM-END
+ # implicit_document ::= block_node DOCUMENT-END*
+ # explicit_document ::= DIRECTIVE* DOCUMENT-START block_node? DOCUMENT-END*
+
+ def parse_stream_start(self):
+
+ # Parse the stream start.
+ token = self.get_token()
+ event = StreamStartEvent(token.start_mark, token.end_mark,
+ encoding=token.encoding)
+
+ # Prepare the next state.
+ self.state = self.parse_implicit_document_start
+
+ return event
+
+ def parse_implicit_document_start(self):
+
+ # Parse an implicit document.
+ if not self.check_token(DirectiveToken, DocumentStartToken,
+ StreamEndToken):
+ self.tag_handles = self.DEFAULT_TAGS
+ token = self.peek_token()
+ start_mark = end_mark = token.start_mark
+ event = DocumentStartEvent(start_mark, end_mark,
+ explicit=False)
+
+ # Prepare the next state.
+ self.states.append(self.parse_document_end)
+ self.state = self.parse_block_node
+
+ return event
+
+ else:
+ return self.parse_document_start()
+
+ def parse_document_start(self):
+
+ # Parse any extra document end indicators.
+ while self.check_token(DocumentEndToken):
+ self.get_token()
+
+ # Parse an explicit document.
+ if not self.check_token(StreamEndToken):
+ token = self.peek_token()
+ start_mark = token.start_mark
+ version, tags = self.process_directives()
+ if not self.check_token(DocumentStartToken):
+ raise ParserError(None, None,
+ "expected '<document start>', but found %r"
+ % self.peek_token().id,
+ self.peek_token().start_mark)
+ token = self.get_token()
+ end_mark = token.end_mark
+ event = DocumentStartEvent(start_mark, end_mark,
+ explicit=True, version=version, tags=tags)
+ self.states.append(self.parse_document_end)
+ self.state = self.parse_document_content
+ else:
+ # Parse the end of the stream.
+ token = self.get_token()
+ event = StreamEndEvent(token.start_mark, token.end_mark)
+ assert not self.states
+ assert not self.marks
+ self.state = None
+ return event
+
+ def parse_document_end(self):
+
+ # Parse the document end.
+ token = self.peek_token()
+ start_mark = end_mark = token.start_mark
+ explicit = False
+ if self.check_token(DocumentEndToken):
+ token = self.get_token()
+ end_mark = token.end_mark
+ explicit = True
+ event = DocumentEndEvent(start_mark, end_mark,
+ explicit=explicit)
+
+ # Prepare the next state.
+ self.state = self.parse_document_start
+
+ return event
+
+ def parse_document_content(self):
+ if self.check_token(DirectiveToken,
+ DocumentStartToken, DocumentEndToken, StreamEndToken):
+ event = self.process_empty_scalar(self.peek_token().start_mark)
+ self.state = self.states.pop()
+ return event
+ else:
+ return self.parse_block_node()
+
+ def process_directives(self):
+ self.yaml_version = None
+ self.tag_handles = {}
+ while self.check_token(DirectiveToken):
+ token = self.get_token()
+ if token.name == 'YAML':
+ if self.yaml_version is not None:
+ raise ParserError(None, None,
+ "found duplicate YAML directive", token.start_mark)
+ major, minor = token.value
+ if major != 1:
+ raise ParserError(None, None,
+ "found incompatible YAML document (version 1.* is required)",
+ token.start_mark)
+ self.yaml_version = token.value
+ elif token.name == 'TAG':
+ handle, prefix = token.value
+ if handle in self.tag_handles:
+ raise ParserError(None, None,
+ "duplicate tag handle %r" % handle,
+ token.start_mark)
+ self.tag_handles[handle] = prefix
+ if self.tag_handles:
+ value = self.yaml_version, self.tag_handles.copy()
+ else:
+ value = self.yaml_version, None
+ for key in self.DEFAULT_TAGS:
+ if key not in self.tag_handles:
+ self.tag_handles[key] = self.DEFAULT_TAGS[key]
+ return value
+
+ # block_node_or_indentless_sequence ::= ALIAS
+ # | properties (block_content | indentless_block_sequence)?
+ # | block_content
+ # | indentless_block_sequence
+ # block_node ::= ALIAS
+ # | properties block_content?
+ # | block_content
+ # flow_node ::= ALIAS
+ # | properties flow_content?
+ # | flow_content
+ # properties ::= TAG ANCHOR? | ANCHOR TAG?
+ # block_content ::= block_collection | flow_collection | SCALAR
+ # flow_content ::= flow_collection | SCALAR
+ # block_collection ::= block_sequence | block_mapping
+ # flow_collection ::= flow_sequence | flow_mapping
+
+ def parse_block_node(self):
+ return self.parse_node(block=True)
+
+ def parse_flow_node(self):
+ return self.parse_node()
+
+ def parse_block_node_or_indentless_sequence(self):
+ return self.parse_node(block=True, indentless_sequence=True)
+
+ def parse_node(self, block=False, indentless_sequence=False):
+ if self.check_token(AliasToken):
+ token = self.get_token()
+ event = AliasEvent(token.value, token.start_mark, token.end_mark)
+ self.state = self.states.pop()
+ else:
+ anchor = None
+ tag = None
+ start_mark = end_mark = tag_mark = None
+ if self.check_token(AnchorToken):
+ token = self.get_token()
+ start_mark = token.start_mark
+ end_mark = token.end_mark
+ anchor = token.value
+ if self.check_token(TagToken):
+ token = self.get_token()
+ tag_mark = token.start_mark
+ end_mark = token.end_mark
+ tag = token.value
+ elif self.check_token(TagToken):
+ token = self.get_token()
+ start_mark = tag_mark = token.start_mark
+ end_mark = token.end_mark
+ tag = token.value
+ if self.check_token(AnchorToken):
+ token = self.get_token()
+ end_mark = token.end_mark
+ anchor = token.value
+ if tag is not None:
+ handle, suffix = tag
+ if handle is not None:
+ if handle not in self.tag_handles:
+ raise ParserError("while parsing a node", start_mark,
+ "found undefined tag handle %r" % handle,
+ tag_mark)
+ tag = self.tag_handles[handle]+suffix
+ else:
+ tag = suffix
+ #if tag == '!':
+ # raise ParserError("while parsing a node", start_mark,
+ # "found non-specific tag '!'", tag_mark,
+ # "Please check 'http://pyyaml.org/wiki/YAMLNonSpecificTag' and share your opinion.")
+ if start_mark is None:
+ start_mark = end_mark = self.peek_token().start_mark
+ event = None
+ implicit = (tag is None or tag == '!')
+ if indentless_sequence and self.check_token(BlockEntryToken):
+ end_mark = self.peek_token().end_mark
+ event = SequenceStartEvent(anchor, tag, implicit,
+ start_mark, end_mark)
+ self.state = self.parse_indentless_sequence_entry
+ else:
+ if self.check_token(ScalarToken):
+ token = self.get_token()
+ end_mark = token.end_mark
+ if (token.plain and tag is None) or tag == '!':
+ implicit = (True, False)
+ elif tag is None:
+ implicit = (False, True)
+ else:
+ implicit = (False, False)
+ event = ScalarEvent(anchor, tag, implicit, token.value,
+ start_mark, end_mark, style=token.style)
+ self.state = self.states.pop()
+ elif self.check_token(FlowSequenceStartToken):
+ end_mark = self.peek_token().end_mark
+ event = SequenceStartEvent(anchor, tag, implicit,
+ start_mark, end_mark, flow_style=True)
+ self.state = self.parse_flow_sequence_first_entry
+ elif self.check_token(FlowMappingStartToken):
+ end_mark = self.peek_token().end_mark
+ event = MappingStartEvent(anchor, tag, implicit,
+ start_mark, end_mark, flow_style=True)
+ self.state = self.parse_flow_mapping_first_key
+ elif block and self.check_token(BlockSequenceStartToken):
+ end_mark = self.peek_token().start_mark
+ event = SequenceStartEvent(anchor, tag, implicit,
+ start_mark, end_mark, flow_style=False)
+ self.state = self.parse_block_sequence_first_entry
+ elif block and self.check_token(BlockMappingStartToken):
+ end_mark = self.peek_token().start_mark
+ event = MappingStartEvent(anchor, tag, implicit,
+ start_mark, end_mark, flow_style=False)
+ self.state = self.parse_block_mapping_first_key
+ elif anchor is not None or tag is not None:
+ # Empty scalars are allowed even if a tag or an anchor is
+ # specified.
+ event = ScalarEvent(anchor, tag, (implicit, False), '',
+ start_mark, end_mark)
+ self.state = self.states.pop()
+ else:
+ if block:
+ node = 'block'
+ else:
+ node = 'flow'
+ token = self.peek_token()
+ raise ParserError("while parsing a %s node" % node, start_mark,
+ "expected the node content, but found %r" % token.id,
+ token.start_mark)
+ return event
+
+ # block_sequence ::= BLOCK-SEQUENCE-START (BLOCK-ENTRY block_node?)* BLOCK-END
+
+ def parse_block_sequence_first_entry(self):
+ token = self.get_token()
+ self.marks.append(token.start_mark)
+ return self.parse_block_sequence_entry()
+
+ def parse_block_sequence_entry(self):
+ if self.check_token(BlockEntryToken):
+ token = self.get_token()
+ if not self.check_token(BlockEntryToken, BlockEndToken):
+ self.states.append(self.parse_block_sequence_entry)
+ return self.parse_block_node()
+ else:
+ self.state = self.parse_block_sequence_entry
+ return self.process_empty_scalar(token.end_mark)
+ if not self.check_token(BlockEndToken):
+ token = self.peek_token()
+ raise ParserError("while parsing a block collection", self.marks[-1],
+ "expected <block end>, but found %r" % token.id, token.start_mark)
+ token = self.get_token()
+ event = SequenceEndEvent(token.start_mark, token.end_mark)
+ self.state = self.states.pop()
+ self.marks.pop()
+ return event
+
+ # indentless_sequence ::= (BLOCK-ENTRY block_node?)+
+
+ def parse_indentless_sequence_entry(self):
+ if self.check_token(BlockEntryToken):
+ token = self.get_token()
+ if not self.check_token(BlockEntryToken,
+ KeyToken, ValueToken, BlockEndToken):
+ self.states.append(self.parse_indentless_sequence_entry)
+ return self.parse_block_node()
+ else:
+ self.state = self.parse_indentless_sequence_entry
+ return self.process_empty_scalar(token.end_mark)
+ token = self.peek_token()
+ event = SequenceEndEvent(token.start_mark, token.start_mark)
+ self.state = self.states.pop()
+ return event
+
+ # block_mapping ::= BLOCK-MAPPING_START
+ # ((KEY block_node_or_indentless_sequence?)?
+ # (VALUE block_node_or_indentless_sequence?)?)*
+ # BLOCK-END
+
+ def parse_block_mapping_first_key(self):
+ token = self.get_token()
+ self.marks.append(token.start_mark)
+ return self.parse_block_mapping_key()
+
+ def parse_block_mapping_key(self):
+ if self.check_token(KeyToken):
+ token = self.get_token()
+ if not self.check_token(KeyToken, ValueToken, BlockEndToken):
+ self.states.append(self.parse_block_mapping_value)
+ return self.parse_block_node_or_indentless_sequence()
+ else:
+ self.state = self.parse_block_mapping_value
+ return self.process_empty_scalar(token.end_mark)
+ if not self.check_token(BlockEndToken):
+ token = self.peek_token()
+ raise ParserError("while parsing a block mapping", self.marks[-1],
+ "expected <block end>, but found %r" % token.id, token.start_mark)
+ token = self.get_token()
+ event = MappingEndEvent(token.start_mark, token.end_mark)
+ self.state = self.states.pop()
+ self.marks.pop()
+ return event
+
+ def parse_block_mapping_value(self):
+ if self.check_token(ValueToken):
+ token = self.get_token()
+ if not self.check_token(KeyToken, ValueToken, BlockEndToken):
+ self.states.append(self.parse_block_mapping_key)
+ return self.parse_block_node_or_indentless_sequence()
+ else:
+ self.state = self.parse_block_mapping_key
+ return self.process_empty_scalar(token.end_mark)
+ else:
+ self.state = self.parse_block_mapping_key
+ token = self.peek_token()
+ return self.process_empty_scalar(token.start_mark)
+
+ # flow_sequence ::= FLOW-SEQUENCE-START
+ # (flow_sequence_entry FLOW-ENTRY)*
+ # flow_sequence_entry?
+ # FLOW-SEQUENCE-END
+ # flow_sequence_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)?
+ #
+ # Note that while production rules for both flow_sequence_entry and
+ # flow_mapping_entry are equal, their interpretations are different.
+ # For `flow_sequence_entry`, the part `KEY flow_node? (VALUE flow_node?)?`
+ # generate an inline mapping (set syntax).
+
+ def parse_flow_sequence_first_entry(self):
+ token = self.get_token()
+ self.marks.append(token.start_mark)
+ return self.parse_flow_sequence_entry(first=True)
+
+ def parse_flow_sequence_entry(self, first=False):
+ if not self.check_token(FlowSequenceEndToken):
+ if not first:
+ if self.check_token(FlowEntryToken):
+ self.get_token()
+ else:
+ token = self.peek_token()
+ raise ParserError("while parsing a flow sequence", self.marks[-1],
+ "expected ',' or ']', but got %r" % token.id, token.start_mark)
+
+ if self.check_token(KeyToken):
+ token = self.peek_token()
+ event = MappingStartEvent(None, None, True,
+ token.start_mark, token.end_mark,
+ flow_style=True)
+ self.state = self.parse_flow_sequence_entry_mapping_key
+ return event
+ elif not self.check_token(FlowSequenceEndToken):
+ self.states.append(self.parse_flow_sequence_entry)
+ return self.parse_flow_node()
+ token = self.get_token()
+ event = SequenceEndEvent(token.start_mark, token.end_mark)
+ self.state = self.states.pop()
+ self.marks.pop()
+ return event
+
+ def parse_flow_sequence_entry_mapping_key(self):
+ token = self.get_token()
+ if not self.check_token(ValueToken,
+ FlowEntryToken, FlowSequenceEndToken):
+ self.states.append(self.parse_flow_sequence_entry_mapping_value)
+ return self.parse_flow_node()
+ else:
+ self.state = self.parse_flow_sequence_entry_mapping_value
+ return self.process_empty_scalar(token.end_mark)
+
+ def parse_flow_sequence_entry_mapping_value(self):
+ if self.check_token(ValueToken):
+ token = self.get_token()
+ if not self.check_token(FlowEntryToken, FlowSequenceEndToken):
+ self.states.append(self.parse_flow_sequence_entry_mapping_end)
+ return self.parse_flow_node()
+ else:
+ self.state = self.parse_flow_sequence_entry_mapping_end
+ return self.process_empty_scalar(token.end_mark)
+ else:
+ self.state = self.parse_flow_sequence_entry_mapping_end
+ token = self.peek_token()
+ return self.process_empty_scalar(token.start_mark)
+
+ def parse_flow_sequence_entry_mapping_end(self):
+ self.state = self.parse_flow_sequence_entry
+ token = self.peek_token()
+ return MappingEndEvent(token.start_mark, token.start_mark)
+
+ # flow_mapping ::= FLOW-MAPPING-START
+ # (flow_mapping_entry FLOW-ENTRY)*
+ # flow_mapping_entry?
+ # FLOW-MAPPING-END
+ # flow_mapping_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)?
+
+ def parse_flow_mapping_first_key(self):
+ token = self.get_token()
+ self.marks.append(token.start_mark)
+ return self.parse_flow_mapping_key(first=True)
+
+ def parse_flow_mapping_key(self, first=False):
+ if not self.check_token(FlowMappingEndToken):
+ if not first:
+ if self.check_token(FlowEntryToken):
+ self.get_token()
+ else:
+ token = self.peek_token()
+ raise ParserError("while parsing a flow mapping", self.marks[-1],
+ "expected ',' or '}', but got %r" % token.id, token.start_mark)
+ if self.check_token(KeyToken):
+ token = self.get_token()
+ if not self.check_token(ValueToken,
+ FlowEntryToken, FlowMappingEndToken):
+ self.states.append(self.parse_flow_mapping_value)
+ return self.parse_flow_node()
+ else:
+ self.state = self.parse_flow_mapping_value
+ return self.process_empty_scalar(token.end_mark)
+ elif not self.check_token(FlowMappingEndToken):
+ self.states.append(self.parse_flow_mapping_empty_value)
+ return self.parse_flow_node()
+ token = self.get_token()
+ event = MappingEndEvent(token.start_mark, token.end_mark)
+ self.state = self.states.pop()
+ self.marks.pop()
+ return event
+
+ def parse_flow_mapping_value(self):
+ if self.check_token(ValueToken):
+ token = self.get_token()
+ if not self.check_token(FlowEntryToken, FlowMappingEndToken):
+ self.states.append(self.parse_flow_mapping_key)
+ return self.parse_flow_node()
+ else:
+ self.state = self.parse_flow_mapping_key
+ return self.process_empty_scalar(token.end_mark)
+ else:
+ self.state = self.parse_flow_mapping_key
+ token = self.peek_token()
+ return self.process_empty_scalar(token.start_mark)
+
+ def parse_flow_mapping_empty_value(self):
+ self.state = self.parse_flow_mapping_key
+ return self.process_empty_scalar(self.peek_token().start_mark)
+
+ def process_empty_scalar(self, mark):
+ return ScalarEvent(None, None, (True, False), '', mark, mark)
+
diff --git a/libs/yaml/reader.py b/libs/yaml/reader.py
new file mode 100644
index 000000000..774b0219b
--- /dev/null
+++ b/libs/yaml/reader.py
@@ -0,0 +1,185 @@
+# This module contains abstractions for the input stream. You don't have to
+# looks further, there are no pretty code.
+#
+# We define two classes here.
+#
+# Mark(source, line, column)
+# It's just a record and its only use is producing nice error messages.
+# Parser does not use it for any other purposes.
+#
+# Reader(source, data)
+# Reader determines the encoding of `data` and converts it to unicode.
+# Reader provides the following methods and attributes:
+# reader.peek(length=1) - return the next `length` characters
+# reader.forward(length=1) - move the current position to `length` characters.
+# reader.index - the number of the current character.
+# reader.line, stream.column - the line and the column of the current character.
+
+__all__ = ['Reader', 'ReaderError']
+
+from .error import YAMLError, Mark
+
+import codecs, re
+
+class ReaderError(YAMLError):
+
+ def __init__(self, name, position, character, encoding, reason):
+ self.name = name
+ self.character = character
+ self.position = position
+ self.encoding = encoding
+ self.reason = reason
+
+ def __str__(self):
+ if isinstance(self.character, bytes):
+ return "'%s' codec can't decode byte #x%02x: %s\n" \
+ " in \"%s\", position %d" \
+ % (self.encoding, ord(self.character), self.reason,
+ self.name, self.position)
+ else:
+ return "unacceptable character #x%04x: %s\n" \
+ " in \"%s\", position %d" \
+ % (self.character, self.reason,
+ self.name, self.position)
+
+class Reader(object):
+ # Reader:
+ # - determines the data encoding and converts it to a unicode string,
+ # - checks if characters are in allowed range,
+ # - adds '\0' to the end.
+
+ # Reader accepts
+ # - a `bytes` object,
+ # - a `str` object,
+ # - a file-like object with its `read` method returning `str`,
+ # - a file-like object with its `read` method returning `unicode`.
+
+ # Yeah, it's ugly and slow.
+
+ def __init__(self, stream):
+ self.name = None
+ self.stream = None
+ self.stream_pointer = 0
+ self.eof = True
+ self.buffer = ''
+ self.pointer = 0
+ self.raw_buffer = None
+ self.raw_decode = None
+ self.encoding = None
+ self.index = 0
+ self.line = 0
+ self.column = 0
+ if isinstance(stream, str):
+ self.name = "<unicode string>"
+ self.check_printable(stream)
+ self.buffer = stream+'\0'
+ elif isinstance(stream, bytes):
+ self.name = "<byte string>"
+ self.raw_buffer = stream
+ self.determine_encoding()
+ else:
+ self.stream = stream
+ self.name = getattr(stream, 'name', "<file>")
+ self.eof = False
+ self.raw_buffer = None
+ self.determine_encoding()
+
+ def peek(self, index=0):
+ try:
+ return self.buffer[self.pointer+index]
+ except IndexError:
+ self.update(index+1)
+ return self.buffer[self.pointer+index]
+
+ def prefix(self, length=1):
+ if self.pointer+length >= len(self.buffer):
+ self.update(length)
+ return self.buffer[self.pointer:self.pointer+length]
+
+ def forward(self, length=1):
+ if self.pointer+length+1 >= len(self.buffer):
+ self.update(length+1)
+ while length:
+ ch = self.buffer[self.pointer]
+ self.pointer += 1
+ self.index += 1
+ if ch in '\n\x85\u2028\u2029' \
+ or (ch == '\r' and self.buffer[self.pointer] != '\n'):
+ self.line += 1
+ self.column = 0
+ elif ch != '\uFEFF':
+ self.column += 1
+ length -= 1
+
+ def get_mark(self):
+ if self.stream is None:
+ return Mark(self.name, self.index, self.line, self.column,
+ self.buffer, self.pointer)
+ else:
+ return Mark(self.name, self.index, self.line, self.column,
+ None, None)
+
+ def determine_encoding(self):
+ while not self.eof and (self.raw_buffer is None or len(self.raw_buffer) < 2):
+ self.update_raw()
+ if isinstance(self.raw_buffer, bytes):
+ if self.raw_buffer.startswith(codecs.BOM_UTF16_LE):
+ self.raw_decode = codecs.utf_16_le_decode
+ self.encoding = 'utf-16-le'
+ elif self.raw_buffer.startswith(codecs.BOM_UTF16_BE):
+ self.raw_decode = codecs.utf_16_be_decode
+ self.encoding = 'utf-16-be'
+ else:
+ self.raw_decode = codecs.utf_8_decode
+ self.encoding = 'utf-8'
+ self.update(1)
+
+ NON_PRINTABLE = re.compile('[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]')
+ def check_printable(self, data):
+ match = self.NON_PRINTABLE.search(data)
+ if match:
+ character = match.group()
+ position = self.index+(len(self.buffer)-self.pointer)+match.start()
+ raise ReaderError(self.name, position, ord(character),
+ 'unicode', "special characters are not allowed")
+
+ def update(self, length):
+ if self.raw_buffer is None:
+ return
+ self.buffer = self.buffer[self.pointer:]
+ self.pointer = 0
+ while len(self.buffer) < length:
+ if not self.eof:
+ self.update_raw()
+ if self.raw_decode is not None:
+ try:
+ data, converted = self.raw_decode(self.raw_buffer,
+ 'strict', self.eof)
+ except UnicodeDecodeError as exc:
+ character = self.raw_buffer[exc.start]
+ if self.stream is not None:
+ position = self.stream_pointer-len(self.raw_buffer)+exc.start
+ else:
+ position = exc.start
+ raise ReaderError(self.name, position, character,
+ exc.encoding, exc.reason)
+ else:
+ data = self.raw_buffer
+ converted = len(data)
+ self.check_printable(data)
+ self.buffer += data
+ self.raw_buffer = self.raw_buffer[converted:]
+ if self.eof:
+ self.buffer += '\0'
+ self.raw_buffer = None
+ break
+
+ def update_raw(self, size=4096):
+ data = self.stream.read(size)
+ if self.raw_buffer is None:
+ self.raw_buffer = data
+ else:
+ self.raw_buffer += data
+ self.stream_pointer += len(data)
+ if not data:
+ self.eof = True
diff --git a/libs/yaml/representer.py b/libs/yaml/representer.py
new file mode 100644
index 000000000..dd144017b
--- /dev/null
+++ b/libs/yaml/representer.py
@@ -0,0 +1,389 @@
+
+__all__ = ['BaseRepresenter', 'SafeRepresenter', 'Representer',
+ 'RepresenterError']
+
+from .error import *
+from .nodes import *
+
+import datetime, sys, copyreg, types, base64, collections
+
+class RepresenterError(YAMLError):
+ pass
+
+class BaseRepresenter:
+
+ yaml_representers = {}
+ yaml_multi_representers = {}
+
+ def __init__(self, default_style=None, default_flow_style=False, sort_keys=True):
+ self.default_style = default_style
+ self.sort_keys = sort_keys
+ self.default_flow_style = default_flow_style
+ self.represented_objects = {}
+ self.object_keeper = []
+ self.alias_key = None
+
+ def represent(self, data):
+ node = self.represent_data(data)
+ self.serialize(node)
+ self.represented_objects = {}
+ self.object_keeper = []
+ self.alias_key = None
+
+ def represent_data(self, data):
+ if self.ignore_aliases(data):
+ self.alias_key = None
+ else:
+ self.alias_key = id(data)
+ if self.alias_key is not None:
+ if self.alias_key in self.represented_objects:
+ node = self.represented_objects[self.alias_key]
+ #if node is None:
+ # raise RepresenterError("recursive objects are not allowed: %r" % data)
+ return node
+ #self.represented_objects[alias_key] = None
+ self.object_keeper.append(data)
+ data_types = type(data).__mro__
+ if data_types[0] in self.yaml_representers:
+ node = self.yaml_representers[data_types[0]](self, data)
+ else:
+ for data_type in data_types:
+ if data_type in self.yaml_multi_representers:
+ node = self.yaml_multi_representers[data_type](self, data)
+ break
+ else:
+ if None in self.yaml_multi_representers:
+ node = self.yaml_multi_representers[None](self, data)
+ elif None in self.yaml_representers:
+ node = self.yaml_representers[None](self, data)
+ else:
+ node = ScalarNode(None, str(data))
+ #if alias_key is not None:
+ # self.represented_objects[alias_key] = node
+ return node
+
+ @classmethod
+ def add_representer(cls, data_type, representer):
+ if not 'yaml_representers' in cls.__dict__:
+ cls.yaml_representers = cls.yaml_representers.copy()
+ cls.yaml_representers[data_type] = representer
+
+ @classmethod
+ def add_multi_representer(cls, data_type, representer):
+ if not 'yaml_multi_representers' in cls.__dict__:
+ cls.yaml_multi_representers = cls.yaml_multi_representers.copy()
+ cls.yaml_multi_representers[data_type] = representer
+
+ def represent_scalar(self, tag, value, style=None):
+ if style is None:
+ style = self.default_style
+ node = ScalarNode(tag, value, style=style)
+ if self.alias_key is not None:
+ self.represented_objects[self.alias_key] = node
+ return node
+
+ def represent_sequence(self, tag, sequence, flow_style=None):
+ value = []
+ node = SequenceNode(tag, value, flow_style=flow_style)
+ if self.alias_key is not None:
+ self.represented_objects[self.alias_key] = node
+ best_style = True
+ for item in sequence:
+ node_item = self.represent_data(item)
+ if not (isinstance(node_item, ScalarNode) and not node_item.style):
+ best_style = False
+ value.append(node_item)
+ if flow_style is None:
+ if self.default_flow_style is not None:
+ node.flow_style = self.default_flow_style
+ else:
+ node.flow_style = best_style
+ return node
+
+ def represent_mapping(self, tag, mapping, flow_style=None):
+ value = []
+ node = MappingNode(tag, value, flow_style=flow_style)
+ if self.alias_key is not None:
+ self.represented_objects[self.alias_key] = node
+ best_style = True
+ if hasattr(mapping, 'items'):
+ mapping = list(mapping.items())
+ if self.sort_keys:
+ try:
+ mapping = sorted(mapping)
+ except TypeError:
+ pass
+ for item_key, item_value in mapping:
+ node_key = self.represent_data(item_key)
+ node_value = self.represent_data(item_value)
+ if not (isinstance(node_key, ScalarNode) and not node_key.style):
+ best_style = False
+ if not (isinstance(node_value, ScalarNode) and not node_value.style):
+ best_style = False
+ value.append((node_key, node_value))
+ if flow_style is None:
+ if self.default_flow_style is not None:
+ node.flow_style = self.default_flow_style
+ else:
+ node.flow_style = best_style
+ return node
+
+ def ignore_aliases(self, data):
+ return False
+
+class SafeRepresenter(BaseRepresenter):
+
+ def ignore_aliases(self, data):
+ if data is None:
+ return True
+ if isinstance(data, tuple) and data == ():
+ return True
+ if isinstance(data, (str, bytes, bool, int, float)):
+ return True
+
+ def represent_none(self, data):
+ return self.represent_scalar('tag:yaml.org,2002:null', 'null')
+
+ def represent_str(self, data):
+ return self.represent_scalar('tag:yaml.org,2002:str', data)
+
+ def represent_binary(self, data):
+ if hasattr(base64, 'encodebytes'):
+ data = base64.encodebytes(data).decode('ascii')
+ else:
+ data = base64.encodestring(data).decode('ascii')
+ return self.represent_scalar('tag:yaml.org,2002:binary', data, style='|')
+
+ def represent_bool(self, data):
+ if data:
+ value = 'true'
+ else:
+ value = 'false'
+ return self.represent_scalar('tag:yaml.org,2002:bool', value)
+
+ def represent_int(self, data):
+ return self.represent_scalar('tag:yaml.org,2002:int', str(data))
+
+ inf_value = 1e300
+ while repr(inf_value) != repr(inf_value*inf_value):
+ inf_value *= inf_value
+
+ def represent_float(self, data):
+ if data != data or (data == 0.0 and data == 1.0):
+ value = '.nan'
+ elif data == self.inf_value:
+ value = '.inf'
+ elif data == -self.inf_value:
+ value = '-.inf'
+ else:
+ value = repr(data).lower()
+ # Note that in some cases `repr(data)` represents a float number
+ # without the decimal parts. For instance:
+ # >>> repr(1e17)
+ # '1e17'
+ # Unfortunately, this is not a valid float representation according
+ # to the definition of the `!!float` tag. We fix this by adding
+ # '.0' before the 'e' symbol.
+ if '.' not in value and 'e' in value:
+ value = value.replace('e', '.0e', 1)
+ return self.represent_scalar('tag:yaml.org,2002:float', value)
+
+ def represent_list(self, data):
+ #pairs = (len(data) > 0 and isinstance(data, list))
+ #if pairs:
+ # for item in data:
+ # if not isinstance(item, tuple) or len(item) != 2:
+ # pairs = False
+ # break
+ #if not pairs:
+ return self.represent_sequence('tag:yaml.org,2002:seq', data)
+ #value = []
+ #for item_key, item_value in data:
+ # value.append(self.represent_mapping(u'tag:yaml.org,2002:map',
+ # [(item_key, item_value)]))
+ #return SequenceNode(u'tag:yaml.org,2002:pairs', value)
+
+ def represent_dict(self, data):
+ return self.represent_mapping('tag:yaml.org,2002:map', data)
+
+ def represent_set(self, data):
+ value = {}
+ for key in data:
+ value[key] = None
+ return self.represent_mapping('tag:yaml.org,2002:set', value)
+
+ def represent_date(self, data):
+ value = data.isoformat()
+ return self.represent_scalar('tag:yaml.org,2002:timestamp', value)
+
+ def represent_datetime(self, data):
+ value = data.isoformat(' ')
+ return self.represent_scalar('tag:yaml.org,2002:timestamp', value)
+
+ def represent_yaml_object(self, tag, data, cls, flow_style=None):
+ if hasattr(data, '__getstate__'):
+ state = data.__getstate__()
+ else:
+ state = data.__dict__.copy()
+ return self.represent_mapping(tag, state, flow_style=flow_style)
+
+ def represent_undefined(self, data):
+ raise RepresenterError("cannot represent an object", data)
+
+SafeRepresenter.add_representer(type(None),
+ SafeRepresenter.represent_none)
+
+SafeRepresenter.add_representer(str,
+ SafeRepresenter.represent_str)
+
+SafeRepresenter.add_representer(bytes,
+ SafeRepresenter.represent_binary)
+
+SafeRepresenter.add_representer(bool,
+ SafeRepresenter.represent_bool)
+
+SafeRepresenter.add_representer(int,
+ SafeRepresenter.represent_int)
+
+SafeRepresenter.add_representer(float,
+ SafeRepresenter.represent_float)
+
+SafeRepresenter.add_representer(list,
+ SafeRepresenter.represent_list)
+
+SafeRepresenter.add_representer(tuple,
+ SafeRepresenter.represent_list)
+
+SafeRepresenter.add_representer(dict,
+ SafeRepresenter.represent_dict)
+
+SafeRepresenter.add_representer(set,
+ SafeRepresenter.represent_set)
+
+SafeRepresenter.add_representer(datetime.date,
+ SafeRepresenter.represent_date)
+
+SafeRepresenter.add_representer(datetime.datetime,
+ SafeRepresenter.represent_datetime)
+
+SafeRepresenter.add_representer(None,
+ SafeRepresenter.represent_undefined)
+
+class Representer(SafeRepresenter):
+
+ def represent_complex(self, data):
+ if data.imag == 0.0:
+ data = '%r' % data.real
+ elif data.real == 0.0:
+ data = '%rj' % data.imag
+ elif data.imag > 0:
+ data = '%r+%rj' % (data.real, data.imag)
+ else:
+ data = '%r%rj' % (data.real, data.imag)
+ return self.represent_scalar('tag:yaml.org,2002:python/complex', data)
+
+ def represent_tuple(self, data):
+ return self.represent_sequence('tag:yaml.org,2002:python/tuple', data)
+
+ def represent_name(self, data):
+ name = '%s.%s' % (data.__module__, data.__name__)
+ return self.represent_scalar('tag:yaml.org,2002:python/name:'+name, '')
+
+ def represent_module(self, data):
+ return self.represent_scalar(
+ 'tag:yaml.org,2002:python/module:'+data.__name__, '')
+
+ def represent_object(self, data):
+ # We use __reduce__ API to save the data. data.__reduce__ returns
+ # a tuple of length 2-5:
+ # (function, args, state, listitems, dictitems)
+
+ # For reconstructing, we calls function(*args), then set its state,
+ # listitems, and dictitems if they are not None.
+
+ # A special case is when function.__name__ == '__newobj__'. In this
+ # case we create the object with args[0].__new__(*args).
+
+ # Another special case is when __reduce__ returns a string - we don't
+ # support it.
+
+ # We produce a !!python/object, !!python/object/new or
+ # !!python/object/apply node.
+
+ cls = type(data)
+ if cls in copyreg.dispatch_table:
+ reduce = copyreg.dispatch_table[cls](data)
+ elif hasattr(data, '__reduce_ex__'):
+ reduce = data.__reduce_ex__(2)
+ elif hasattr(data, '__reduce__'):
+ reduce = data.__reduce__()
+ else:
+ raise RepresenterError("cannot represent an object", data)
+ reduce = (list(reduce)+[None]*5)[:5]
+ function, args, state, listitems, dictitems = reduce
+ args = list(args)
+ if state is None:
+ state = {}
+ if listitems is not None:
+ listitems = list(listitems)
+ if dictitems is not None:
+ dictitems = dict(dictitems)
+ if function.__name__ == '__newobj__':
+ function = args[0]
+ args = args[1:]
+ tag = 'tag:yaml.org,2002:python/object/new:'
+ newobj = True
+ else:
+ tag = 'tag:yaml.org,2002:python/object/apply:'
+ newobj = False
+ function_name = '%s.%s' % (function.__module__, function.__name__)
+ if not args and not listitems and not dictitems \
+ and isinstance(state, dict) and newobj:
+ return self.represent_mapping(
+ 'tag:yaml.org,2002:python/object:'+function_name, state)
+ if not listitems and not dictitems \
+ and isinstance(state, dict) and not state:
+ return self.represent_sequence(tag+function_name, args)
+ value = {}
+ if args:
+ value['args'] = args
+ if state or not isinstance(state, dict):
+ value['state'] = state
+ if listitems:
+ value['listitems'] = listitems
+ if dictitems:
+ value['dictitems'] = dictitems
+ return self.represent_mapping(tag+function_name, value)
+
+ def represent_ordered_dict(self, data):
+ # Provide uniform representation across different Python versions.
+ data_type = type(data)
+ tag = 'tag:yaml.org,2002:python/object/apply:%s.%s' \
+ % (data_type.__module__, data_type.__name__)
+ items = [[key, value] for key, value in data.items()]
+ return self.represent_sequence(tag, [items])
+
+Representer.add_representer(complex,
+ Representer.represent_complex)
+
+Representer.add_representer(tuple,
+ Representer.represent_tuple)
+
+Representer.add_representer(type,
+ Representer.represent_name)
+
+Representer.add_representer(collections.OrderedDict,
+ Representer.represent_ordered_dict)
+
+Representer.add_representer(types.FunctionType,
+ Representer.represent_name)
+
+Representer.add_representer(types.BuiltinFunctionType,
+ Representer.represent_name)
+
+Representer.add_representer(types.ModuleType,
+ Representer.represent_module)
+
+Representer.add_multi_representer(object,
+ Representer.represent_object)
+
diff --git a/libs/yaml/resolver.py b/libs/yaml/resolver.py
new file mode 100644
index 000000000..02b82e73e
--- /dev/null
+++ b/libs/yaml/resolver.py
@@ -0,0 +1,227 @@
+
+__all__ = ['BaseResolver', 'Resolver']
+
+from .error import *
+from .nodes import *
+
+import re
+
+class ResolverError(YAMLError):
+ pass
+
+class BaseResolver:
+
+ DEFAULT_SCALAR_TAG = 'tag:yaml.org,2002:str'
+ DEFAULT_SEQUENCE_TAG = 'tag:yaml.org,2002:seq'
+ DEFAULT_MAPPING_TAG = 'tag:yaml.org,2002:map'
+
+ yaml_implicit_resolvers = {}
+ yaml_path_resolvers = {}
+
+ def __init__(self):
+ self.resolver_exact_paths = []
+ self.resolver_prefix_paths = []
+
+ @classmethod
+ def add_implicit_resolver(cls, tag, regexp, first):
+ if not 'yaml_implicit_resolvers' in cls.__dict__:
+ implicit_resolvers = {}
+ for key in cls.yaml_implicit_resolvers:
+ implicit_resolvers[key] = cls.yaml_implicit_resolvers[key][:]
+ cls.yaml_implicit_resolvers = implicit_resolvers
+ if first is None:
+ first = [None]
+ for ch in first:
+ cls.yaml_implicit_resolvers.setdefault(ch, []).append((tag, regexp))
+
+ @classmethod
+ def add_path_resolver(cls, tag, path, kind=None):
+ # Note: `add_path_resolver` is experimental. The API could be changed.
+ # `new_path` is a pattern that is matched against the path from the
+ # root to the node that is being considered. `node_path` elements are
+ # tuples `(node_check, index_check)`. `node_check` is a node class:
+ # `ScalarNode`, `SequenceNode`, `MappingNode` or `None`. `None`
+ # matches any kind of a node. `index_check` could be `None`, a boolean
+ # value, a string value, or a number. `None` and `False` match against
+ # any _value_ of sequence and mapping nodes. `True` matches against
+ # any _key_ of a mapping node. A string `index_check` matches against
+ # a mapping value that corresponds to a scalar key which content is
+ # equal to the `index_check` value. An integer `index_check` matches
+ # against a sequence value with the index equal to `index_check`.
+ if not 'yaml_path_resolvers' in cls.__dict__:
+ cls.yaml_path_resolvers = cls.yaml_path_resolvers.copy()
+ new_path = []
+ for element in path:
+ if isinstance(element, (list, tuple)):
+ if len(element) == 2:
+ node_check, index_check = element
+ elif len(element) == 1:
+ node_check = element[0]
+ index_check = True
+ else:
+ raise ResolverError("Invalid path element: %s" % element)
+ else:
+ node_check = None
+ index_check = element
+ if node_check is str:
+ node_check = ScalarNode
+ elif node_check is list:
+ node_check = SequenceNode
+ elif node_check is dict:
+ node_check = MappingNode
+ elif node_check not in [ScalarNode, SequenceNode, MappingNode] \
+ and not isinstance(node_check, str) \
+ and node_check is not None:
+ raise ResolverError("Invalid node checker: %s" % node_check)
+ if not isinstance(index_check, (str, int)) \
+ and index_check is not None:
+ raise ResolverError("Invalid index checker: %s" % index_check)
+ new_path.append((node_check, index_check))
+ if kind is str:
+ kind = ScalarNode
+ elif kind is list:
+ kind = SequenceNode
+ elif kind is dict:
+ kind = MappingNode
+ elif kind not in [ScalarNode, SequenceNode, MappingNode] \
+ and kind is not None:
+ raise ResolverError("Invalid node kind: %s" % kind)
+ cls.yaml_path_resolvers[tuple(new_path), kind] = tag
+
+ def descend_resolver(self, current_node, current_index):
+ if not self.yaml_path_resolvers:
+ return
+ exact_paths = {}
+ prefix_paths = []
+ if current_node:
+ depth = len(self.resolver_prefix_paths)
+ for path, kind in self.resolver_prefix_paths[-1]:
+ if self.check_resolver_prefix(depth, path, kind,
+ current_node, current_index):
+ if len(path) > depth:
+ prefix_paths.append((path, kind))
+ else:
+ exact_paths[kind] = self.yaml_path_resolvers[path, kind]
+ else:
+ for path, kind in self.yaml_path_resolvers:
+ if not path:
+ exact_paths[kind] = self.yaml_path_resolvers[path, kind]
+ else:
+ prefix_paths.append((path, kind))
+ self.resolver_exact_paths.append(exact_paths)
+ self.resolver_prefix_paths.append(prefix_paths)
+
+ def ascend_resolver(self):
+ if not self.yaml_path_resolvers:
+ return
+ self.resolver_exact_paths.pop()
+ self.resolver_prefix_paths.pop()
+
+ def check_resolver_prefix(self, depth, path, kind,
+ current_node, current_index):
+ node_check, index_check = path[depth-1]
+ if isinstance(node_check, str):
+ if current_node.tag != node_check:
+ return
+ elif node_check is not None:
+ if not isinstance(current_node, node_check):
+ return
+ if index_check is True and current_index is not None:
+ return
+ if (index_check is False or index_check is None) \
+ and current_index is None:
+ return
+ if isinstance(index_check, str):
+ if not (isinstance(current_index, ScalarNode)
+ and index_check == current_index.value):
+ return
+ elif isinstance(index_check, int) and not isinstance(index_check, bool):
+ if index_check != current_index:
+ return
+ return True
+
+ def resolve(self, kind, value, implicit):
+ if kind is ScalarNode and implicit[0]:
+ if value == '':
+ resolvers = self.yaml_implicit_resolvers.get('', [])
+ else:
+ resolvers = self.yaml_implicit_resolvers.get(value[0], [])
+ resolvers += self.yaml_implicit_resolvers.get(None, [])
+ for tag, regexp in resolvers:
+ if regexp.match(value):
+ return tag
+ implicit = implicit[1]
+ if self.yaml_path_resolvers:
+ exact_paths = self.resolver_exact_paths[-1]
+ if kind in exact_paths:
+ return exact_paths[kind]
+ if None in exact_paths:
+ return exact_paths[None]
+ if kind is ScalarNode:
+ return self.DEFAULT_SCALAR_TAG
+ elif kind is SequenceNode:
+ return self.DEFAULT_SEQUENCE_TAG
+ elif kind is MappingNode:
+ return self.DEFAULT_MAPPING_TAG
+
+class Resolver(BaseResolver):
+ pass
+
+Resolver.add_implicit_resolver(
+ 'tag:yaml.org,2002:bool',
+ re.compile(r'''^(?:yes|Yes|YES|no|No|NO
+ |true|True|TRUE|false|False|FALSE
+ |on|On|ON|off|Off|OFF)$''', re.X),
+ list('yYnNtTfFoO'))
+
+Resolver.add_implicit_resolver(
+ 'tag:yaml.org,2002:float',
+ re.compile(r'''^(?:[-+]?(?:[0-9][0-9_]*)\.[0-9_]*(?:[eE][-+][0-9]+)?
+ |\.[0-9_]+(?:[eE][-+][0-9]+)?
+ |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\.[0-9_]*
+ |[-+]?\.(?:inf|Inf|INF)
+ |\.(?:nan|NaN|NAN))$''', re.X),
+ list('-+0123456789.'))
+
+Resolver.add_implicit_resolver(
+ 'tag:yaml.org,2002:int',
+ re.compile(r'''^(?:[-+]?0b[0-1_]+
+ |[-+]?0[0-7_]+
+ |[-+]?(?:0|[1-9][0-9_]*)
+ |[-+]?0x[0-9a-fA-F_]+
+ |[-+]?[1-9][0-9_]*(?::[0-5]?[0-9])+)$''', re.X),
+ list('-+0123456789'))
+
+Resolver.add_implicit_resolver(
+ 'tag:yaml.org,2002:merge',
+ re.compile(r'^(?:<<)$'),
+ ['<'])
+
+Resolver.add_implicit_resolver(
+ 'tag:yaml.org,2002:null',
+ re.compile(r'''^(?: ~
+ |null|Null|NULL
+ | )$''', re.X),
+ ['~', 'n', 'N', ''])
+
+Resolver.add_implicit_resolver(
+ 'tag:yaml.org,2002:timestamp',
+ re.compile(r'''^(?:[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]
+ |[0-9][0-9][0-9][0-9] -[0-9][0-9]? -[0-9][0-9]?
+ (?:[Tt]|[ \t]+)[0-9][0-9]?
+ :[0-9][0-9] :[0-9][0-9] (?:\.[0-9]*)?
+ (?:[ \t]*(?:Z|[-+][0-9][0-9]?(?::[0-9][0-9])?))?)$''', re.X),
+ list('0123456789'))
+
+Resolver.add_implicit_resolver(
+ 'tag:yaml.org,2002:value',
+ re.compile(r'^(?:=)$'),
+ ['='])
+
+# The following resolver is only for documentation purposes. It cannot work
+# because plain scalars cannot start with '!', '&', or '*'.
+Resolver.add_implicit_resolver(
+ 'tag:yaml.org,2002:yaml',
+ re.compile(r'^(?:!|&|\*)$'),
+ list('!&*'))
+
diff --git a/libs/yaml/scanner.py b/libs/yaml/scanner.py
new file mode 100644
index 000000000..775dbcc6d
--- /dev/null
+++ b/libs/yaml/scanner.py
@@ -0,0 +1,1435 @@
+
+# Scanner produces tokens of the following types:
+# STREAM-START
+# STREAM-END
+# DIRECTIVE(name, value)
+# DOCUMENT-START
+# DOCUMENT-END
+# BLOCK-SEQUENCE-START
+# BLOCK-MAPPING-START
+# BLOCK-END
+# FLOW-SEQUENCE-START
+# FLOW-MAPPING-START
+# FLOW-SEQUENCE-END
+# FLOW-MAPPING-END
+# BLOCK-ENTRY
+# FLOW-ENTRY
+# KEY
+# VALUE
+# ALIAS(value)
+# ANCHOR(value)
+# TAG(value)
+# SCALAR(value, plain, style)
+#
+# Read comments in the Scanner code for more details.
+#
+
+__all__ = ['Scanner', 'ScannerError']
+
+from .error import MarkedYAMLError
+from .tokens import *
+
+class ScannerError(MarkedYAMLError):
+ pass
+
+class SimpleKey:
+ # See below simple keys treatment.
+
+ def __init__(self, token_number, required, index, line, column, mark):
+ self.token_number = token_number
+ self.required = required
+ self.index = index
+ self.line = line
+ self.column = column
+ self.mark = mark
+
+class Scanner:
+
+ def __init__(self):
+ """Initialize the scanner."""
+ # It is assumed that Scanner and Reader will have a common descendant.
+ # Reader do the dirty work of checking for BOM and converting the
+ # input data to Unicode. It also adds NUL to the end.
+ #
+ # Reader supports the following methods
+ # self.peek(i=0) # peek the next i-th character
+ # self.prefix(l=1) # peek the next l characters
+ # self.forward(l=1) # read the next l characters and move the pointer.
+
+ # Had we reached the end of the stream?
+ self.done = False
+
+ # The number of unclosed '{' and '['. `flow_level == 0` means block
+ # context.
+ self.flow_level = 0
+
+ # List of processed tokens that are not yet emitted.
+ self.tokens = []
+
+ # Add the STREAM-START token.
+ self.fetch_stream_start()
+
+ # Number of tokens that were emitted through the `get_token` method.
+ self.tokens_taken = 0
+
+ # The current indentation level.
+ self.indent = -1
+
+ # Past indentation levels.
+ self.indents = []
+
+ # Variables related to simple keys treatment.
+
+ # A simple key is a key that is not denoted by the '?' indicator.
+ # Example of simple keys:
+ # ---
+ # block simple key: value
+ # ? not a simple key:
+ # : { flow simple key: value }
+ # We emit the KEY token before all keys, so when we find a potential
+ # simple key, we try to locate the corresponding ':' indicator.
+ # Simple keys should be limited to a single line and 1024 characters.
+
+ # Can a simple key start at the current position? A simple key may
+ # start:
+ # - at the beginning of the line, not counting indentation spaces
+ # (in block context),
+ # - after '{', '[', ',' (in the flow context),
+ # - after '?', ':', '-' (in the block context).
+ # In the block context, this flag also signifies if a block collection
+ # may start at the current position.
+ self.allow_simple_key = True
+
+ # Keep track of possible simple keys. This is a dictionary. The key
+ # is `flow_level`; there can be no more that one possible simple key
+ # for each level. The value is a SimpleKey record:
+ # (token_number, required, index, line, column, mark)
+ # A simple key may start with ALIAS, ANCHOR, TAG, SCALAR(flow),
+ # '[', or '{' tokens.
+ self.possible_simple_keys = {}
+
+ # Public methods.
+
+ def check_token(self, *choices):
+ # Check if the next token is one of the given types.
+ while self.need_more_tokens():
+ self.fetch_more_tokens()
+ if self.tokens:
+ if not choices:
+ return True
+ for choice in choices:
+ if isinstance(self.tokens[0], choice):
+ return True
+ return False
+
+ def peek_token(self):
+ # Return the next token, but do not delete if from the queue.
+ # Return None if no more tokens.
+ while self.need_more_tokens():
+ self.fetch_more_tokens()
+ if self.tokens:
+ return self.tokens[0]
+ else:
+ return None
+
+ def get_token(self):
+ # Return the next token.
+ while self.need_more_tokens():
+ self.fetch_more_tokens()
+ if self.tokens:
+ self.tokens_taken += 1
+ return self.tokens.pop(0)
+
+ # Private methods.
+
+ def need_more_tokens(self):
+ if self.done:
+ return False
+ if not self.tokens:
+ return True
+ # The current token may be a potential simple key, so we
+ # need to look further.
+ self.stale_possible_simple_keys()
+ if self.next_possible_simple_key() == self.tokens_taken:
+ return True
+
+ def fetch_more_tokens(self):
+
+ # Eat whitespaces and comments until we reach the next token.
+ self.scan_to_next_token()
+
+ # Remove obsolete possible simple keys.
+ self.stale_possible_simple_keys()
+
+ # Compare the current indentation and column. It may add some tokens
+ # and decrease the current indentation level.
+ self.unwind_indent(self.column)
+
+ # Peek the next character.
+ ch = self.peek()
+
+ # Is it the end of stream?
+ if ch == '\0':
+ return self.fetch_stream_end()
+
+ # Is it a directive?
+ if ch == '%' and self.check_directive():
+ return self.fetch_directive()
+
+ # Is it the document start?
+ if ch == '-' and self.check_document_start():
+ return self.fetch_document_start()
+
+ # Is it the document end?
+ if ch == '.' and self.check_document_end():
+ return self.fetch_document_end()
+
+ # TODO: support for BOM within a stream.
+ #if ch == '\uFEFF':
+ # return self.fetch_bom() <-- issue BOMToken
+
+ # Note: the order of the following checks is NOT significant.
+
+ # Is it the flow sequence start indicator?
+ if ch == '[':
+ return self.fetch_flow_sequence_start()
+
+ # Is it the flow mapping start indicator?
+ if ch == '{':
+ return self.fetch_flow_mapping_start()
+
+ # Is it the flow sequence end indicator?
+ if ch == ']':
+ return self.fetch_flow_sequence_end()
+
+ # Is it the flow mapping end indicator?
+ if ch == '}':
+ return self.fetch_flow_mapping_end()
+
+ # Is it the flow entry indicator?
+ if ch == ',':
+ return self.fetch_flow_entry()
+
+ # Is it the block entry indicator?
+ if ch == '-' and self.check_block_entry():
+ return self.fetch_block_entry()
+
+ # Is it the key indicator?
+ if ch == '?' and self.check_key():
+ return self.fetch_key()
+
+ # Is it the value indicator?
+ if ch == ':' and self.check_value():
+ return self.fetch_value()
+
+ # Is it an alias?
+ if ch == '*':
+ return self.fetch_alias()
+
+ # Is it an anchor?
+ if ch == '&':
+ return self.fetch_anchor()
+
+ # Is it a tag?
+ if ch == '!':
+ return self.fetch_tag()
+
+ # Is it a literal scalar?
+ if ch == '|' and not self.flow_level:
+ return self.fetch_literal()
+
+ # Is it a folded scalar?
+ if ch == '>' and not self.flow_level:
+ return self.fetch_folded()
+
+ # Is it a single quoted scalar?
+ if ch == '\'':
+ return self.fetch_single()
+
+ # Is it a double quoted scalar?
+ if ch == '\"':
+ return self.fetch_double()
+
+ # It must be a plain scalar then.
+ if self.check_plain():
+ return self.fetch_plain()
+
+ # No? It's an error. Let's produce a nice error message.
+ raise ScannerError("while scanning for the next token", None,
+ "found character %r that cannot start any token" % ch,
+ self.get_mark())
+
+ # Simple keys treatment.
+
+ def next_possible_simple_key(self):
+ # Return the number of the nearest possible simple key. Actually we
+ # don't need to loop through the whole dictionary. We may replace it
+ # with the following code:
+ # if not self.possible_simple_keys:
+ # return None
+ # return self.possible_simple_keys[
+ # min(self.possible_simple_keys.keys())].token_number
+ min_token_number = None
+ for level in self.possible_simple_keys:
+ key = self.possible_simple_keys[level]
+ if min_token_number is None or key.token_number < min_token_number:
+ min_token_number = key.token_number
+ return min_token_number
+
+ def stale_possible_simple_keys(self):
+ # Remove entries that are no longer possible simple keys. According to
+ # the YAML specification, simple keys
+ # - should be limited to a single line,
+ # - should be no longer than 1024 characters.
+ # Disabling this procedure will allow simple keys of any length and
+ # height (may cause problems if indentation is broken though).
+ for level in list(self.possible_simple_keys):
+ key = self.possible_simple_keys[level]
+ if key.line != self.line \
+ or self.index-key.index > 1024:
+ if key.required:
+ raise ScannerError("while scanning a simple key", key.mark,
+ "could not find expected ':'", self.get_mark())
+ del self.possible_simple_keys[level]
+
+ def save_possible_simple_key(self):
+ # The next token may start a simple key. We check if it's possible
+ # and save its position. This function is called for
+ # ALIAS, ANCHOR, TAG, SCALAR(flow), '[', and '{'.
+
+ # Check if a simple key is required at the current position.
+ required = not self.flow_level and self.indent == self.column
+
+ # The next token might be a simple key. Let's save it's number and
+ # position.
+ if self.allow_simple_key:
+ self.remove_possible_simple_key()
+ token_number = self.tokens_taken+len(self.tokens)
+ key = SimpleKey(token_number, required,
+ self.index, self.line, self.column, self.get_mark())
+ self.possible_simple_keys[self.flow_level] = key
+
+ def remove_possible_simple_key(self):
+ # Remove the saved possible key position at the current flow level.
+ if self.flow_level in self.possible_simple_keys:
+ key = self.possible_simple_keys[self.flow_level]
+
+ if key.required:
+ raise ScannerError("while scanning a simple key", key.mark,
+ "could not find expected ':'", self.get_mark())
+
+ del self.possible_simple_keys[self.flow_level]
+
+ # Indentation functions.
+
+ def unwind_indent(self, column):
+
+ ## In flow context, tokens should respect indentation.
+ ## Actually the condition should be `self.indent >= column` according to
+ ## the spec. But this condition will prohibit intuitively correct
+ ## constructions such as
+ ## key : {
+ ## }
+ #if self.flow_level and self.indent > column:
+ # raise ScannerError(None, None,
+ # "invalid intendation or unclosed '[' or '{'",
+ # self.get_mark())
+
+ # In the flow context, indentation is ignored. We make the scanner less
+ # restrictive then specification requires.
+ if self.flow_level:
+ return
+
+ # In block context, we may need to issue the BLOCK-END tokens.
+ while self.indent > column:
+ mark = self.get_mark()
+ self.indent = self.indents.pop()
+ self.tokens.append(BlockEndToken(mark, mark))
+
+ def add_indent(self, column):
+ # Check if we need to increase indentation.
+ if self.indent < column:
+ self.indents.append(self.indent)
+ self.indent = column
+ return True
+ return False
+
+ # Fetchers.
+
+ def fetch_stream_start(self):
+ # We always add STREAM-START as the first token and STREAM-END as the
+ # last token.
+
+ # Read the token.
+ mark = self.get_mark()
+
+ # Add STREAM-START.
+ self.tokens.append(StreamStartToken(mark, mark,
+ encoding=self.encoding))
+
+
+ def fetch_stream_end(self):
+
+ # Set the current intendation to -1.
+ self.unwind_indent(-1)
+
+ # Reset simple keys.
+ self.remove_possible_simple_key()
+ self.allow_simple_key = False
+ self.possible_simple_keys = {}
+
+ # Read the token.
+ mark = self.get_mark()
+
+ # Add STREAM-END.
+ self.tokens.append(StreamEndToken(mark, mark))
+
+ # The steam is finished.
+ self.done = True
+
+ def fetch_directive(self):
+
+ # Set the current intendation to -1.
+ self.unwind_indent(-1)
+
+ # Reset simple keys.
+ self.remove_possible_simple_key()
+ self.allow_simple_key = False
+
+ # Scan and add DIRECTIVE.
+ self.tokens.append(self.scan_directive())
+
+ def fetch_document_start(self):
+ self.fetch_document_indicator(DocumentStartToken)
+
+ def fetch_document_end(self):
+ self.fetch_document_indicator(DocumentEndToken)
+
+ def fetch_document_indicator(self, TokenClass):
+
+ # Set the current intendation to -1.
+ self.unwind_indent(-1)
+
+ # Reset simple keys. Note that there could not be a block collection
+ # after '---'.
+ self.remove_possible_simple_key()
+ self.allow_simple_key = False
+
+ # Add DOCUMENT-START or DOCUMENT-END.
+ start_mark = self.get_mark()
+ self.forward(3)
+ end_mark = self.get_mark()
+ self.tokens.append(TokenClass(start_mark, end_mark))
+
+ def fetch_flow_sequence_start(self):
+ self.fetch_flow_collection_start(FlowSequenceStartToken)
+
+ def fetch_flow_mapping_start(self):
+ self.fetch_flow_collection_start(FlowMappingStartToken)
+
+ def fetch_flow_collection_start(self, TokenClass):
+
+ # '[' and '{' may start a simple key.
+ self.save_possible_simple_key()
+
+ # Increase the flow level.
+ self.flow_level += 1
+
+ # Simple keys are allowed after '[' and '{'.
+ self.allow_simple_key = True
+
+ # Add FLOW-SEQUENCE-START or FLOW-MAPPING-START.
+ start_mark = self.get_mark()
+ self.forward()
+ end_mark = self.get_mark()
+ self.tokens.append(TokenClass(start_mark, end_mark))
+
+ def fetch_flow_sequence_end(self):
+ self.fetch_flow_collection_end(FlowSequenceEndToken)
+
+ def fetch_flow_mapping_end(self):
+ self.fetch_flow_collection_end(FlowMappingEndToken)
+
+ def fetch_flow_collection_end(self, TokenClass):
+
+ # Reset possible simple key on the current level.
+ self.remove_possible_simple_key()
+
+ # Decrease the flow level.
+ self.flow_level -= 1
+
+ # No simple keys after ']' or '}'.
+ self.allow_simple_key = False
+
+ # Add FLOW-SEQUENCE-END or FLOW-MAPPING-END.
+ start_mark = self.get_mark()
+ self.forward()
+ end_mark = self.get_mark()
+ self.tokens.append(TokenClass(start_mark, end_mark))
+
+ def fetch_flow_entry(self):
+
+ # Simple keys are allowed after ','.
+ self.allow_simple_key = True
+
+ # Reset possible simple key on the current level.
+ self.remove_possible_simple_key()
+
+ # Add FLOW-ENTRY.
+ start_mark = self.get_mark()
+ self.forward()
+ end_mark = self.get_mark()
+ self.tokens.append(FlowEntryToken(start_mark, end_mark))
+
+ def fetch_block_entry(self):
+
+ # Block context needs additional checks.
+ if not self.flow_level:
+
+ # Are we allowed to start a new entry?
+ if not self.allow_simple_key:
+ raise ScannerError(None, None,
+ "sequence entries are not allowed here",
+ self.get_mark())
+
+ # We may need to add BLOCK-SEQUENCE-START.
+ if self.add_indent(self.column):
+ mark = self.get_mark()
+ self.tokens.append(BlockSequenceStartToken(mark, mark))
+
+ # It's an error for the block entry to occur in the flow context,
+ # but we let the parser detect this.
+ else:
+ pass
+
+ # Simple keys are allowed after '-'.
+ self.allow_simple_key = True
+
+ # Reset possible simple key on the current level.
+ self.remove_possible_simple_key()
+
+ # Add BLOCK-ENTRY.
+ start_mark = self.get_mark()
+ self.forward()
+ end_mark = self.get_mark()
+ self.tokens.append(BlockEntryToken(start_mark, end_mark))
+
+ def fetch_key(self):
+
+ # Block context needs additional checks.
+ if not self.flow_level:
+
+ # Are we allowed to start a key (not necessary a simple)?
+ if not self.allow_simple_key:
+ raise ScannerError(None, None,
+ "mapping keys are not allowed here",
+ self.get_mark())
+
+ # We may need to add BLOCK-MAPPING-START.
+ if self.add_indent(self.column):
+ mark = self.get_mark()
+ self.tokens.append(BlockMappingStartToken(mark, mark))
+
+ # Simple keys are allowed after '?' in the block context.
+ self.allow_simple_key = not self.flow_level
+
+ # Reset possible simple key on the current level.
+ self.remove_possible_simple_key()
+
+ # Add KEY.
+ start_mark = self.get_mark()
+ self.forward()
+ end_mark = self.get_mark()
+ self.tokens.append(KeyToken(start_mark, end_mark))
+
+ def fetch_value(self):
+
+ # Do we determine a simple key?
+ if self.flow_level in self.possible_simple_keys:
+
+ # Add KEY.
+ key = self.possible_simple_keys[self.flow_level]
+ del self.possible_simple_keys[self.flow_level]
+ self.tokens.insert(key.token_number-self.tokens_taken,
+ KeyToken(key.mark, key.mark))
+
+ # If this key starts a new block mapping, we need to add
+ # BLOCK-MAPPING-START.
+ if not self.flow_level:
+ if self.add_indent(key.column):
+ self.tokens.insert(key.token_number-self.tokens_taken,
+ BlockMappingStartToken(key.mark, key.mark))
+
+ # There cannot be two simple keys one after another.
+ self.allow_simple_key = False
+
+ # It must be a part of a complex key.
+ else:
+
+ # Block context needs additional checks.
+ # (Do we really need them? They will be caught by the parser
+ # anyway.)
+ if not self.flow_level:
+
+ # We are allowed to start a complex value if and only if
+ # we can start a simple key.
+ if not self.allow_simple_key:
+ raise ScannerError(None, None,
+ "mapping values are not allowed here",
+ self.get_mark())
+
+ # If this value starts a new block mapping, we need to add
+ # BLOCK-MAPPING-START. It will be detected as an error later by
+ # the parser.
+ if not self.flow_level:
+ if self.add_indent(self.column):
+ mark = self.get_mark()
+ self.tokens.append(BlockMappingStartToken(mark, mark))
+
+ # Simple keys are allowed after ':' in the block context.
+ self.allow_simple_key = not self.flow_level
+
+ # Reset possible simple key on the current level.
+ self.remove_possible_simple_key()
+
+ # Add VALUE.
+ start_mark = self.get_mark()
+ self.forward()
+ end_mark = self.get_mark()
+ self.tokens.append(ValueToken(start_mark, end_mark))
+
+ def fetch_alias(self):
+
+ # ALIAS could be a simple key.
+ self.save_possible_simple_key()
+
+ # No simple keys after ALIAS.
+ self.allow_simple_key = False
+
+ # Scan and add ALIAS.
+ self.tokens.append(self.scan_anchor(AliasToken))
+
+ def fetch_anchor(self):
+
+ # ANCHOR could start a simple key.
+ self.save_possible_simple_key()
+
+ # No simple keys after ANCHOR.
+ self.allow_simple_key = False
+
+ # Scan and add ANCHOR.
+ self.tokens.append(self.scan_anchor(AnchorToken))
+
+ def fetch_tag(self):
+
+ # TAG could start a simple key.
+ self.save_possible_simple_key()
+
+ # No simple keys after TAG.
+ self.allow_simple_key = False
+
+ # Scan and add TAG.
+ self.tokens.append(self.scan_tag())
+
+ def fetch_literal(self):
+ self.fetch_block_scalar(style='|')
+
+ def fetch_folded(self):
+ self.fetch_block_scalar(style='>')
+
+ def fetch_block_scalar(self, style):
+
+ # A simple key may follow a block scalar.
+ self.allow_simple_key = True
+
+ # Reset possible simple key on the current level.
+ self.remove_possible_simple_key()
+
+ # Scan and add SCALAR.
+ self.tokens.append(self.scan_block_scalar(style))
+
+ def fetch_single(self):
+ self.fetch_flow_scalar(style='\'')
+
+ def fetch_double(self):
+ self.fetch_flow_scalar(style='"')
+
+ def fetch_flow_scalar(self, style):
+
+ # A flow scalar could be a simple key.
+ self.save_possible_simple_key()
+
+ # No simple keys after flow scalars.
+ self.allow_simple_key = False
+
+ # Scan and add SCALAR.
+ self.tokens.append(self.scan_flow_scalar(style))
+
+ def fetch_plain(self):
+
+ # A plain scalar could be a simple key.
+ self.save_possible_simple_key()
+
+ # No simple keys after plain scalars. But note that `scan_plain` will
+ # change this flag if the scan is finished at the beginning of the
+ # line.
+ self.allow_simple_key = False
+
+ # Scan and add SCALAR. May change `allow_simple_key`.
+ self.tokens.append(self.scan_plain())
+
+ # Checkers.
+
+ def check_directive(self):
+
+ # DIRECTIVE: ^ '%' ...
+ # The '%' indicator is already checked.
+ if self.column == 0:
+ return True
+
+ def check_document_start(self):
+
+ # DOCUMENT-START: ^ '---' (' '|'\n')
+ if self.column == 0:
+ if self.prefix(3) == '---' \
+ and self.peek(3) in '\0 \t\r\n\x85\u2028\u2029':
+ return True
+
+ def check_document_end(self):
+
+ # DOCUMENT-END: ^ '...' (' '|'\n')
+ if self.column == 0:
+ if self.prefix(3) == '...' \
+ and self.peek(3) in '\0 \t\r\n\x85\u2028\u2029':
+ return True
+
+ def check_block_entry(self):
+
+ # BLOCK-ENTRY: '-' (' '|'\n')
+ return self.peek(1) in '\0 \t\r\n\x85\u2028\u2029'
+
+ def check_key(self):
+
+ # KEY(flow context): '?'
+ if self.flow_level:
+ return True
+
+ # KEY(block context): '?' (' '|'\n')
+ else:
+ return self.peek(1) in '\0 \t\r\n\x85\u2028\u2029'
+
+ def check_value(self):
+
+ # VALUE(flow context): ':'
+ if self.flow_level:
+ return True
+
+ # VALUE(block context): ':' (' '|'\n')
+ else:
+ return self.peek(1) in '\0 \t\r\n\x85\u2028\u2029'
+
+ def check_plain(self):
+
+ # A plain scalar may start with any non-space character except:
+ # '-', '?', ':', ',', '[', ']', '{', '}',
+ # '#', '&', '*', '!', '|', '>', '\'', '\"',
+ # '%', '@', '`'.
+ #
+ # It may also start with
+ # '-', '?', ':'
+ # if it is followed by a non-space character.
+ #
+ # Note that we limit the last rule to the block context (except the
+ # '-' character) because we want the flow context to be space
+ # independent.
+ ch = self.peek()
+ return ch not in '\0 \t\r\n\x85\u2028\u2029-?:,[]{}#&*!|>\'\"%@`' \
+ or (self.peek(1) not in '\0 \t\r\n\x85\u2028\u2029'
+ and (ch == '-' or (not self.flow_level and ch in '?:')))
+
+ # Scanners.
+
+ def scan_to_next_token(self):
+ # We ignore spaces, line breaks and comments.
+ # If we find a line break in the block context, we set the flag
+ # `allow_simple_key` on.
+ # The byte order mark is stripped if it's the first character in the
+ # stream. We do not yet support BOM inside the stream as the
+ # specification requires. Any such mark will be considered as a part
+ # of the document.
+ #
+ # TODO: We need to make tab handling rules more sane. A good rule is
+ # Tabs cannot precede tokens
+ # BLOCK-SEQUENCE-START, BLOCK-MAPPING-START, BLOCK-END,
+ # KEY(block), VALUE(block), BLOCK-ENTRY
+ # So the checking code is
+ # if <TAB>:
+ # self.allow_simple_keys = False
+ # We also need to add the check for `allow_simple_keys == True` to
+ # `unwind_indent` before issuing BLOCK-END.
+ # Scanners for block, flow, and plain scalars need to be modified.
+
+ if self.index == 0 and self.peek() == '\uFEFF':
+ self.forward()
+ found = False
+ while not found:
+ while self.peek() == ' ':
+ self.forward()
+ if self.peek() == '#':
+ while self.peek() not in '\0\r\n\x85\u2028\u2029':
+ self.forward()
+ if self.scan_line_break():
+ if not self.flow_level:
+ self.allow_simple_key = True
+ else:
+ found = True
+
+ def scan_directive(self):
+ # See the specification for details.
+ start_mark = self.get_mark()
+ self.forward()
+ name = self.scan_directive_name(start_mark)
+ value = None
+ if name == 'YAML':
+ value = self.scan_yaml_directive_value(start_mark)
+ end_mark = self.get_mark()
+ elif name == 'TAG':
+ value = self.scan_tag_directive_value(start_mark)
+ end_mark = self.get_mark()
+ else:
+ end_mark = self.get_mark()
+ while self.peek() not in '\0\r\n\x85\u2028\u2029':
+ self.forward()
+ self.scan_directive_ignored_line(start_mark)
+ return DirectiveToken(name, value, start_mark, end_mark)
+
+ def scan_directive_name(self, start_mark):
+ # See the specification for details.
+ length = 0
+ ch = self.peek(length)
+ while '0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \
+ or ch in '-_':
+ length += 1
+ ch = self.peek(length)
+ if not length:
+ raise ScannerError("while scanning a directive", start_mark,
+ "expected alphabetic or numeric character, but found %r"
+ % ch, self.get_mark())
+ value = self.prefix(length)
+ self.forward(length)
+ ch = self.peek()
+ if ch not in '\0 \r\n\x85\u2028\u2029':
+ raise ScannerError("while scanning a directive", start_mark,
+ "expected alphabetic or numeric character, but found %r"
+ % ch, self.get_mark())
+ return value
+
+ def scan_yaml_directive_value(self, start_mark):
+ # See the specification for details.
+ while self.peek() == ' ':
+ self.forward()
+ major = self.scan_yaml_directive_number(start_mark)
+ if self.peek() != '.':
+ raise ScannerError("while scanning a directive", start_mark,
+ "expected a digit or '.', but found %r" % self.peek(),
+ self.get_mark())
+ self.forward()
+ minor = self.scan_yaml_directive_number(start_mark)
+ if self.peek() not in '\0 \r\n\x85\u2028\u2029':
+ raise ScannerError("while scanning a directive", start_mark,
+ "expected a digit or ' ', but found %r" % self.peek(),
+ self.get_mark())
+ return (major, minor)
+
+ def scan_yaml_directive_number(self, start_mark):
+ # See the specification for details.
+ ch = self.peek()
+ if not ('0' <= ch <= '9'):
+ raise ScannerError("while scanning a directive", start_mark,
+ "expected a digit, but found %r" % ch, self.get_mark())
+ length = 0
+ while '0' <= self.peek(length) <= '9':
+ length += 1
+ value = int(self.prefix(length))
+ self.forward(length)
+ return value
+
+ def scan_tag_directive_value(self, start_mark):
+ # See the specification for details.
+ while self.peek() == ' ':
+ self.forward()
+ handle = self.scan_tag_directive_handle(start_mark)
+ while self.peek() == ' ':
+ self.forward()
+ prefix = self.scan_tag_directive_prefix(start_mark)
+ return (handle, prefix)
+
+ def scan_tag_directive_handle(self, start_mark):
+ # See the specification for details.
+ value = self.scan_tag_handle('directive', start_mark)
+ ch = self.peek()
+ if ch != ' ':
+ raise ScannerError("while scanning a directive", start_mark,
+ "expected ' ', but found %r" % ch, self.get_mark())
+ return value
+
+ def scan_tag_directive_prefix(self, start_mark):
+ # See the specification for details.
+ value = self.scan_tag_uri('directive', start_mark)
+ ch = self.peek()
+ if ch not in '\0 \r\n\x85\u2028\u2029':
+ raise ScannerError("while scanning a directive", start_mark,
+ "expected ' ', but found %r" % ch, self.get_mark())
+ return value
+
+ def scan_directive_ignored_line(self, start_mark):
+ # See the specification for details.
+ while self.peek() == ' ':
+ self.forward()
+ if self.peek() == '#':
+ while self.peek() not in '\0\r\n\x85\u2028\u2029':
+ self.forward()
+ ch = self.peek()
+ if ch not in '\0\r\n\x85\u2028\u2029':
+ raise ScannerError("while scanning a directive", start_mark,
+ "expected a comment or a line break, but found %r"
+ % ch, self.get_mark())
+ self.scan_line_break()
+
+ def scan_anchor(self, TokenClass):
+ # The specification does not restrict characters for anchors and
+ # aliases. This may lead to problems, for instance, the document:
+ # [ *alias, value ]
+ # can be interpreted in two ways, as
+ # [ "value" ]
+ # and
+ # [ *alias , "value" ]
+ # Therefore we restrict aliases to numbers and ASCII letters.
+ start_mark = self.get_mark()
+ indicator = self.peek()
+ if indicator == '*':
+ name = 'alias'
+ else:
+ name = 'anchor'
+ self.forward()
+ length = 0
+ ch = self.peek(length)
+ while '0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \
+ or ch in '-_':
+ length += 1
+ ch = self.peek(length)
+ if not length:
+ raise ScannerError("while scanning an %s" % name, start_mark,
+ "expected alphabetic or numeric character, but found %r"
+ % ch, self.get_mark())
+ value = self.prefix(length)
+ self.forward(length)
+ ch = self.peek()
+ if ch not in '\0 \t\r\n\x85\u2028\u2029?:,]}%@`':
+ raise ScannerError("while scanning an %s" % name, start_mark,
+ "expected alphabetic or numeric character, but found %r"
+ % ch, self.get_mark())
+ end_mark = self.get_mark()
+ return TokenClass(value, start_mark, end_mark)
+
+ def scan_tag(self):
+ # See the specification for details.
+ start_mark = self.get_mark()
+ ch = self.peek(1)
+ if ch == '<':
+ handle = None
+ self.forward(2)
+ suffix = self.scan_tag_uri('tag', start_mark)
+ if self.peek() != '>':
+ raise ScannerError("while parsing a tag", start_mark,
+ "expected '>', but found %r" % self.peek(),
+ self.get_mark())
+ self.forward()
+ elif ch in '\0 \t\r\n\x85\u2028\u2029':
+ handle = None
+ suffix = '!'
+ self.forward()
+ else:
+ length = 1
+ use_handle = False
+ while ch not in '\0 \r\n\x85\u2028\u2029':
+ if ch == '!':
+ use_handle = True
+ break
+ length += 1
+ ch = self.peek(length)
+ handle = '!'
+ if use_handle:
+ handle = self.scan_tag_handle('tag', start_mark)
+ else:
+ handle = '!'
+ self.forward()
+ suffix = self.scan_tag_uri('tag', start_mark)
+ ch = self.peek()
+ if ch not in '\0 \r\n\x85\u2028\u2029':
+ raise ScannerError("while scanning a tag", start_mark,
+ "expected ' ', but found %r" % ch, self.get_mark())
+ value = (handle, suffix)
+ end_mark = self.get_mark()
+ return TagToken(value, start_mark, end_mark)
+
+ def scan_block_scalar(self, style):
+ # See the specification for details.
+
+ if style == '>':
+ folded = True
+ else:
+ folded = False
+
+ chunks = []
+ start_mark = self.get_mark()
+
+ # Scan the header.
+ self.forward()
+ chomping, increment = self.scan_block_scalar_indicators(start_mark)
+ self.scan_block_scalar_ignored_line(start_mark)
+
+ # Determine the indentation level and go to the first non-empty line.
+ min_indent = self.indent+1
+ if min_indent < 1:
+ min_indent = 1
+ if increment is None:
+ breaks, max_indent, end_mark = self.scan_block_scalar_indentation()
+ indent = max(min_indent, max_indent)
+ else:
+ indent = min_indent+increment-1
+ breaks, end_mark = self.scan_block_scalar_breaks(indent)
+ line_break = ''
+
+ # Scan the inner part of the block scalar.
+ while self.column == indent and self.peek() != '\0':
+ chunks.extend(breaks)
+ leading_non_space = self.peek() not in ' \t'
+ length = 0
+ while self.peek(length) not in '\0\r\n\x85\u2028\u2029':
+ length += 1
+ chunks.append(self.prefix(length))
+ self.forward(length)
+ line_break = self.scan_line_break()
+ breaks, end_mark = self.scan_block_scalar_breaks(indent)
+ if self.column == indent and self.peek() != '\0':
+
+ # Unfortunately, folding rules are ambiguous.
+ #
+ # This is the folding according to the specification:
+
+ if folded and line_break == '\n' \
+ and leading_non_space and self.peek() not in ' \t':
+ if not breaks:
+ chunks.append(' ')
+ else:
+ chunks.append(line_break)
+
+ # This is Clark Evans's interpretation (also in the spec
+ # examples):
+ #
+ #if folded and line_break == '\n':
+ # if not breaks:
+ # if self.peek() not in ' \t':
+ # chunks.append(' ')
+ # else:
+ # chunks.append(line_break)
+ #else:
+ # chunks.append(line_break)
+ else:
+ break
+
+ # Chomp the tail.
+ if chomping is not False:
+ chunks.append(line_break)
+ if chomping is True:
+ chunks.extend(breaks)
+
+ # We are done.
+ return ScalarToken(''.join(chunks), False, start_mark, end_mark,
+ style)
+
+ def scan_block_scalar_indicators(self, start_mark):
+ # See the specification for details.
+ chomping = None
+ increment = None
+ ch = self.peek()
+ if ch in '+-':
+ if ch == '+':
+ chomping = True
+ else:
+ chomping = False
+ self.forward()
+ ch = self.peek()
+ if ch in '0123456789':
+ increment = int(ch)
+ if increment == 0:
+ raise ScannerError("while scanning a block scalar", start_mark,
+ "expected indentation indicator in the range 1-9, but found 0",
+ self.get_mark())
+ self.forward()
+ elif ch in '0123456789':
+ increment = int(ch)
+ if increment == 0:
+ raise ScannerError("while scanning a block scalar", start_mark,
+ "expected indentation indicator in the range 1-9, but found 0",
+ self.get_mark())
+ self.forward()
+ ch = self.peek()
+ if ch in '+-':
+ if ch == '+':
+ chomping = True
+ else:
+ chomping = False
+ self.forward()
+ ch = self.peek()
+ if ch not in '\0 \r\n\x85\u2028\u2029':
+ raise ScannerError("while scanning a block scalar", start_mark,
+ "expected chomping or indentation indicators, but found %r"
+ % ch, self.get_mark())
+ return chomping, increment
+
+ def scan_block_scalar_ignored_line(self, start_mark):
+ # See the specification for details.
+ while self.peek() == ' ':
+ self.forward()
+ if self.peek() == '#':
+ while self.peek() not in '\0\r\n\x85\u2028\u2029':
+ self.forward()
+ ch = self.peek()
+ if ch not in '\0\r\n\x85\u2028\u2029':
+ raise ScannerError("while scanning a block scalar", start_mark,
+ "expected a comment or a line break, but found %r" % ch,
+ self.get_mark())
+ self.scan_line_break()
+
+ def scan_block_scalar_indentation(self):
+ # See the specification for details.
+ chunks = []
+ max_indent = 0
+ end_mark = self.get_mark()
+ while self.peek() in ' \r\n\x85\u2028\u2029':
+ if self.peek() != ' ':
+ chunks.append(self.scan_line_break())
+ end_mark = self.get_mark()
+ else:
+ self.forward()
+ if self.column > max_indent:
+ max_indent = self.column
+ return chunks, max_indent, end_mark
+
+ def scan_block_scalar_breaks(self, indent):
+ # See the specification for details.
+ chunks = []
+ end_mark = self.get_mark()
+ while self.column < indent and self.peek() == ' ':
+ self.forward()
+ while self.peek() in '\r\n\x85\u2028\u2029':
+ chunks.append(self.scan_line_break())
+ end_mark = self.get_mark()
+ while self.column < indent and self.peek() == ' ':
+ self.forward()
+ return chunks, end_mark
+
+ def scan_flow_scalar(self, style):
+ # See the specification for details.
+ # Note that we loose indentation rules for quoted scalars. Quoted
+ # scalars don't need to adhere indentation because " and ' clearly
+ # mark the beginning and the end of them. Therefore we are less
+ # restrictive then the specification requires. We only need to check
+ # that document separators are not included in scalars.
+ if style == '"':
+ double = True
+ else:
+ double = False
+ chunks = []
+ start_mark = self.get_mark()
+ quote = self.peek()
+ self.forward()
+ chunks.extend(self.scan_flow_scalar_non_spaces(double, start_mark))
+ while self.peek() != quote:
+ chunks.extend(self.scan_flow_scalar_spaces(double, start_mark))
+ chunks.extend(self.scan_flow_scalar_non_spaces(double, start_mark))
+ self.forward()
+ end_mark = self.get_mark()
+ return ScalarToken(''.join(chunks), False, start_mark, end_mark,
+ style)
+
+ ESCAPE_REPLACEMENTS = {
+ '0': '\0',
+ 'a': '\x07',
+ 'b': '\x08',
+ 't': '\x09',
+ '\t': '\x09',
+ 'n': '\x0A',
+ 'v': '\x0B',
+ 'f': '\x0C',
+ 'r': '\x0D',
+ 'e': '\x1B',
+ ' ': '\x20',
+ '\"': '\"',
+ '\\': '\\',
+ '/': '/',
+ 'N': '\x85',
+ '_': '\xA0',
+ 'L': '\u2028',
+ 'P': '\u2029',
+ }
+
+ ESCAPE_CODES = {
+ 'x': 2,
+ 'u': 4,
+ 'U': 8,
+ }
+
+ def scan_flow_scalar_non_spaces(self, double, start_mark):
+ # See the specification for details.
+ chunks = []
+ while True:
+ length = 0
+ while self.peek(length) not in '\'\"\\\0 \t\r\n\x85\u2028\u2029':
+ length += 1
+ if length:
+ chunks.append(self.prefix(length))
+ self.forward(length)
+ ch = self.peek()
+ if not double and ch == '\'' and self.peek(1) == '\'':
+ chunks.append('\'')
+ self.forward(2)
+ elif (double and ch == '\'') or (not double and ch in '\"\\'):
+ chunks.append(ch)
+ self.forward()
+ elif double and ch == '\\':
+ self.forward()
+ ch = self.peek()
+ if ch in self.ESCAPE_REPLACEMENTS:
+ chunks.append(self.ESCAPE_REPLACEMENTS[ch])
+ self.forward()
+ elif ch in self.ESCAPE_CODES:
+ length = self.ESCAPE_CODES[ch]
+ self.forward()
+ for k in range(length):
+ if self.peek(k) not in '0123456789ABCDEFabcdef':
+ raise ScannerError("while scanning a double-quoted scalar", start_mark,
+ "expected escape sequence of %d hexdecimal numbers, but found %r" %
+ (length, self.peek(k)), self.get_mark())
+ code = int(self.prefix(length), 16)
+ chunks.append(chr(code))
+ self.forward(length)
+ elif ch in '\r\n\x85\u2028\u2029':
+ self.scan_line_break()
+ chunks.extend(self.scan_flow_scalar_breaks(double, start_mark))
+ else:
+ raise ScannerError("while scanning a double-quoted scalar", start_mark,
+ "found unknown escape character %r" % ch, self.get_mark())
+ else:
+ return chunks
+
+ def scan_flow_scalar_spaces(self, double, start_mark):
+ # See the specification for details.
+ chunks = []
+ length = 0
+ while self.peek(length) in ' \t':
+ length += 1
+ whitespaces = self.prefix(length)
+ self.forward(length)
+ ch = self.peek()
+ if ch == '\0':
+ raise ScannerError("while scanning a quoted scalar", start_mark,
+ "found unexpected end of stream", self.get_mark())
+ elif ch in '\r\n\x85\u2028\u2029':
+ line_break = self.scan_line_break()
+ breaks = self.scan_flow_scalar_breaks(double, start_mark)
+ if line_break != '\n':
+ chunks.append(line_break)
+ elif not breaks:
+ chunks.append(' ')
+ chunks.extend(breaks)
+ else:
+ chunks.append(whitespaces)
+ return chunks
+
+ def scan_flow_scalar_breaks(self, double, start_mark):
+ # See the specification for details.
+ chunks = []
+ while True:
+ # Instead of checking indentation, we check for document
+ # separators.
+ prefix = self.prefix(3)
+ if (prefix == '---' or prefix == '...') \
+ and self.peek(3) in '\0 \t\r\n\x85\u2028\u2029':
+ raise ScannerError("while scanning a quoted scalar", start_mark,
+ "found unexpected document separator", self.get_mark())
+ while self.peek() in ' \t':
+ self.forward()
+ if self.peek() in '\r\n\x85\u2028\u2029':
+ chunks.append(self.scan_line_break())
+ else:
+ return chunks
+
+ def scan_plain(self):
+ # See the specification for details.
+ # We add an additional restriction for the flow context:
+ # plain scalars in the flow context cannot contain ',' or '?'.
+ # We also keep track of the `allow_simple_key` flag here.
+ # Indentation rules are loosed for the flow context.
+ chunks = []
+ start_mark = self.get_mark()
+ end_mark = start_mark
+ indent = self.indent+1
+ # We allow zero indentation for scalars, but then we need to check for
+ # document separators at the beginning of the line.
+ #if indent == 0:
+ # indent = 1
+ spaces = []
+ while True:
+ length = 0
+ if self.peek() == '#':
+ break
+ while True:
+ ch = self.peek(length)
+ if ch in '\0 \t\r\n\x85\u2028\u2029' \
+ or (ch == ':' and
+ self.peek(length+1) in '\0 \t\r\n\x85\u2028\u2029'
+ + (u',[]{}' if self.flow_level else u''))\
+ or (self.flow_level and ch in ',?[]{}'):
+ break
+ length += 1
+ if length == 0:
+ break
+ self.allow_simple_key = False
+ chunks.extend(spaces)
+ chunks.append(self.prefix(length))
+ self.forward(length)
+ end_mark = self.get_mark()
+ spaces = self.scan_plain_spaces(indent, start_mark)
+ if not spaces or self.peek() == '#' \
+ or (not self.flow_level and self.column < indent):
+ break
+ return ScalarToken(''.join(chunks), True, start_mark, end_mark)
+
+ def scan_plain_spaces(self, indent, start_mark):
+ # See the specification for details.
+ # The specification is really confusing about tabs in plain scalars.
+ # We just forbid them completely. Do not use tabs in YAML!
+ chunks = []
+ length = 0
+ while self.peek(length) in ' ':
+ length += 1
+ whitespaces = self.prefix(length)
+ self.forward(length)
+ ch = self.peek()
+ if ch in '\r\n\x85\u2028\u2029':
+ line_break = self.scan_line_break()
+ self.allow_simple_key = True
+ prefix = self.prefix(3)
+ if (prefix == '---' or prefix == '...') \
+ and self.peek(3) in '\0 \t\r\n\x85\u2028\u2029':
+ return
+ breaks = []
+ while self.peek() in ' \r\n\x85\u2028\u2029':
+ if self.peek() == ' ':
+ self.forward()
+ else:
+ breaks.append(self.scan_line_break())
+ prefix = self.prefix(3)
+ if (prefix == '---' or prefix == '...') \
+ and self.peek(3) in '\0 \t\r\n\x85\u2028\u2029':
+ return
+ if line_break != '\n':
+ chunks.append(line_break)
+ elif not breaks:
+ chunks.append(' ')
+ chunks.extend(breaks)
+ elif whitespaces:
+ chunks.append(whitespaces)
+ return chunks
+
+ def scan_tag_handle(self, name, start_mark):
+ # See the specification for details.
+ # For some strange reasons, the specification does not allow '_' in
+ # tag handles. I have allowed it anyway.
+ ch = self.peek()
+ if ch != '!':
+ raise ScannerError("while scanning a %s" % name, start_mark,
+ "expected '!', but found %r" % ch, self.get_mark())
+ length = 1
+ ch = self.peek(length)
+ if ch != ' ':
+ while '0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \
+ or ch in '-_':
+ length += 1
+ ch = self.peek(length)
+ if ch != '!':
+ self.forward(length)
+ raise ScannerError("while scanning a %s" % name, start_mark,
+ "expected '!', but found %r" % ch, self.get_mark())
+ length += 1
+ value = self.prefix(length)
+ self.forward(length)
+ return value
+
+ def scan_tag_uri(self, name, start_mark):
+ # See the specification for details.
+ # Note: we do not check if URI is well-formed.
+ chunks = []
+ length = 0
+ ch = self.peek(length)
+ while '0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \
+ or ch in '-;/?:@&=+$,_.!~*\'()[]%':
+ if ch == '%':
+ chunks.append(self.prefix(length))
+ self.forward(length)
+ length = 0
+ chunks.append(self.scan_uri_escapes(name, start_mark))
+ else:
+ length += 1
+ ch = self.peek(length)
+ if length:
+ chunks.append(self.prefix(length))
+ self.forward(length)
+ length = 0
+ if not chunks:
+ raise ScannerError("while parsing a %s" % name, start_mark,
+ "expected URI, but found %r" % ch, self.get_mark())
+ return ''.join(chunks)
+
+ def scan_uri_escapes(self, name, start_mark):
+ # See the specification for details.
+ codes = []
+ mark = self.get_mark()
+ while self.peek() == '%':
+ self.forward()
+ for k in range(2):
+ if self.peek(k) not in '0123456789ABCDEFabcdef':
+ raise ScannerError("while scanning a %s" % name, start_mark,
+ "expected URI escape sequence of 2 hexdecimal numbers, but found %r"
+ % self.peek(k), self.get_mark())
+ codes.append(int(self.prefix(2), 16))
+ self.forward(2)
+ try:
+ value = bytes(codes).decode('utf-8')
+ except UnicodeDecodeError as exc:
+ raise ScannerError("while scanning a %s" % name, start_mark, str(exc), mark)
+ return value
+
+ def scan_line_break(self):
+ # Transforms:
+ # '\r\n' : '\n'
+ # '\r' : '\n'
+ # '\n' : '\n'
+ # '\x85' : '\n'
+ # '\u2028' : '\u2028'
+ # '\u2029 : '\u2029'
+ # default : ''
+ ch = self.peek()
+ if ch in '\r\n\x85':
+ if self.prefix(2) == '\r\n':
+ self.forward(2)
+ else:
+ self.forward()
+ return '\n'
+ elif ch in '\u2028\u2029':
+ self.forward()
+ return ch
+ return ''
diff --git a/libs/yaml/serializer.py b/libs/yaml/serializer.py
new file mode 100644
index 000000000..fe911e67a
--- /dev/null
+++ b/libs/yaml/serializer.py
@@ -0,0 +1,111 @@
+
+__all__ = ['Serializer', 'SerializerError']
+
+from .error import YAMLError
+from .events import *
+from .nodes import *
+
+class SerializerError(YAMLError):
+ pass
+
+class Serializer:
+
+ ANCHOR_TEMPLATE = 'id%03d'
+
+ def __init__(self, encoding=None,
+ explicit_start=None, explicit_end=None, version=None, tags=None):
+ self.use_encoding = encoding
+ self.use_explicit_start = explicit_start
+ self.use_explicit_end = explicit_end
+ self.use_version = version
+ self.use_tags = tags
+ self.serialized_nodes = {}
+ self.anchors = {}
+ self.last_anchor_id = 0
+ self.closed = None
+
+ def open(self):
+ if self.closed is None:
+ self.emit(StreamStartEvent(encoding=self.use_encoding))
+ self.closed = False
+ elif self.closed:
+ raise SerializerError("serializer is closed")
+ else:
+ raise SerializerError("serializer is already opened")
+
+ def close(self):
+ if self.closed is None:
+ raise SerializerError("serializer is not opened")
+ elif not self.closed:
+ self.emit(StreamEndEvent())
+ self.closed = True
+
+ #def __del__(self):
+ # self.close()
+
+ def serialize(self, node):
+ if self.closed is None:
+ raise SerializerError("serializer is not opened")
+ elif self.closed:
+ raise SerializerError("serializer is closed")
+ self.emit(DocumentStartEvent(explicit=self.use_explicit_start,
+ version=self.use_version, tags=self.use_tags))
+ self.anchor_node(node)
+ self.serialize_node(node, None, None)
+ self.emit(DocumentEndEvent(explicit=self.use_explicit_end))
+ self.serialized_nodes = {}
+ self.anchors = {}
+ self.last_anchor_id = 0
+
+ def anchor_node(self, node):
+ if node in self.anchors:
+ if self.anchors[node] is None:
+ self.anchors[node] = self.generate_anchor(node)
+ else:
+ self.anchors[node] = None
+ if isinstance(node, SequenceNode):
+ for item in node.value:
+ self.anchor_node(item)
+ elif isinstance(node, MappingNode):
+ for key, value in node.value:
+ self.anchor_node(key)
+ self.anchor_node(value)
+
+ def generate_anchor(self, node):
+ self.last_anchor_id += 1
+ return self.ANCHOR_TEMPLATE % self.last_anchor_id
+
+ def serialize_node(self, node, parent, index):
+ alias = self.anchors[node]
+ if node in self.serialized_nodes:
+ self.emit(AliasEvent(alias))
+ else:
+ self.serialized_nodes[node] = True
+ self.descend_resolver(parent, index)
+ if isinstance(node, ScalarNode):
+ detected_tag = self.resolve(ScalarNode, node.value, (True, False))
+ default_tag = self.resolve(ScalarNode, node.value, (False, True))
+ implicit = (node.tag == detected_tag), (node.tag == default_tag)
+ self.emit(ScalarEvent(alias, node.tag, implicit, node.value,
+ style=node.style))
+ elif isinstance(node, SequenceNode):
+ implicit = (node.tag
+ == self.resolve(SequenceNode, node.value, True))
+ self.emit(SequenceStartEvent(alias, node.tag, implicit,
+ flow_style=node.flow_style))
+ index = 0
+ for item in node.value:
+ self.serialize_node(item, node, index)
+ index += 1
+ self.emit(SequenceEndEvent())
+ elif isinstance(node, MappingNode):
+ implicit = (node.tag
+ == self.resolve(MappingNode, node.value, True))
+ self.emit(MappingStartEvent(alias, node.tag, implicit,
+ flow_style=node.flow_style))
+ for key, value in node.value:
+ self.serialize_node(key, node, None)
+ self.serialize_node(value, node, key)
+ self.emit(MappingEndEvent())
+ self.ascend_resolver()
+
diff --git a/libs/yaml/tokens.py b/libs/yaml/tokens.py
new file mode 100644
index 000000000..4d0b48a39
--- /dev/null
+++ b/libs/yaml/tokens.py
@@ -0,0 +1,104 @@
+
+class Token(object):
+ def __init__(self, start_mark, end_mark):
+ self.start_mark = start_mark
+ self.end_mark = end_mark
+ def __repr__(self):
+ attributes = [key for key in self.__dict__
+ if not key.endswith('_mark')]
+ attributes.sort()
+ arguments = ', '.join(['%s=%r' % (key, getattr(self, key))
+ for key in attributes])
+ return '%s(%s)' % (self.__class__.__name__, arguments)
+
+#class BOMToken(Token):
+# id = '<byte order mark>'
+
+class DirectiveToken(Token):
+ id = '<directive>'
+ def __init__(self, name, value, start_mark, end_mark):
+ self.name = name
+ self.value = value
+ self.start_mark = start_mark
+ self.end_mark = end_mark
+
+class DocumentStartToken(Token):
+ id = '<document start>'
+
+class DocumentEndToken(Token):
+ id = '<document end>'
+
+class StreamStartToken(Token):
+ id = '<stream start>'
+ def __init__(self, start_mark=None, end_mark=None,
+ encoding=None):
+ self.start_mark = start_mark
+ self.end_mark = end_mark
+ self.encoding = encoding
+
+class StreamEndToken(Token):
+ id = '<stream end>'
+
+class BlockSequenceStartToken(Token):
+ id = '<block sequence start>'
+
+class BlockMappingStartToken(Token):
+ id = '<block mapping start>'
+
+class BlockEndToken(Token):
+ id = '<block end>'
+
+class FlowSequenceStartToken(Token):
+ id = '['
+
+class FlowMappingStartToken(Token):
+ id = '{'
+
+class FlowSequenceEndToken(Token):
+ id = ']'
+
+class FlowMappingEndToken(Token):
+ id = '}'
+
+class KeyToken(Token):
+ id = '?'
+
+class ValueToken(Token):
+ id = ':'
+
+class BlockEntryToken(Token):
+ id = '-'
+
+class FlowEntryToken(Token):
+ id = ','
+
+class AliasToken(Token):
+ id = '<alias>'
+ def __init__(self, value, start_mark, end_mark):
+ self.value = value
+ self.start_mark = start_mark
+ self.end_mark = end_mark
+
+class AnchorToken(Token):
+ id = '<anchor>'
+ def __init__(self, value, start_mark, end_mark):
+ self.value = value
+ self.start_mark = start_mark
+ self.end_mark = end_mark
+
+class TagToken(Token):
+ id = '<tag>'
+ def __init__(self, value, start_mark, end_mark):
+ self.value = value
+ self.start_mark = start_mark
+ self.end_mark = end_mark
+
+class ScalarToken(Token):
+ id = '<scalar>'
+ def __init__(self, value, plain, start_mark, end_mark, style=None):
+ self.value = value
+ self.plain = plain
+ self.start_mark = start_mark
+ self.end_mark = end_mark
+ self.style = style
+