diff options
author | Louis Vézina <[email protected]> | 2020-01-29 20:07:26 -0500 |
---|---|---|
committer | Louis Vézina <[email protected]> | 2020-01-29 20:07:26 -0500 |
commit | 83c95cc77dfd5ed18b439b1635f95bac129d0ce2 (patch) | |
tree | ea557727572cf3479a0af6d11434d0b486132eb8 | |
parent | 95b8aadb239bdce8d7f7a03e4ab995e56bf4e820 (diff) | |
download | bazarr-83c95cc77dfd5ed18b439b1635f95bac129d0ce2.tar.gz bazarr-83c95cc77dfd5ed18b439b1635f95bac129d0ce2.zip |
WIP
81 files changed, 24572 insertions, 0 deletions
diff --git a/libs/bs4/__init__.py b/libs/bs4/__init__.py new file mode 100644 index 000000000..95ca229c1 --- /dev/null +++ b/libs/bs4/__init__.py @@ -0,0 +1,616 @@ +"""Beautiful Soup +Elixir and Tonic +"The Screen-Scraper's Friend" +http://www.crummy.com/software/BeautifulSoup/ + +Beautiful Soup uses a pluggable XML or HTML parser to parse a +(possibly invalid) document into a tree representation. Beautiful Soup +provides methods and Pythonic idioms that make it easy to navigate, +search, and modify the parse tree. + +Beautiful Soup works with Python 2.7 and up. It works better if lxml +and/or html5lib is installed. + +For more than you ever wanted to know about Beautiful Soup, see the +documentation: +http://www.crummy.com/software/BeautifulSoup/bs4/doc/ + +""" + +__author__ = "Leonard Richardson ([email protected])" +__version__ = "4.8.0" +__copyright__ = "Copyright (c) 2004-2019 Leonard Richardson" +# Use of this source code is governed by the MIT license. +__license__ = "MIT" + +__all__ = ['BeautifulSoup'] + +import os +import re +import sys +import traceback +import warnings + +from .builder import builder_registry, ParserRejectedMarkup +from .dammit import UnicodeDammit +from .element import ( + CData, + Comment, + DEFAULT_OUTPUT_ENCODING, + Declaration, + Doctype, + NavigableString, + PageElement, + ProcessingInstruction, + ResultSet, + SoupStrainer, + Tag, + ) + +# The very first thing we do is give a useful error if someone is +# running this code under Python 3 without converting it. +'You are trying to run the Python 2 version of Beautiful Soup under Python 3. This will not work.'!='You need to convert the code, either by installing it (`python setup.py install`) or by running 2to3 (`2to3 -w bs4`).' + +class BeautifulSoup(Tag): + """ + This class defines the basic interface called by the tree builders. + + These methods will be called by the parser: + reset() + feed(markup) + + The tree builder may call these methods from its feed() implementation: + handle_starttag(name, attrs) # See note about return value + handle_endtag(name) + handle_data(data) # Appends to the current data node + endData(containerClass=NavigableString) # Ends the current data node + + No matter how complicated the underlying parser is, you should be + able to build a tree using 'start tag' events, 'end tag' events, + 'data' events, and "done with data" events. + + If you encounter an empty-element tag (aka a self-closing tag, + like HTML's <br> tag), call handle_starttag and then + handle_endtag. + """ + ROOT_TAG_NAME = '[document]' + + # If the end-user gives no indication which tree builder they + # want, look for one with these features. + DEFAULT_BUILDER_FEATURES = ['html', 'fast'] + + ASCII_SPACES = '\x20\x0a\x09\x0c\x0d' + + NO_PARSER_SPECIFIED_WARNING = "No parser was explicitly specified, so I'm using the best available %(markup_type)s parser for this system (\"%(parser)s\"). This usually isn't a problem, but if you run this code on another system, or in a different virtual environment, it may use a different parser and behave differently.\n\nThe code that caused this warning is on line %(line_number)s of the file %(filename)s. To get rid of this warning, pass the additional argument 'features=\"%(parser)s\"' to the BeautifulSoup constructor.\n" + + def __init__(self, markup="", features=None, builder=None, + parse_only=None, from_encoding=None, exclude_encodings=None, + **kwargs): + """Constructor. + + :param markup: A string or a file-like object representing + markup to be parsed. + + :param features: Desirable features of the parser to be used. This + may be the name of a specific parser ("lxml", "lxml-xml", + "html.parser", or "html5lib") or it may be the type of markup + to be used ("html", "html5", "xml"). It's recommended that you + name a specific parser, so that Beautiful Soup gives you the + same results across platforms and virtual environments. + + :param builder: A TreeBuilder subclass to instantiate (or + instance to use) instead of looking one up based on + `features`. You only need to use this if you've implemented a + custom TreeBuilder. + + :param parse_only: A SoupStrainer. Only parts of the document + matching the SoupStrainer will be considered. This is useful + when parsing part of a document that would otherwise be too + large to fit into memory. + + :param from_encoding: A string indicating the encoding of the + document to be parsed. Pass this in if Beautiful Soup is + guessing wrongly about the document's encoding. + + :param exclude_encodings: A list of strings indicating + encodings known to be wrong. Pass this in if you don't know + the document's encoding but you know Beautiful Soup's guess is + wrong. + + :param kwargs: For backwards compatibility purposes, the + constructor accepts certain keyword arguments used in + Beautiful Soup 3. None of these arguments do anything in + Beautiful Soup 4; they will result in a warning and then be ignored. + + Apart from this, any keyword arguments passed into the BeautifulSoup + constructor are propagated to the TreeBuilder constructor. This + makes it possible to configure a TreeBuilder beyond saying + which one to use. + + """ + + if 'convertEntities' in kwargs: + del kwargs['convertEntities'] + warnings.warn( + "BS4 does not respect the convertEntities argument to the " + "BeautifulSoup constructor. Entities are always converted " + "to Unicode characters.") + + if 'markupMassage' in kwargs: + del kwargs['markupMassage'] + warnings.warn( + "BS4 does not respect the markupMassage argument to the " + "BeautifulSoup constructor. The tree builder is responsible " + "for any necessary markup massage.") + + if 'smartQuotesTo' in kwargs: + del kwargs['smartQuotesTo'] + warnings.warn( + "BS4 does not respect the smartQuotesTo argument to the " + "BeautifulSoup constructor. Smart quotes are always converted " + "to Unicode characters.") + + if 'selfClosingTags' in kwargs: + del kwargs['selfClosingTags'] + warnings.warn( + "BS4 does not respect the selfClosingTags argument to the " + "BeautifulSoup constructor. The tree builder is responsible " + "for understanding self-closing tags.") + + if 'isHTML' in kwargs: + del kwargs['isHTML'] + warnings.warn( + "BS4 does not respect the isHTML argument to the " + "BeautifulSoup constructor. Suggest you use " + "features='lxml' for HTML and features='lxml-xml' for " + "XML.") + + def deprecated_argument(old_name, new_name): + if old_name in kwargs: + warnings.warn( + 'The "%s" argument to the BeautifulSoup constructor ' + 'has been renamed to "%s."' % (old_name, new_name)) + value = kwargs[old_name] + del kwargs[old_name] + return value + return None + + parse_only = parse_only or deprecated_argument( + "parseOnlyThese", "parse_only") + + from_encoding = from_encoding or deprecated_argument( + "fromEncoding", "from_encoding") + + if from_encoding and isinstance(markup, str): + warnings.warn("You provided Unicode markup but also provided a value for from_encoding. Your from_encoding will be ignored.") + from_encoding = None + + # We need this information to track whether or not the builder + # was specified well enough that we can omit the 'you need to + # specify a parser' warning. + original_builder = builder + original_features = features + + if isinstance(builder, type): + # A builder class was passed in; it needs to be instantiated. + builder_class = builder + builder = None + elif builder is None: + if isinstance(features, str): + features = [features] + if features is None or len(features) == 0: + features = self.DEFAULT_BUILDER_FEATURES + builder_class = builder_registry.lookup(*features) + if builder_class is None: + raise FeatureNotFound( + "Couldn't find a tree builder with the features you " + "requested: %s. Do you need to install a parser library?" + % ",".join(features)) + + # At this point either we have a TreeBuilder instance in + # builder, or we have a builder_class that we can instantiate + # with the remaining **kwargs. + if builder is None: + builder = builder_class(**kwargs) + if not original_builder and not ( + original_features == builder.NAME or + original_features in builder.ALTERNATE_NAMES + ): + if builder.is_xml: + markup_type = "XML" + else: + markup_type = "HTML" + + # This code adapted from warnings.py so that we get the same line + # of code as our warnings.warn() call gets, even if the answer is wrong + # (as it may be in a multithreading situation). + caller = None + try: + caller = sys._getframe(1) + except ValueError: + pass + if caller: + globals = caller.f_globals + line_number = caller.f_lineno + else: + globals = sys.__dict__ + line_number= 1 + filename = globals.get('__file__') + if filename: + fnl = filename.lower() + if fnl.endswith((".pyc", ".pyo")): + filename = filename[:-1] + if filename: + # If there is no filename at all, the user is most likely in a REPL, + # and the warning is not necessary. + values = dict( + filename=filename, + line_number=line_number, + parser=builder.NAME, + markup_type=markup_type + ) + warnings.warn(self.NO_PARSER_SPECIFIED_WARNING % values, stacklevel=2) + else: + if kwargs: + warnings.warn("Keyword arguments to the BeautifulSoup constructor will be ignored. These would normally be passed into the TreeBuilder constructor, but a TreeBuilder instance was passed in as `builder`.") + + self.builder = builder + self.is_xml = builder.is_xml + self.known_xml = self.is_xml + self._namespaces = dict() + self.parse_only = parse_only + + self.builder.initialize_soup(self) + + if hasattr(markup, 'read'): # It's a file-type object. + markup = markup.read() + elif len(markup) <= 256 and ( + (isinstance(markup, bytes) and not b'<' in markup) + or (isinstance(markup, str) and not '<' in markup) + ): + # Print out warnings for a couple beginner problems + # involving passing non-markup to Beautiful Soup. + # Beautiful Soup will still parse the input as markup, + # just in case that's what the user really wants. + if (isinstance(markup, str) + and not os.path.supports_unicode_filenames): + possible_filename = markup.encode("utf8") + else: + possible_filename = markup + is_file = False + try: + is_file = os.path.exists(possible_filename) + except Exception as e: + # This is almost certainly a problem involving + # characters not valid in filenames on this + # system. Just let it go. + pass + if is_file: + if isinstance(markup, str): + markup = markup.encode("utf8") + warnings.warn( + '"%s" looks like a filename, not markup. You should' + ' probably open this file and pass the filehandle into' + ' Beautiful Soup.' % markup) + self._check_markup_is_url(markup) + + for (self.markup, self.original_encoding, self.declared_html_encoding, + self.contains_replacement_characters) in ( + self.builder.prepare_markup( + markup, from_encoding, exclude_encodings=exclude_encodings)): + self.reset() + try: + self._feed() + break + except ParserRejectedMarkup: + pass + + # Clear out the markup and remove the builder's circular + # reference to this object. + self.markup = None + self.builder.soup = None + + def __copy__(self): + copy = type(self)( + self.encode('utf-8'), builder=self.builder, from_encoding='utf-8' + ) + + # Although we encoded the tree to UTF-8, that may not have + # been the encoding of the original markup. Set the copy's + # .original_encoding to reflect the original object's + # .original_encoding. + copy.original_encoding = self.original_encoding + return copy + + def __getstate__(self): + # Frequently a tree builder can't be pickled. + d = dict(self.__dict__) + if 'builder' in d and not self.builder.picklable: + d['builder'] = None + return d + + @staticmethod + def _check_markup_is_url(markup): + """ + Check if markup looks like it's actually a url and raise a warning + if so. Markup can be unicode or str (py2) / bytes (py3). + """ + if isinstance(markup, bytes): + space = b' ' + cant_start_with = (b"http:", b"https:") + elif isinstance(markup, str): + space = ' ' + cant_start_with = ("http:", "https:") + else: + return + + if any(markup.startswith(prefix) for prefix in cant_start_with): + if not space in markup: + if isinstance(markup, bytes): + decoded_markup = markup.decode('utf-8', 'replace') + else: + decoded_markup = markup + warnings.warn( + '"%s" looks like a URL. Beautiful Soup is not an' + ' HTTP client. You should probably use an HTTP client like' + ' requests to get the document behind the URL, and feed' + ' that document to Beautiful Soup.' % decoded_markup + ) + + def _feed(self): + # Convert the document to Unicode. + self.builder.reset() + + self.builder.feed(self.markup) + # Close out any unfinished strings and close all the open tags. + self.endData() + while self.currentTag.name != self.ROOT_TAG_NAME: + self.popTag() + + def reset(self): + Tag.__init__(self, self, self.builder, self.ROOT_TAG_NAME) + self.hidden = 1 + self.builder.reset() + self.current_data = [] + self.currentTag = None + self.tagStack = [] + self.preserve_whitespace_tag_stack = [] + self.pushTag(self) + + def new_tag(self, name, namespace=None, nsprefix=None, attrs={}, **kwattrs): + """Create a new tag associated with this soup.""" + kwattrs.update(attrs) + return Tag(None, self.builder, name, namespace, nsprefix, kwattrs) + + def new_string(self, s, subclass=NavigableString): + """Create a new NavigableString associated with this soup.""" + return subclass(s) + + def insert_before(self, successor): + raise NotImplementedError("BeautifulSoup objects don't support insert_before().") + + def insert_after(self, successor): + raise NotImplementedError("BeautifulSoup objects don't support insert_after().") + + def popTag(self): + tag = self.tagStack.pop() + if self.preserve_whitespace_tag_stack and tag == self.preserve_whitespace_tag_stack[-1]: + self.preserve_whitespace_tag_stack.pop() + #print "Pop", tag.name + if self.tagStack: + self.currentTag = self.tagStack[-1] + return self.currentTag + + def pushTag(self, tag): + #print "Push", tag.name + if self.currentTag is not None: + self.currentTag.contents.append(tag) + self.tagStack.append(tag) + self.currentTag = self.tagStack[-1] + if tag.name in self.builder.preserve_whitespace_tags: + self.preserve_whitespace_tag_stack.append(tag) + + def endData(self, containerClass=NavigableString): + if self.current_data: + current_data = ''.join(self.current_data) + # If whitespace is not preserved, and this string contains + # nothing but ASCII spaces, replace it with a single space + # or newline. + if not self.preserve_whitespace_tag_stack: + strippable = True + for i in current_data: + if i not in self.ASCII_SPACES: + strippable = False + break + if strippable: + if '\n' in current_data: + current_data = '\n' + else: + current_data = ' ' + + # Reset the data collector. + self.current_data = [] + + # Should we add this string to the tree at all? + if self.parse_only and len(self.tagStack) <= 1 and \ + (not self.parse_only.text or \ + not self.parse_only.search(current_data)): + return + + o = containerClass(current_data) + self.object_was_parsed(o) + + def object_was_parsed(self, o, parent=None, most_recent_element=None): + """Add an object to the parse tree.""" + if parent is None: + parent = self.currentTag + if most_recent_element is not None: + previous_element = most_recent_element + else: + previous_element = self._most_recent_element + + next_element = previous_sibling = next_sibling = None + if isinstance(o, Tag): + next_element = o.next_element + next_sibling = o.next_sibling + previous_sibling = o.previous_sibling + if previous_element is None: + previous_element = o.previous_element + + fix = parent.next_element is not None + + o.setup(parent, previous_element, next_element, previous_sibling, next_sibling) + + self._most_recent_element = o + parent.contents.append(o) + + # Check if we are inserting into an already parsed node. + if fix: + self._linkage_fixer(parent) + + def _linkage_fixer(self, el): + """Make sure linkage of this fragment is sound.""" + + first = el.contents[0] + child = el.contents[-1] + descendant = child + + if child is first and el.parent is not None: + # Parent should be linked to first child + el.next_element = child + # We are no longer linked to whatever this element is + prev_el = child.previous_element + if prev_el is not None and prev_el is not el: + prev_el.next_element = None + # First child should be linked to the parent, and no previous siblings. + child.previous_element = el + child.previous_sibling = None + + # We have no sibling as we've been appended as the last. + child.next_sibling = None + + # This index is a tag, dig deeper for a "last descendant" + if isinstance(child, Tag) and child.contents: + descendant = child._last_descendant(False) + + # As the final step, link last descendant. It should be linked + # to the parent's next sibling (if found), else walk up the chain + # and find a parent with a sibling. It should have no next sibling. + descendant.next_element = None + descendant.next_sibling = None + target = el + while True: + if target is None: + break + elif target.next_sibling is not None: + descendant.next_element = target.next_sibling + target.next_sibling.previous_element = child + break + target = target.parent + + def _popToTag(self, name, nsprefix=None, inclusivePop=True): + """Pops the tag stack up to and including the most recent + instance of the given tag. If inclusivePop is false, pops the tag + stack up to but *not* including the most recent instqance of + the given tag.""" + #print "Popping to %s" % name + if name == self.ROOT_TAG_NAME: + # The BeautifulSoup object itself can never be popped. + return + + most_recently_popped = None + + stack_size = len(self.tagStack) + for i in range(stack_size - 1, 0, -1): + t = self.tagStack[i] + if (name == t.name and nsprefix == t.prefix): + if inclusivePop: + most_recently_popped = self.popTag() + break + most_recently_popped = self.popTag() + + return most_recently_popped + + def handle_starttag(self, name, namespace, nsprefix, attrs): + """Push a start tag on to the stack. + + If this method returns None, the tag was rejected by the + SoupStrainer. You should proceed as if the tag had not occurred + in the document. For instance, if this was a self-closing tag, + don't call handle_endtag. + """ + + # print "Start tag %s: %s" % (name, attrs) + self.endData() + + if (self.parse_only and len(self.tagStack) <= 1 + and (self.parse_only.text + or not self.parse_only.search_tag(name, attrs))): + return None + + tag = Tag(self, self.builder, name, namespace, nsprefix, attrs, + self.currentTag, self._most_recent_element) + if tag is None: + return tag + if self._most_recent_element is not None: + self._most_recent_element.next_element = tag + self._most_recent_element = tag + self.pushTag(tag) + return tag + + def handle_endtag(self, name, nsprefix=None): + #print "End tag: " + name + self.endData() + self._popToTag(name, nsprefix) + + def handle_data(self, data): + self.current_data.append(data) + + def decode(self, pretty_print=False, + eventual_encoding=DEFAULT_OUTPUT_ENCODING, + formatter="minimal"): + """Returns a string or Unicode representation of this document. + To get Unicode, pass None for encoding.""" + + if self.is_xml: + # Print the XML declaration + encoding_part = '' + if eventual_encoding != None: + encoding_part = ' encoding="%s"' % eventual_encoding + prefix = '<?xml version="1.0"%s?>\n' % encoding_part + else: + prefix = '' + if not pretty_print: + indent_level = None + else: + indent_level = 0 + return prefix + super(BeautifulSoup, self).decode( + indent_level, eventual_encoding, formatter) + +# Alias to make it easier to type import: 'from bs4 import _soup' +_s = BeautifulSoup +_soup = BeautifulSoup + +class BeautifulStoneSoup(BeautifulSoup): + """Deprecated interface to an XML parser.""" + + def __init__(self, *args, **kwargs): + kwargs['features'] = 'xml' + warnings.warn( + 'The BeautifulStoneSoup class is deprecated. Instead of using ' + 'it, pass features="xml" into the BeautifulSoup constructor.') + super(BeautifulStoneSoup, self).__init__(*args, **kwargs) + + +class StopParsing(Exception): + pass + +class FeatureNotFound(ValueError): + pass + + +#By default, act as an HTML pretty-printer. +if __name__ == '__main__': + import sys + soup = BeautifulSoup(sys.stdin) + print(soup.prettify()) diff --git a/libs/bs4/builder/__init__.py b/libs/bs4/builder/__init__.py new file mode 100644 index 000000000..cc497cf0b --- /dev/null +++ b/libs/bs4/builder/__init__.py @@ -0,0 +1,367 @@ +# Use of this source code is governed by the MIT license. +__license__ = "MIT" + +from collections import defaultdict +import itertools +import sys +from bs4.element import ( + CharsetMetaAttributeValue, + ContentMetaAttributeValue, + nonwhitespace_re + ) + +__all__ = [ + 'HTMLTreeBuilder', + 'SAXTreeBuilder', + 'TreeBuilder', + 'TreeBuilderRegistry', + ] + +# Some useful features for a TreeBuilder to have. +FAST = 'fast' +PERMISSIVE = 'permissive' +STRICT = 'strict' +XML = 'xml' +HTML = 'html' +HTML_5 = 'html5' + + +class TreeBuilderRegistry(object): + + def __init__(self): + self.builders_for_feature = defaultdict(list) + self.builders = [] + + def register(self, treebuilder_class): + """Register a treebuilder based on its advertised features.""" + for feature in treebuilder_class.features: + self.builders_for_feature[feature].insert(0, treebuilder_class) + self.builders.insert(0, treebuilder_class) + + def lookup(self, *features): + if len(self.builders) == 0: + # There are no builders at all. + return None + + if len(features) == 0: + # They didn't ask for any features. Give them the most + # recently registered builder. + return self.builders[0] + + # Go down the list of features in order, and eliminate any builders + # that don't match every feature. + features = list(features) + features.reverse() + candidates = None + candidate_set = None + while len(features) > 0: + feature = features.pop() + we_have_the_feature = self.builders_for_feature.get(feature, []) + if len(we_have_the_feature) > 0: + if candidates is None: + candidates = we_have_the_feature + candidate_set = set(candidates) + else: + # Eliminate any candidates that don't have this feature. + candidate_set = candidate_set.intersection( + set(we_have_the_feature)) + + # The only valid candidates are the ones in candidate_set. + # Go through the original list of candidates and pick the first one + # that's in candidate_set. + if candidate_set is None: + return None + for candidate in candidates: + if candidate in candidate_set: + return candidate + return None + +# The BeautifulSoup class will take feature lists from developers and use them +# to look up builders in this registry. +builder_registry = TreeBuilderRegistry() + +class TreeBuilder(object): + """Turn a document into a Beautiful Soup object tree.""" + + NAME = "[Unknown tree builder]" + ALTERNATE_NAMES = [] + features = [] + + is_xml = False + picklable = False + empty_element_tags = None # A tag will be considered an empty-element + # tag when and only when it has no contents. + + # A value for these tag/attribute combinations is a space- or + # comma-separated list of CDATA, rather than a single CDATA. + DEFAULT_CDATA_LIST_ATTRIBUTES = {} + + DEFAULT_PRESERVE_WHITESPACE_TAGS = set() + + USE_DEFAULT = object() + + def __init__(self, multi_valued_attributes=USE_DEFAULT, preserve_whitespace_tags=USE_DEFAULT): + """Constructor. + + :param multi_valued_attributes: If this is set to None, the + TreeBuilder will not turn any values for attributes like + 'class' into lists. Setting this do a dictionary will + customize this behavior; look at DEFAULT_CDATA_LIST_ATTRIBUTES + for an example. + + Internally, these are called "CDATA list attributes", but that + probably doesn't make sense to an end-user, so the argument name + is `multi_valued_attributes`. + + :param preserve_whitespace_tags: + """ + self.soup = None + if multi_valued_attributes is self.USE_DEFAULT: + multi_valued_attributes = self.DEFAULT_CDATA_LIST_ATTRIBUTES + self.cdata_list_attributes = multi_valued_attributes + if preserve_whitespace_tags is self.USE_DEFAULT: + preserve_whitespace_tags = self.DEFAULT_PRESERVE_WHITESPACE_TAGS + self.preserve_whitespace_tags = preserve_whitespace_tags + + def initialize_soup(self, soup): + """The BeautifulSoup object has been initialized and is now + being associated with the TreeBuilder. + """ + self.soup = soup + + def reset(self): + pass + + def can_be_empty_element(self, tag_name): + """Might a tag with this name be an empty-element tag? + + The final markup may or may not actually present this tag as + self-closing. + + For instance: an HTMLBuilder does not consider a <p> tag to be + an empty-element tag (it's not in + HTMLBuilder.empty_element_tags). This means an empty <p> tag + will be presented as "<p></p>", not "<p />". + + The default implementation has no opinion about which tags are + empty-element tags, so a tag will be presented as an + empty-element tag if and only if it has no contents. + "<foo></foo>" will become "<foo />", and "<foo>bar</foo>" will + be left alone. + """ + if self.empty_element_tags is None: + return True + return tag_name in self.empty_element_tags + + def feed(self, markup): + raise NotImplementedError() + + def prepare_markup(self, markup, user_specified_encoding=None, + document_declared_encoding=None): + return markup, None, None, False + + def test_fragment_to_document(self, fragment): + """Wrap an HTML fragment to make it look like a document. + + Different parsers do this differently. For instance, lxml + introduces an empty <head> tag, and html5lib + doesn't. Abstracting this away lets us write simple tests + which run HTML fragments through the parser and compare the + results against other HTML fragments. + + This method should not be used outside of tests. + """ + return fragment + + def set_up_substitutions(self, tag): + return False + + def _replace_cdata_list_attribute_values(self, tag_name, attrs): + """Replaces class="foo bar" with class=["foo", "bar"] + + Modifies its input in place. + """ + if not attrs: + return attrs + if self.cdata_list_attributes: + universal = self.cdata_list_attributes.get('*', []) + tag_specific = self.cdata_list_attributes.get( + tag_name.lower(), None) + for attr in list(attrs.keys()): + if attr in universal or (tag_specific and attr in tag_specific): + # We have a "class"-type attribute whose string + # value is a whitespace-separated list of + # values. Split it into a list. + value = attrs[attr] + if isinstance(value, str): + values = nonwhitespace_re.findall(value) + else: + # html5lib sometimes calls setAttributes twice + # for the same tag when rearranging the parse + # tree. On the second call the attribute value + # here is already a list. If this happens, + # leave the value alone rather than trying to + # split it again. + values = value + attrs[attr] = values + return attrs + +class SAXTreeBuilder(TreeBuilder): + """A Beautiful Soup treebuilder that listens for SAX events.""" + + def feed(self, markup): + raise NotImplementedError() + + def close(self): + pass + + def startElement(self, name, attrs): + attrs = dict((key[1], value) for key, value in list(attrs.items())) + #print "Start %s, %r" % (name, attrs) + self.soup.handle_starttag(name, attrs) + + def endElement(self, name): + #print "End %s" % name + self.soup.handle_endtag(name) + + def startElementNS(self, nsTuple, nodeName, attrs): + # Throw away (ns, nodeName) for now. + self.startElement(nodeName, attrs) + + def endElementNS(self, nsTuple, nodeName): + # Throw away (ns, nodeName) for now. + self.endElement(nodeName) + #handler.endElementNS((ns, node.nodeName), node.nodeName) + + def startPrefixMapping(self, prefix, nodeValue): + # Ignore the prefix for now. + pass + + def endPrefixMapping(self, prefix): + # Ignore the prefix for now. + # handler.endPrefixMapping(prefix) + pass + + def characters(self, content): + self.soup.handle_data(content) + + def startDocument(self): + pass + + def endDocument(self): + pass + + +class HTMLTreeBuilder(TreeBuilder): + """This TreeBuilder knows facts about HTML. + + Such as which tags are empty-element tags. + """ + + empty_element_tags = set([ + # These are from HTML5. + 'area', 'base', 'br', 'col', 'embed', 'hr', 'img', 'input', 'keygen', 'link', 'menuitem', 'meta', 'param', 'source', 'track', 'wbr', + + # These are from earlier versions of HTML and are removed in HTML5. + 'basefont', 'bgsound', 'command', 'frame', 'image', 'isindex', 'nextid', 'spacer' + ]) + + # The HTML standard defines these as block-level elements. Beautiful + # Soup does not treat these elements differently from other elements, + # but it may do so eventually, and this information is available if + # you need to use it. + block_elements = set(["address", "article", "aside", "blockquote", "canvas", "dd", "div", "dl", "dt", "fieldset", "figcaption", "figure", "footer", "form", "h1", "h2", "h3", "h4", "h5", "h6", "header", "hr", "li", "main", "nav", "noscript", "ol", "output", "p", "pre", "section", "table", "tfoot", "ul", "video"]) + + # The HTML standard defines these attributes as containing a + # space-separated list of values, not a single value. That is, + # class="foo bar" means that the 'class' attribute has two values, + # 'foo' and 'bar', not the single value 'foo bar'. When we + # encounter one of these attributes, we will parse its value into + # a list of values if possible. Upon output, the list will be + # converted back into a string. + DEFAULT_CDATA_LIST_ATTRIBUTES = { + "*" : ['class', 'accesskey', 'dropzone'], + "a" : ['rel', 'rev'], + "link" : ['rel', 'rev'], + "td" : ["headers"], + "th" : ["headers"], + "td" : ["headers"], + "form" : ["accept-charset"], + "object" : ["archive"], + + # These are HTML5 specific, as are *.accesskey and *.dropzone above. + "area" : ["rel"], + "icon" : ["sizes"], + "iframe" : ["sandbox"], + "output" : ["for"], + } + + DEFAULT_PRESERVE_WHITESPACE_TAGS = set(['pre', 'textarea']) + + def set_up_substitutions(self, tag): + # We are only interested in <meta> tags + if tag.name != 'meta': + return False + + http_equiv = tag.get('http-equiv') + content = tag.get('content') + charset = tag.get('charset') + + # We are interested in <meta> tags that say what encoding the + # document was originally in. This means HTML 5-style <meta> + # tags that provide the "charset" attribute. It also means + # HTML 4-style <meta> tags that provide the "content" + # attribute and have "http-equiv" set to "content-type". + # + # In both cases we will replace the value of the appropriate + # attribute with a standin object that can take on any + # encoding. + meta_encoding = None + if charset is not None: + # HTML 5 style: + # <meta charset="utf8"> + meta_encoding = charset + tag['charset'] = CharsetMetaAttributeValue(charset) + + elif (content is not None and http_equiv is not None + and http_equiv.lower() == 'content-type'): + # HTML 4 style: + # <meta http-equiv="content-type" content="text/html; charset=utf8"> + tag['content'] = ContentMetaAttributeValue(content) + + return (meta_encoding is not None) + +def register_treebuilders_from(module): + """Copy TreeBuilders from the given module into this module.""" + # I'm fairly sure this is not the best way to do this. + this_module = sys.modules['bs4.builder'] + for name in module.__all__: + obj = getattr(module, name) + + if issubclass(obj, TreeBuilder): + setattr(this_module, name, obj) + this_module.__all__.append(name) + # Register the builder while we're at it. + this_module.builder_registry.register(obj) + +class ParserRejectedMarkup(Exception): + pass + +# Builders are registered in reverse order of priority, so that custom +# builder registrations will take precedence. In general, we want lxml +# to take precedence over html5lib, because it's faster. And we only +# want to use HTMLParser as a last result. +from . import _htmlparser +register_treebuilders_from(_htmlparser) +try: + from . import _html5lib + register_treebuilders_from(_html5lib) +except ImportError: + # They don't have html5lib installed. + pass +try: + from . import _lxml + register_treebuilders_from(_lxml) +except ImportError: + # They don't have lxml installed. + pass diff --git a/libs/bs4/builder/_html5lib.py b/libs/bs4/builder/_html5lib.py new file mode 100644 index 000000000..090bb61a8 --- /dev/null +++ b/libs/bs4/builder/_html5lib.py @@ -0,0 +1,426 @@ +# Use of this source code is governed by the MIT license. +__license__ = "MIT" + +__all__ = [ + 'HTML5TreeBuilder', + ] + +import warnings +import re +from bs4.builder import ( + PERMISSIVE, + HTML, + HTML_5, + HTMLTreeBuilder, + ) +from bs4.element import ( + NamespacedAttribute, + nonwhitespace_re, +) +import html5lib +from html5lib.constants import ( + namespaces, + prefixes, + ) +from bs4.element import ( + Comment, + Doctype, + NavigableString, + Tag, + ) + +try: + # Pre-0.99999999 + from html5lib.treebuilders import _base as treebuilder_base + new_html5lib = False +except ImportError as e: + # 0.99999999 and up + from html5lib.treebuilders import base as treebuilder_base + new_html5lib = True + +class HTML5TreeBuilder(HTMLTreeBuilder): + """Use html5lib to build a tree.""" + + NAME = "html5lib" + + features = [NAME, PERMISSIVE, HTML_5, HTML] + + def prepare_markup(self, markup, user_specified_encoding, + document_declared_encoding=None, exclude_encodings=None): + # Store the user-specified encoding for use later on. + self.user_specified_encoding = user_specified_encoding + + # document_declared_encoding and exclude_encodings aren't used + # ATM because the html5lib TreeBuilder doesn't use + # UnicodeDammit. + if exclude_encodings: + warnings.warn("You provided a value for exclude_encoding, but the html5lib tree builder doesn't support exclude_encoding.") + yield (markup, None, None, False) + + # These methods are defined by Beautiful Soup. + def feed(self, markup): + if self.soup.parse_only is not None: + warnings.warn("You provided a value for parse_only, but the html5lib tree builder doesn't support parse_only. The entire document will be parsed.") + parser = html5lib.HTMLParser(tree=self.create_treebuilder) + + extra_kwargs = dict() + if not isinstance(markup, str): + if new_html5lib: + extra_kwargs['override_encoding'] = self.user_specified_encoding + else: + extra_kwargs['encoding'] = self.user_specified_encoding + doc = parser.parse(markup, **extra_kwargs) + + # Set the character encoding detected by the tokenizer. + if isinstance(markup, str): + # We need to special-case this because html5lib sets + # charEncoding to UTF-8 if it gets Unicode input. + doc.original_encoding = None + else: + original_encoding = parser.tokenizer.stream.charEncoding[0] + if not isinstance(original_encoding, str): + # In 0.99999999 and up, the encoding is an html5lib + # Encoding object. We want to use a string for compatibility + # with other tree builders. + original_encoding = original_encoding.name + doc.original_encoding = original_encoding + + def create_treebuilder(self, namespaceHTMLElements): + self.underlying_builder = TreeBuilderForHtml5lib( + namespaceHTMLElements, self.soup) + return self.underlying_builder + + def test_fragment_to_document(self, fragment): + """See `TreeBuilder`.""" + return '<html><head></head><body>%s</body></html>' % fragment + + +class TreeBuilderForHtml5lib(treebuilder_base.TreeBuilder): + + def __init__(self, namespaceHTMLElements, soup=None): + if soup: + self.soup = soup + else: + from bs4 import BeautifulSoup + self.soup = BeautifulSoup("", "html.parser") + super(TreeBuilderForHtml5lib, self).__init__(namespaceHTMLElements) + + def documentClass(self): + self.soup.reset() + return Element(self.soup, self.soup, None) + + def insertDoctype(self, token): + name = token["name"] + publicId = token["publicId"] + systemId = token["systemId"] + + doctype = Doctype.for_name_and_ids(name, publicId, systemId) + self.soup.object_was_parsed(doctype) + + def elementClass(self, name, namespace): + tag = self.soup.new_tag(name, namespace) + return Element(tag, self.soup, namespace) + + def commentClass(self, data): + return TextNode(Comment(data), self.soup) + + def fragmentClass(self): + from bs4 import BeautifulSoup + self.soup = BeautifulSoup("", "html.parser") + self.soup.name = "[document_fragment]" + return Element(self.soup, self.soup, None) + + def appendChild(self, node): + # XXX This code is not covered by the BS4 tests. + self.soup.append(node.element) + + def getDocument(self): + return self.soup + + def getFragment(self): + return treebuilder_base.TreeBuilder.getFragment(self).element + + def testSerializer(self, element): + from bs4 import BeautifulSoup + rv = [] + doctype_re = re.compile(r'^(.*?)(?: PUBLIC "(.*?)"(?: "(.*?)")?| SYSTEM "(.*?)")?$') + + def serializeElement(element, indent=0): + if isinstance(element, BeautifulSoup): + pass + if isinstance(element, Doctype): + m = doctype_re.match(element) + if m: + name = m.group(1) + if m.lastindex > 1: + publicId = m.group(2) or "" + systemId = m.group(3) or m.group(4) or "" + rv.append("""|%s<!DOCTYPE %s "%s" "%s">""" % + (' ' * indent, name, publicId, systemId)) + else: + rv.append("|%s<!DOCTYPE %s>" % (' ' * indent, name)) + else: + rv.append("|%s<!DOCTYPE >" % (' ' * indent,)) + elif isinstance(element, Comment): + rv.append("|%s<!-- %s -->" % (' ' * indent, element)) + elif isinstance(element, NavigableString): + rv.append("|%s\"%s\"" % (' ' * indent, element)) + else: + if element.namespace: + name = "%s %s" % (prefixes[element.namespace], + element.name) + else: + name = element.name + rv.append("|%s<%s>" % (' ' * indent, name)) + if element.attrs: + attributes = [] + for name, value in list(element.attrs.items()): + if isinstance(name, NamespacedAttribute): + name = "%s %s" % (prefixes[name.namespace], name.name) + if isinstance(value, list): + value = " ".join(value) + attributes.append((name, value)) + + for name, value in sorted(attributes): + rv.append('|%s%s="%s"' % (' ' * (indent + 2), name, value)) + indent += 2 + for child in element.children: + serializeElement(child, indent) + serializeElement(element, 0) + + return "\n".join(rv) + +class AttrList(object): + def __init__(self, element): + self.element = element + self.attrs = dict(self.element.attrs) + def __iter__(self): + return list(self.attrs.items()).__iter__() + def __setitem__(self, name, value): + # If this attribute is a multi-valued attribute for this element, + # turn its value into a list. + list_attr = self.element.cdata_list_attributes + if (name in list_attr['*'] + or (self.element.name in list_attr + and name in list_attr[self.element.name])): + # A node that is being cloned may have already undergone + # this procedure. + if not isinstance(value, list): + value = nonwhitespace_re.findall(value) + self.element[name] = value + def items(self): + return list(self.attrs.items()) + def keys(self): + return list(self.attrs.keys()) + def __len__(self): + return len(self.attrs) + def __getitem__(self, name): + return self.attrs[name] + def __contains__(self, name): + return name in list(self.attrs.keys()) + + +class Element(treebuilder_base.Node): + def __init__(self, element, soup, namespace): + treebuilder_base.Node.__init__(self, element.name) + self.element = element + self.soup = soup + self.namespace = namespace + + def appendChild(self, node): + string_child = child = None + if isinstance(node, str): + # Some other piece of code decided to pass in a string + # instead of creating a TextElement object to contain the + # string. + string_child = child = node + elif isinstance(node, Tag): + # Some other piece of code decided to pass in a Tag + # instead of creating an Element object to contain the + # Tag. + child = node + elif node.element.__class__ == NavigableString: + string_child = child = node.element + node.parent = self + else: + child = node.element + node.parent = self + + if not isinstance(child, str) and child.parent is not None: + node.element.extract() + + if (string_child is not None and self.element.contents + and self.element.contents[-1].__class__ == NavigableString): + # We are appending a string onto another string. + # TODO This has O(n^2) performance, for input like + # "a</a>a</a>a</a>..." + old_element = self.element.contents[-1] + new_element = self.soup.new_string(old_element + string_child) + old_element.replace_with(new_element) + self.soup._most_recent_element = new_element + else: + if isinstance(node, str): + # Create a brand new NavigableString from this string. + child = self.soup.new_string(node) + + # Tell Beautiful Soup to act as if it parsed this element + # immediately after the parent's last descendant. (Or + # immediately after the parent, if it has no children.) + if self.element.contents: + most_recent_element = self.element._last_descendant(False) + elif self.element.next_element is not None: + # Something from further ahead in the parse tree is + # being inserted into this earlier element. This is + # very annoying because it means an expensive search + # for the last element in the tree. + most_recent_element = self.soup._last_descendant() + else: + most_recent_element = self.element + + self.soup.object_was_parsed( + child, parent=self.element, + most_recent_element=most_recent_element) + + def getAttributes(self): + if isinstance(self.element, Comment): + return {} + return AttrList(self.element) + + def setAttributes(self, attributes): + + if attributes is not None and len(attributes) > 0: + + converted_attributes = [] + for name, value in list(attributes.items()): + if isinstance(name, tuple): + new_name = NamespacedAttribute(*name) + del attributes[name] + attributes[new_name] = value + + self.soup.builder._replace_cdata_list_attribute_values( + self.name, attributes) + for name, value in list(attributes.items()): + self.element[name] = value + + # The attributes may contain variables that need substitution. + # Call set_up_substitutions manually. + # + # The Tag constructor called this method when the Tag was created, + # but we just set/changed the attributes, so call it again. + self.soup.builder.set_up_substitutions(self.element) + attributes = property(getAttributes, setAttributes) + + def insertText(self, data, insertBefore=None): + text = TextNode(self.soup.new_string(data), self.soup) + if insertBefore: + self.insertBefore(text, insertBefore) + else: + self.appendChild(text) + + def insertBefore(self, node, refNode): + index = self.element.index(refNode.element) + if (node.element.__class__ == NavigableString and self.element.contents + and self.element.contents[index-1].__class__ == NavigableString): + # (See comments in appendChild) + old_node = self.element.contents[index-1] + new_str = self.soup.new_string(old_node + node.element) + old_node.replace_with(new_str) + else: + self.element.insert(index, node.element) + node.parent = self + + def removeChild(self, node): + node.element.extract() + + def reparentChildren(self, new_parent): + """Move all of this tag's children into another tag.""" + # print "MOVE", self.element.contents + # print "FROM", self.element + # print "TO", new_parent.element + + element = self.element + new_parent_element = new_parent.element + # Determine what this tag's next_element will be once all the children + # are removed. + final_next_element = element.next_sibling + + new_parents_last_descendant = new_parent_element._last_descendant(False, False) + if len(new_parent_element.contents) > 0: + # The new parent already contains children. We will be + # appending this tag's children to the end. + new_parents_last_child = new_parent_element.contents[-1] + new_parents_last_descendant_next_element = new_parents_last_descendant.next_element + else: + # The new parent contains no children. + new_parents_last_child = None + new_parents_last_descendant_next_element = new_parent_element.next_element + + to_append = element.contents + if len(to_append) > 0: + # Set the first child's previous_element and previous_sibling + # to elements within the new parent + first_child = to_append[0] + if new_parents_last_descendant is not None: + first_child.previous_element = new_parents_last_descendant + else: + first_child.previous_element = new_parent_element + first_child.previous_sibling = new_parents_last_child + if new_parents_last_descendant is not None: + new_parents_last_descendant.next_element = first_child + else: + new_parent_element.next_element = first_child + if new_parents_last_child is not None: + new_parents_last_child.next_sibling = first_child + + # Find the very last element being moved. It is now the + # parent's last descendant. It has no .next_sibling and + # its .next_element is whatever the previous last + # descendant had. + last_childs_last_descendant = to_append[-1]._last_descendant(False, True) + + last_childs_last_descendant.next_element = new_parents_last_descendant_next_element + if new_parents_last_descendant_next_element is not None: + # TODO: This code has no test coverage and I'm not sure + # how to get html5lib to go through this path, but it's + # just the other side of the previous line. + new_parents_last_descendant_next_element.previous_element = last_childs_last_descendant + last_childs_last_descendant.next_sibling = None + + for child in to_append: + child.parent = new_parent_element + new_parent_element.contents.append(child) + + # Now that this element has no children, change its .next_element. + element.contents = [] + element.next_element = final_next_element + + # print "DONE WITH MOVE" + # print "FROM", self.element + # print "TO", new_parent_element + + def cloneNode(self): + tag = self.soup.new_tag(self.element.name, self.namespace) + node = Element(tag, self.soup, self.namespace) + for key,value in self.attributes: + node.attributes[key] = value + return node + + def hasContent(self): + return self.element.contents + + def getNameTuple(self): + if self.namespace == None: + return namespaces["html"], self.name + else: + return self.namespace, self.name + + nameTuple = property(getNameTuple) + +class TextNode(Element): + def __init__(self, element, soup): + treebuilder_base.Node.__init__(self, None) + self.element = element + self.soup = soup + + def cloneNode(self): + raise NotImplementedError diff --git a/libs/bs4/builder/_htmlparser.py b/libs/bs4/builder/_htmlparser.py new file mode 100644 index 000000000..ea549c356 --- /dev/null +++ b/libs/bs4/builder/_htmlparser.py @@ -0,0 +1,350 @@ +# encoding: utf-8 +"""Use the HTMLParser library to parse HTML files that aren't too bad.""" + +# Use of this source code is governed by the MIT license. +__license__ = "MIT" + +__all__ = [ + 'HTMLParserTreeBuilder', + ] + +from html.parser import HTMLParser + +try: + from html.parser import HTMLParseError +except ImportError as e: + # HTMLParseError is removed in Python 3.5. Since it can never be + # thrown in 3.5, we can just define our own class as a placeholder. + class HTMLParseError(Exception): + pass + +import sys +import warnings + +# Starting in Python 3.2, the HTMLParser constructor takes a 'strict' +# argument, which we'd like to set to False. Unfortunately, +# http://bugs.python.org/issue13273 makes strict=True a better bet +# before Python 3.2.3. +# +# At the end of this file, we monkeypatch HTMLParser so that +# strict=True works well on Python 3.2.2. +major, minor, release = sys.version_info[:3] +CONSTRUCTOR_TAKES_STRICT = major == 3 and minor == 2 and release >= 3 +CONSTRUCTOR_STRICT_IS_DEPRECATED = major == 3 and minor == 3 +CONSTRUCTOR_TAKES_CONVERT_CHARREFS = major == 3 and minor >= 4 + + +from bs4.element import ( + CData, + Comment, + Declaration, + Doctype, + ProcessingInstruction, + ) +from bs4.dammit import EntitySubstitution, UnicodeDammit + +from bs4.builder import ( + HTML, + HTMLTreeBuilder, + STRICT, + ) + + +HTMLPARSER = 'html.parser' + +class BeautifulSoupHTMLParser(HTMLParser): + + def __init__(self, *args, **kwargs): + HTMLParser.__init__(self, *args, **kwargs) + + # Keep a list of empty-element tags that were encountered + # without an explicit closing tag. If we encounter a closing tag + # of this type, we'll associate it with one of those entries. + # + # This isn't a stack because we don't care about the + # order. It's a list of closing tags we've already handled and + # will ignore, assuming they ever show up. + self.already_closed_empty_element = [] + + def error(self, msg): + """In Python 3, HTMLParser subclasses must implement error(), although this + requirement doesn't appear to be documented. + + In Python 2, HTMLParser implements error() as raising an exception. + + In any event, this method is called only on very strange markup and our best strategy + is to pretend it didn't happen and keep going. + """ + warnings.warn(msg) + + def handle_startendtag(self, name, attrs): + # This is only called when the markup looks like + # <tag/>. + + # is_startend() tells handle_starttag not to close the tag + # just because its name matches a known empty-element tag. We + # know that this is an empty-element tag and we want to call + # handle_endtag ourselves. + tag = self.handle_starttag(name, attrs, handle_empty_element=False) + self.handle_endtag(name) + + def handle_starttag(self, name, attrs, handle_empty_element=True): + # XXX namespace + attr_dict = {} + for key, value in attrs: + # Change None attribute values to the empty string + # for consistency with the other tree builders. + if value is None: + value = '' + attr_dict[key] = value + attrvalue = '""' + #print "START", name + tag = self.soup.handle_starttag(name, None, None, attr_dict) + if tag and tag.is_empty_element and handle_empty_element: + # Unlike other parsers, html.parser doesn't send separate end tag + # events for empty-element tags. (It's handled in + # handle_startendtag, but only if the original markup looked like + # <tag/>.) + # + # So we need to call handle_endtag() ourselves. Since we + # know the start event is identical to the end event, we + # don't want handle_endtag() to cross off any previous end + # events for tags of this name. + self.handle_endtag(name, check_already_closed=False) + + # But we might encounter an explicit closing tag for this tag + # later on. If so, we want to ignore it. + self.already_closed_empty_element.append(name) + + def handle_endtag(self, name, check_already_closed=True): + #print "END", name + if check_already_closed and name in self.already_closed_empty_element: + # This is a redundant end tag for an empty-element tag. + # We've already called handle_endtag() for it, so just + # check it off the list. + # print "ALREADY CLOSED", name + self.already_closed_empty_element.remove(name) + else: + self.soup.handle_endtag(name) + + def handle_data(self, data): + self.soup.handle_data(data) + + def handle_charref(self, name): + # XXX workaround for a bug in HTMLParser. Remove this once + # it's fixed in all supported versions. + # http://bugs.python.org/issue13633 + if name.startswith('x'): + real_name = int(name.lstrip('x'), 16) + elif name.startswith('X'): + real_name = int(name.lstrip('X'), 16) + else: + real_name = int(name) + + data = None + if real_name < 256: + # HTML numeric entities are supposed to reference Unicode + # code points, but sometimes they reference code points in + # some other encoding (ahem, Windows-1252). E.g. “ + # instead of É for LEFT DOUBLE QUOTATION MARK. This + # code tries to detect this situation and compensate. + for encoding in (self.soup.original_encoding, 'windows-1252'): + if not encoding: + continue + try: + data = bytearray([real_name]).decode(encoding) + except UnicodeDecodeError as e: + pass + if not data: + try: + data = chr(real_name) + except (ValueError, OverflowError) as e: + pass + data = data or "\N{REPLACEMENT CHARACTER}" + self.handle_data(data) + + def handle_entityref(self, name): + character = EntitySubstitution.HTML_ENTITY_TO_CHARACTER.get(name) + if character is not None: + data = character + else: + # If this were XML, it would be ambiguous whether "&foo" + # was an character entity reference with a missing + # semicolon or the literal string "&foo". Since this is + # HTML, we have a complete list of all character entity references, + # and this one wasn't found, so assume it's the literal string "&foo". + data = "&%s" % name + self.handle_data(data) + + def handle_comment(self, data): + self.soup.endData() + self.soup.handle_data(data) + self.soup.endData(Comment) + + def handle_decl(self, data): + self.soup.endData() + if data.startswith("DOCTYPE "): + data = data[len("DOCTYPE "):] + elif data == 'DOCTYPE': + # i.e. "<!DOCTYPE>" + data = '' + self.soup.handle_data(data) + self.soup.endData(Doctype) + + def unknown_decl(self, data): + if data.upper().startswith('CDATA['): + cls = CData + data = data[len('CDATA['):] + else: + cls = Declaration + self.soup.endData() + self.soup.handle_data(data) + self.soup.endData(cls) + + def handle_pi(self, data): + self.soup.endData() + self.soup.handle_data(data) + self.soup.endData(ProcessingInstruction) + + +class HTMLParserTreeBuilder(HTMLTreeBuilder): + + is_xml = False + picklable = True + NAME = HTMLPARSER + features = [NAME, HTML, STRICT] + + def __init__(self, parser_args=None, parser_kwargs=None, **kwargs): + super(HTMLParserTreeBuilder, self).__init__(**kwargs) + parser_args = parser_args or [] + parser_kwargs = parser_kwargs or {} + if CONSTRUCTOR_TAKES_STRICT and not CONSTRUCTOR_STRICT_IS_DEPRECATED: + parser_kwargs['strict'] = False + if CONSTRUCTOR_TAKES_CONVERT_CHARREFS: + parser_kwargs['convert_charrefs'] = False + self.parser_args = (parser_args, parser_kwargs) + + def prepare_markup(self, markup, user_specified_encoding=None, + document_declared_encoding=None, exclude_encodings=None): + """ + :return: A 4-tuple (markup, original encoding, encoding + declared within markup, whether any characters had to be + replaced with REPLACEMENT CHARACTER). + """ + if isinstance(markup, str): + yield (markup, None, None, False) + return + + try_encodings = [user_specified_encoding, document_declared_encoding] + dammit = UnicodeDammit(markup, try_encodings, is_html=True, + exclude_encodings=exclude_encodings) + yield (dammit.markup, dammit.original_encoding, + dammit.declared_html_encoding, + dammit.contains_replacement_characters) + + def feed(self, markup): + args, kwargs = self.parser_args + parser = BeautifulSoupHTMLParser(*args, **kwargs) + parser.soup = self.soup + try: + parser.feed(markup) + parser.close() + except HTMLParseError as e: + warnings.warn(RuntimeWarning( + "Python's built-in HTMLParser cannot parse the given document. This is not a bug in Beautiful Soup. The best solution is to install an external parser (lxml or html5lib), and use Beautiful Soup with that parser. See http://www.crummy.com/software/BeautifulSoup/bs4/doc/#installing-a-parser for help.")) + raise e + parser.already_closed_empty_element = [] + +# Patch 3.2 versions of HTMLParser earlier than 3.2.3 to use some +# 3.2.3 code. This ensures they don't treat markup like <p></p> as a +# string. +# +# XXX This code can be removed once most Python 3 users are on 3.2.3. +if major == 3 and minor == 2 and not CONSTRUCTOR_TAKES_STRICT: + import re + attrfind_tolerant = re.compile( + r'\s*((?<=[\'"\s])[^\s/>][^\s/=>]*)(\s*=+\s*' + r'(\'[^\']*\'|"[^"]*"|(?![\'"])[^>\s]*))?') + HTMLParserTreeBuilder.attrfind_tolerant = attrfind_tolerant + + locatestarttagend = re.compile(r""" + <[a-zA-Z][-.a-zA-Z0-9:_]* # tag name + (?:\s+ # whitespace before attribute name + (?:[a-zA-Z_][-.:a-zA-Z0-9_]* # attribute name + (?:\s*=\s* # value indicator + (?:'[^']*' # LITA-enclosed value + |\"[^\"]*\" # LIT-enclosed value + |[^'\">\s]+ # bare value + ) + )? + ) + )* + \s* # trailing whitespace +""", re.VERBOSE) + BeautifulSoupHTMLParser.locatestarttagend = locatestarttagend + + from html.parser import tagfind, attrfind + + def parse_starttag(self, i): + self.__starttag_text = None + endpos = self.check_for_whole_start_tag(i) + if endpos < 0: + return endpos + rawdata = self.rawdata + self.__starttag_text = rawdata[i:endpos] + + # Now parse the data between i+1 and j into a tag and attrs + attrs = [] + match = tagfind.match(rawdata, i+1) + assert match, 'unexpected call to parse_starttag()' + k = match.end() + self.lasttag = tag = rawdata[i+1:k].lower() + while k < endpos: + if self.strict: + m = attrfind.match(rawdata, k) + else: + m = attrfind_tolerant.match(rawdata, k) + if not m: + break + attrname, rest, attrvalue = m.group(1, 2, 3) + if not rest: + attrvalue = None + elif attrvalue[:1] == '\'' == attrvalue[-1:] or \ + attrvalue[:1] == '"' == attrvalue[-1:]: + attrvalue = attrvalue[1:-1] + if attrvalue: + attrvalue = self.unescape(attrvalue) + attrs.append((attrname.lower(), attrvalue)) + k = m.end() + + end = rawdata[k:endpos].strip() + if end not in (">", "/>"): + lineno, offset = self.getpos() + if "\n" in self.__starttag_text: + lineno = lineno + self.__starttag_text.count("\n") + offset = len(self.__starttag_text) \ + - self.__starttag_text.rfind("\n") + else: + offset = offset + len(self.__starttag_text) + if self.strict: + self.error("junk characters in start tag: %r" + % (rawdata[k:endpos][:20],)) + self.handle_data(rawdata[i:endpos]) + return endpos + if end.endswith('/>'): + # XHTML-style empty tag: <span attr="value" /> + self.handle_startendtag(tag, attrs) + else: + self.handle_starttag(tag, attrs) + if tag in self.CDATA_CONTENT_ELEMENTS: + self.set_cdata_mode(tag) + return endpos + + def set_cdata_mode(self, elem): + self.cdata_elem = elem.lower() + self.interesting = re.compile(r'</\s*%s\s*>' % self.cdata_elem, re.I) + + BeautifulSoupHTMLParser.parse_starttag = parse_starttag + BeautifulSoupHTMLParser.set_cdata_mode = set_cdata_mode + + CONSTRUCTOR_TAKES_STRICT = True diff --git a/libs/bs4/builder/_lxml.py b/libs/bs4/builder/_lxml.py new file mode 100644 index 000000000..a490e2301 --- /dev/null +++ b/libs/bs4/builder/_lxml.py @@ -0,0 +1,296 @@ +# Use of this source code is governed by the MIT license. +__license__ = "MIT" + +__all__ = [ + 'LXMLTreeBuilderForXML', + 'LXMLTreeBuilder', + ] + +try: + from collections.abc import Callable # Python 3.6 +except ImportError as e: + from collections import Callable + +from io import BytesIO +from io import StringIO +from lxml import etree +from bs4.element import ( + Comment, + Doctype, + NamespacedAttribute, + ProcessingInstruction, + XMLProcessingInstruction, +) +from bs4.builder import ( + FAST, + HTML, + HTMLTreeBuilder, + PERMISSIVE, + ParserRejectedMarkup, + TreeBuilder, + XML) +from bs4.dammit import EncodingDetector + +LXML = 'lxml' + +def _invert(d): + "Invert a dictionary." + return dict((v,k) for k, v in list(d.items())) + +class LXMLTreeBuilderForXML(TreeBuilder): + DEFAULT_PARSER_CLASS = etree.XMLParser + + is_xml = True + processing_instruction_class = XMLProcessingInstruction + + NAME = "lxml-xml" + ALTERNATE_NAMES = ["xml"] + + # Well, it's permissive by XML parser standards. + features = [NAME, LXML, XML, FAST, PERMISSIVE] + + CHUNK_SIZE = 512 + + # This namespace mapping is specified in the XML Namespace + # standard. + DEFAULT_NSMAPS = dict(xml='http://www.w3.org/XML/1998/namespace') + + DEFAULT_NSMAPS_INVERTED = _invert(DEFAULT_NSMAPS) + + def initialize_soup(self, soup): + """Let the BeautifulSoup object know about the standard namespace + mapping. + """ + super(LXMLTreeBuilderForXML, self).initialize_soup(soup) + self._register_namespaces(self.DEFAULT_NSMAPS) + + def _register_namespaces(self, mapping): + """Let the BeautifulSoup object know about namespaces encountered + while parsing the document. + + This might be useful later on when creating CSS selectors. + """ + for key, value in list(mapping.items()): + if key and key not in self.soup._namespaces: + # Let the BeautifulSoup object know about a new namespace. + # If there are multiple namespaces defined with the same + # prefix, the first one in the document takes precedence. + self.soup._namespaces[key] = value + + def default_parser(self, encoding): + # This can either return a parser object or a class, which + # will be instantiated with default arguments. + if self._default_parser is not None: + return self._default_parser + return etree.XMLParser( + target=self, strip_cdata=False, recover=True, encoding=encoding) + + def parser_for(self, encoding): + # Use the default parser. + parser = self.default_parser(encoding) + + if isinstance(parser, Callable): + # Instantiate the parser with default arguments + parser = parser(target=self, strip_cdata=False, encoding=encoding) + return parser + + def __init__(self, parser=None, empty_element_tags=None, **kwargs): + # TODO: Issue a warning if parser is present but not a + # callable, since that means there's no way to create new + # parsers for different encodings. + self._default_parser = parser + if empty_element_tags is not None: + self.empty_element_tags = set(empty_element_tags) + self.soup = None + self.nsmaps = [self.DEFAULT_NSMAPS_INVERTED] + super(LXMLTreeBuilderForXML, self).__init__(**kwargs) + + def _getNsTag(self, tag): + # Split the namespace URL out of a fully-qualified lxml tag + # name. Copied from lxml's src/lxml/sax.py. + if tag[0] == '{': + return tuple(tag[1:].split('}', 1)) + else: + return (None, tag) + + def prepare_markup(self, markup, user_specified_encoding=None, + exclude_encodings=None, + document_declared_encoding=None): + """ + :yield: A series of 4-tuples. + (markup, encoding, declared encoding, + has undergone character replacement) + + Each 4-tuple represents a strategy for parsing the document. + """ + # Instead of using UnicodeDammit to convert the bytestring to + # Unicode using different encodings, use EncodingDetector to + # iterate over the encodings, and tell lxml to try to parse + # the document as each one in turn. + is_html = not self.is_xml + if is_html: + self.processing_instruction_class = ProcessingInstruction + else: + self.processing_instruction_class = XMLProcessingInstruction + + if isinstance(markup, str): + # We were given Unicode. Maybe lxml can parse Unicode on + # this system? + yield markup, None, document_declared_encoding, False + + if isinstance(markup, str): + # No, apparently not. Convert the Unicode to UTF-8 and + # tell lxml to parse it as UTF-8. + yield (markup.encode("utf8"), "utf8", + document_declared_encoding, False) + + try_encodings = [user_specified_encoding, document_declared_encoding] + detector = EncodingDetector( + markup, try_encodings, is_html, exclude_encodings) + for encoding in detector.encodings: + yield (detector.markup, encoding, document_declared_encoding, False) + + def feed(self, markup): + if isinstance(markup, bytes): + markup = BytesIO(markup) + elif isinstance(markup, str): + markup = StringIO(markup) + + # Call feed() at least once, even if the markup is empty, + # or the parser won't be initialized. + data = markup.read(self.CHUNK_SIZE) + try: + self.parser = self.parser_for(self.soup.original_encoding) + self.parser.feed(data) + while len(data) != 0: + # Now call feed() on the rest of the data, chunk by chunk. + data = markup.read(self.CHUNK_SIZE) + if len(data) != 0: + self.parser.feed(data) + self.parser.close() + except (UnicodeDecodeError, LookupError, etree.ParserError) as e: + raise ParserRejectedMarkup(str(e)) + + def close(self): + self.nsmaps = [self.DEFAULT_NSMAPS_INVERTED] + + def start(self, name, attrs, nsmap={}): + # Make sure attrs is a mutable dict--lxml may send an immutable dictproxy. + attrs = dict(attrs) + nsprefix = None + # Invert each namespace map as it comes in. + if len(nsmap) == 0 and len(self.nsmaps) > 1: + # There are no new namespaces for this tag, but + # non-default namespaces are in play, so we need a + # separate tag stack to know when they end. + self.nsmaps.append(None) + elif len(nsmap) > 0: + # A new namespace mapping has come into play. + + # First, Let the BeautifulSoup object know about it. + self._register_namespaces(nsmap) + + # Then, add it to our running list of inverted namespace + # mappings. + self.nsmaps.append(_invert(nsmap)) + + # Also treat the namespace mapping as a set of attributes on the + # tag, so we can recreate it later. + attrs = attrs.copy() + for prefix, namespace in list(nsmap.items()): + attribute = NamespacedAttribute( + "xmlns", prefix, "http://www.w3.org/2000/xmlns/") + attrs[attribute] = namespace + + # Namespaces are in play. Find any attributes that came in + # from lxml with namespaces attached to their names, and + # turn then into NamespacedAttribute objects. + new_attrs = {} + for attr, value in list(attrs.items()): + namespace, attr = self._getNsTag(attr) + if namespace is None: + new_attrs[attr] = value + else: + nsprefix = self._prefix_for_namespace(namespace) + attr = NamespacedAttribute(nsprefix, attr, namespace) + new_attrs[attr] = value + attrs = new_attrs + + namespace, name = self._getNsTag(name) + nsprefix = self._prefix_for_namespace(namespace) + self.soup.handle_starttag(name, namespace, nsprefix, attrs) + + def _prefix_for_namespace(self, namespace): + """Find the currently active prefix for the given namespace.""" + if namespace is None: + return None + for inverted_nsmap in reversed(self.nsmaps): + if inverted_nsmap is not None and namespace in inverted_nsmap: + return inverted_nsmap[namespace] + return None + + def end(self, name): + self.soup.endData() + completed_tag = self.soup.tagStack[-1] + namespace, name = self._getNsTag(name) + nsprefix = None + if namespace is not None: + for inverted_nsmap in reversed(self.nsmaps): + if inverted_nsmap is not None and namespace in inverted_nsmap: + nsprefix = inverted_nsmap[namespace] + break + self.soup.handle_endtag(name, nsprefix) + if len(self.nsmaps) > 1: + # This tag, or one of its parents, introduced a namespace + # mapping, so pop it off the stack. + self.nsmaps.pop() + + def pi(self, target, data): + self.soup.endData() + self.soup.handle_data(target + ' ' + data) + self.soup.endData(self.processing_instruction_class) + + def data(self, content): + self.soup.handle_data(content) + + def doctype(self, name, pubid, system): + self.soup.endData() + doctype = Doctype.for_name_and_ids(name, pubid, system) + self.soup.object_was_parsed(doctype) + + def comment(self, content): + "Handle comments as Comment objects." + self.soup.endData() + self.soup.handle_data(content) + self.soup.endData(Comment) + + def test_fragment_to_document(self, fragment): + """See `TreeBuilder`.""" + return '<?xml version="1.0" encoding="utf-8"?>\n%s' % fragment + + +class LXMLTreeBuilder(HTMLTreeBuilder, LXMLTreeBuilderForXML): + + NAME = LXML + ALTERNATE_NAMES = ["lxml-html"] + + features = ALTERNATE_NAMES + [NAME, HTML, FAST, PERMISSIVE] + is_xml = False + processing_instruction_class = ProcessingInstruction + + def default_parser(self, encoding): + return etree.HTMLParser + + def feed(self, markup): + encoding = self.soup.original_encoding + try: + self.parser = self.parser_for(encoding) + self.parser.feed(markup) + self.parser.close() + except (UnicodeDecodeError, LookupError, etree.ParserError) as e: + raise ParserRejectedMarkup(str(e)) + + + def test_fragment_to_document(self, fragment): + """See `TreeBuilder`.""" + return '<html><body>%s</body></html>' % fragment diff --git a/libs/bs4/dammit.py b/libs/bs4/dammit.py new file mode 100644 index 000000000..c7ac4d431 --- /dev/null +++ b/libs/bs4/dammit.py @@ -0,0 +1,850 @@ +# -*- coding: utf-8 -*- +"""Beautiful Soup bonus library: Unicode, Dammit + +This library converts a bytestream to Unicode through any means +necessary. It is heavily based on code from Mark Pilgrim's Universal +Feed Parser. It works best on XML and HTML, but it does not rewrite the +XML or HTML to reflect a new encoding; that's the tree builder's job. +""" +# Use of this source code is governed by the MIT license. +__license__ = "MIT" + +import codecs +from html.entities import codepoint2name +import re +import logging +import string + +# Import a library to autodetect character encodings. +chardet_type = None +try: + # First try the fast C implementation. + # PyPI package: cchardet + import cchardet + def chardet_dammit(s): + return cchardet.detect(s)['encoding'] +except ImportError: + try: + # Fall back to the pure Python implementation + # Debian package: python-chardet + # PyPI package: chardet + import chardet + def chardet_dammit(s): + return chardet.detect(s)['encoding'] + #import chardet.constants + #chardet.constants._debug = 1 + except ImportError: + # No chardet available. + def chardet_dammit(s): + return None + +# Available from http://cjkpython.i18n.org/. +try: + import iconv_codec +except ImportError: + pass + +xml_encoding_re = re.compile( + '^<\\?.*encoding=[\'"](.*?)[\'"].*\\?>'.encode(), re.I) +html_meta_re = re.compile( + '<\\s*meta[^>]+charset\\s*=\\s*["\']?([^>]*?)[ /;\'">]'.encode(), re.I) + +class EntitySubstitution(object): + + """Substitute XML or HTML entities for the corresponding characters.""" + + def _populate_class_variables(): + lookup = {} + reverse_lookup = {} + characters_for_re = [] + + # &apos is an XHTML entity and an HTML 5, but not an HTML 4 + # entity. We don't want to use it, but we want to recognize it on the way in. + # + # TODO: Ideally we would be able to recognize all HTML 5 named + # entities, but that's a little tricky. + extra = [(39, 'apos')] + for codepoint, name in list(codepoint2name.items()) + extra: + character = chr(codepoint) + if codepoint not in (34, 39): + # There's no point in turning the quotation mark into + # " or the single quote into ', unless it + # happens within an attribute value, which is handled + # elsewhere. + characters_for_re.append(character) + lookup[character] = name + # But we do want to recognize those entities on the way in and + # convert them to Unicode characters. + reverse_lookup[name] = character + re_definition = "[%s]" % "".join(characters_for_re) + return lookup, reverse_lookup, re.compile(re_definition) + (CHARACTER_TO_HTML_ENTITY, HTML_ENTITY_TO_CHARACTER, + CHARACTER_TO_HTML_ENTITY_RE) = _populate_class_variables() + + CHARACTER_TO_XML_ENTITY = { + "'": "apos", + '"': "quot", + "&": "amp", + "<": "lt", + ">": "gt", + } + + BARE_AMPERSAND_OR_BRACKET = re.compile("([<>]|" + "&(?!#\\d+;|#x[0-9a-fA-F]+;|\\w+;)" + ")") + + AMPERSAND_OR_BRACKET = re.compile("([<>&])") + + @classmethod + def _substitute_html_entity(cls, matchobj): + entity = cls.CHARACTER_TO_HTML_ENTITY.get(matchobj.group(0)) + return "&%s;" % entity + + @classmethod + def _substitute_xml_entity(cls, matchobj): + """Used with a regular expression to substitute the + appropriate XML entity for an XML special character.""" + entity = cls.CHARACTER_TO_XML_ENTITY[matchobj.group(0)] + return "&%s;" % entity + + @classmethod + def quoted_attribute_value(self, value): + """Make a value into a quoted XML attribute, possibly escaping it. + + Most strings will be quoted using double quotes. + + Bob's Bar -> "Bob's Bar" + + If a string contains double quotes, it will be quoted using + single quotes. + + Welcome to "my bar" -> 'Welcome to "my bar"' + + If a string contains both single and double quotes, the + double quotes will be escaped, and the string will be quoted + using double quotes. + + Welcome to "Bob's Bar" -> "Welcome to "Bob's bar" + """ + quote_with = '"' + if '"' in value: + if "'" in value: + # The string contains both single and double + # quotes. Turn the double quotes into + # entities. We quote the double quotes rather than + # the single quotes because the entity name is + # """ whether this is HTML or XML. If we + # quoted the single quotes, we'd have to decide + # between ' and &squot;. + replace_with = """ + value = value.replace('"', replace_with) + else: + # There are double quotes but no single quotes. + # We can use single quotes to quote the attribute. + quote_with = "'" + return quote_with + value + quote_with + + @classmethod + def substitute_xml(cls, value, make_quoted_attribute=False): + """Substitute XML entities for special XML characters. + + :param value: A string to be substituted. The less-than sign + will become <, the greater-than sign will become >, + and any ampersands will become &. If you want ampersands + that appear to be part of an entity definition to be left + alone, use substitute_xml_containing_entities() instead. + + :param make_quoted_attribute: If True, then the string will be + quoted, as befits an attribute value. + """ + # Escape angle brackets and ampersands. + value = cls.AMPERSAND_OR_BRACKET.sub( + cls._substitute_xml_entity, value) + + if make_quoted_attribute: + value = cls.quoted_attribute_value(value) + return value + + @classmethod + def substitute_xml_containing_entities( + cls, value, make_quoted_attribute=False): + """Substitute XML entities for special XML characters. + + :param value: A string to be substituted. The less-than sign will + become <, the greater-than sign will become >, and any + ampersands that are not part of an entity defition will + become &. + + :param make_quoted_attribute: If True, then the string will be + quoted, as befits an attribute value. + """ + # Escape angle brackets, and ampersands that aren't part of + # entities. + value = cls.BARE_AMPERSAND_OR_BRACKET.sub( + cls._substitute_xml_entity, value) + + if make_quoted_attribute: + value = cls.quoted_attribute_value(value) + return value + + @classmethod + def substitute_html(cls, s): + """Replace certain Unicode characters with named HTML entities. + + This differs from data.encode(encoding, 'xmlcharrefreplace') + in that the goal is to make the result more readable (to those + with ASCII displays) rather than to recover from + errors. There's absolutely nothing wrong with a UTF-8 string + containg a LATIN SMALL LETTER E WITH ACUTE, but replacing that + character with "é" will make it more readable to some + people. + """ + return cls.CHARACTER_TO_HTML_ENTITY_RE.sub( + cls._substitute_html_entity, s) + + +class EncodingDetector: + """Suggests a number of possible encodings for a bytestring. + + Order of precedence: + + 1. Encodings you specifically tell EncodingDetector to try first + (the override_encodings argument to the constructor). + + 2. An encoding declared within the bytestring itself, either in an + XML declaration (if the bytestring is to be interpreted as an XML + document), or in a <meta> tag (if the bytestring is to be + interpreted as an HTML document.) + + 3. An encoding detected through textual analysis by chardet, + cchardet, or a similar external library. + + 4. UTF-8. + + 5. Windows-1252. + """ + def __init__(self, markup, override_encodings=None, is_html=False, + exclude_encodings=None): + self.override_encodings = override_encodings or [] + exclude_encodings = exclude_encodings or [] + self.exclude_encodings = set([x.lower() for x in exclude_encodings]) + self.chardet_encoding = None + self.is_html = is_html + self.declared_encoding = None + + # First order of business: strip a byte-order mark. + self.markup, self.sniffed_encoding = self.strip_byte_order_mark(markup) + + def _usable(self, encoding, tried): + if encoding is not None: + encoding = encoding.lower() + if encoding in self.exclude_encodings: + return False + if encoding not in tried: + tried.add(encoding) + return True + return False + + @property + def encodings(self): + """Yield a number of encodings that might work for this markup.""" + tried = set() + for e in self.override_encodings: + if self._usable(e, tried): + yield e + + # Did the document originally start with a byte-order mark + # that indicated its encoding? + if self._usable(self.sniffed_encoding, tried): + yield self.sniffed_encoding + + # Look within the document for an XML or HTML encoding + # declaration. + if self.declared_encoding is None: + self.declared_encoding = self.find_declared_encoding( + self.markup, self.is_html) + if self._usable(self.declared_encoding, tried): + yield self.declared_encoding + + # Use third-party character set detection to guess at the + # encoding. + if self.chardet_encoding is None: + self.chardet_encoding = chardet_dammit(self.markup) + if self._usable(self.chardet_encoding, tried): + yield self.chardet_encoding + + # As a last-ditch effort, try utf-8 and windows-1252. + for e in ('utf-8', 'windows-1252'): + if self._usable(e, tried): + yield e + + @classmethod + def strip_byte_order_mark(cls, data): + """If a byte-order mark is present, strip it and return the encoding it implies.""" + encoding = None + if isinstance(data, str): + # Unicode data cannot have a byte-order mark. + return data, encoding + if (len(data) >= 4) and (data[:2] == b'\xfe\xff') \ + and (data[2:4] != '\x00\x00'): + encoding = 'utf-16be' + data = data[2:] + elif (len(data) >= 4) and (data[:2] == b'\xff\xfe') \ + and (data[2:4] != '\x00\x00'): + encoding = 'utf-16le' + data = data[2:] + elif data[:3] == b'\xef\xbb\xbf': + encoding = 'utf-8' + data = data[3:] + elif data[:4] == b'\x00\x00\xfe\xff': + encoding = 'utf-32be' + data = data[4:] + elif data[:4] == b'\xff\xfe\x00\x00': + encoding = 'utf-32le' + data = data[4:] + return data, encoding + + @classmethod + def find_declared_encoding(cls, markup, is_html=False, search_entire_document=False): + """Given a document, tries to find its declared encoding. + + An XML encoding is declared at the beginning of the document. + + An HTML encoding is declared in a <meta> tag, hopefully near the + beginning of the document. + """ + if search_entire_document: + xml_endpos = html_endpos = len(markup) + else: + xml_endpos = 1024 + html_endpos = max(2048, int(len(markup) * 0.05)) + + declared_encoding = None + declared_encoding_match = xml_encoding_re.search(markup, endpos=xml_endpos) + if not declared_encoding_match and is_html: + declared_encoding_match = html_meta_re.search(markup, endpos=html_endpos) + if declared_encoding_match is not None: + declared_encoding = declared_encoding_match.groups()[0].decode( + 'ascii', 'replace') + if declared_encoding: + return declared_encoding.lower() + return None + +class UnicodeDammit: + """A class for detecting the encoding of a *ML document and + converting it to a Unicode string. If the source encoding is + windows-1252, can replace MS smart quotes with their HTML or XML + equivalents.""" + + # This dictionary maps commonly seen values for "charset" in HTML + # meta tags to the corresponding Python codec names. It only covers + # values that aren't in Python's aliases and can't be determined + # by the heuristics in find_codec. + CHARSET_ALIASES = {"macintosh": "mac-roman", + "x-sjis": "shift-jis"} + + ENCODINGS_WITH_SMART_QUOTES = [ + "windows-1252", + "iso-8859-1", + "iso-8859-2", + ] + + def __init__(self, markup, override_encodings=[], + smart_quotes_to=None, is_html=False, exclude_encodings=[]): + self.smart_quotes_to = smart_quotes_to + self.tried_encodings = [] + self.contains_replacement_characters = False + self.is_html = is_html + self.log = logging.getLogger(__name__) + self.detector = EncodingDetector( + markup, override_encodings, is_html, exclude_encodings) + + # Short-circuit if the data is in Unicode to begin with. + if isinstance(markup, str) or markup == '': + self.markup = markup + self.unicode_markup = str(markup) + self.original_encoding = None + return + + # The encoding detector may have stripped a byte-order mark. + # Use the stripped markup from this point on. + self.markup = self.detector.markup + + u = None + for encoding in self.detector.encodings: + markup = self.detector.markup + u = self._convert_from(encoding) + if u is not None: + break + + if not u: + # None of the encodings worked. As an absolute last resort, + # try them again with character replacement. + + for encoding in self.detector.encodings: + if encoding != "ascii": + u = self._convert_from(encoding, "replace") + if u is not None: + self.log.warning( + "Some characters could not be decoded, and were " + "replaced with REPLACEMENT CHARACTER." + ) + self.contains_replacement_characters = True + break + + # If none of that worked, we could at this point force it to + # ASCII, but that would destroy so much data that I think + # giving up is better. + self.unicode_markup = u + if not u: + self.original_encoding = None + + def _sub_ms_char(self, match): + """Changes a MS smart quote character to an XML or HTML + entity, or an ASCII character.""" + orig = match.group(1) + if self.smart_quotes_to == 'ascii': + sub = self.MS_CHARS_TO_ASCII.get(orig).encode() + else: + sub = self.MS_CHARS.get(orig) + if type(sub) == tuple: + if self.smart_quotes_to == 'xml': + sub = '&#x'.encode() + sub[1].encode() + ';'.encode() + else: + sub = '&'.encode() + sub[0].encode() + ';'.encode() + else: + sub = sub.encode() + return sub + + def _convert_from(self, proposed, errors="strict"): + proposed = self.find_codec(proposed) + if not proposed or (proposed, errors) in self.tried_encodings: + return None + self.tried_encodings.append((proposed, errors)) + markup = self.markup + # Convert smart quotes to HTML if coming from an encoding + # that might have them. + if (self.smart_quotes_to is not None + and proposed in self.ENCODINGS_WITH_SMART_QUOTES): + smart_quotes_re = b"([\x80-\x9f])" + smart_quotes_compiled = re.compile(smart_quotes_re) + markup = smart_quotes_compiled.sub(self._sub_ms_char, markup) + + try: + #print "Trying to convert document to %s (errors=%s)" % ( + # proposed, errors) + u = self._to_unicode(markup, proposed, errors) + self.markup = u + self.original_encoding = proposed + except Exception as e: + #print "That didn't work!" + #print e + return None + #print "Correct encoding: %s" % proposed + return self.markup + + def _to_unicode(self, data, encoding, errors="strict"): + '''Given a string and its encoding, decodes the string into Unicode. + %encoding is a string recognized by encodings.aliases''' + return str(data, encoding, errors) + + @property + def declared_html_encoding(self): + if not self.is_html: + return None + return self.detector.declared_encoding + + def find_codec(self, charset): + value = (self._codec(self.CHARSET_ALIASES.get(charset, charset)) + or (charset and self._codec(charset.replace("-", ""))) + or (charset and self._codec(charset.replace("-", "_"))) + or (charset and charset.lower()) + or charset + ) + if value: + return value.lower() + return None + + def _codec(self, charset): + if not charset: + return charset + codec = None + try: + codecs.lookup(charset) + codec = charset + except (LookupError, ValueError): + pass + return codec + + + # A partial mapping of ISO-Latin-1 to HTML entities/XML numeric entities. + MS_CHARS = {b'\x80': ('euro', '20AC'), + b'\x81': ' ', + b'\x82': ('sbquo', '201A'), + b'\x83': ('fnof', '192'), + b'\x84': ('bdquo', '201E'), + b'\x85': ('hellip', '2026'), + b'\x86': ('dagger', '2020'), + b'\x87': ('Dagger', '2021'), + b'\x88': ('circ', '2C6'), + b'\x89': ('permil', '2030'), + b'\x8A': ('Scaron', '160'), + b'\x8B': ('lsaquo', '2039'), + b'\x8C': ('OElig', '152'), + b'\x8D': '?', + b'\x8E': ('#x17D', '17D'), + b'\x8F': '?', + b'\x90': '?', + b'\x91': ('lsquo', '2018'), + b'\x92': ('rsquo', '2019'), + b'\x93': ('ldquo', '201C'), + b'\x94': ('rdquo', '201D'), + b'\x95': ('bull', '2022'), + b'\x96': ('ndash', '2013'), + b'\x97': ('mdash', '2014'), + b'\x98': ('tilde', '2DC'), + b'\x99': ('trade', '2122'), + b'\x9a': ('scaron', '161'), + b'\x9b': ('rsaquo', '203A'), + b'\x9c': ('oelig', '153'), + b'\x9d': '?', + b'\x9e': ('#x17E', '17E'), + b'\x9f': ('Yuml', ''),} + + # A parochial partial mapping of ISO-Latin-1 to ASCII. Contains + # horrors like stripping diacritical marks to turn á into a, but also + # contains non-horrors like turning “ into ". + MS_CHARS_TO_ASCII = { + b'\x80' : 'EUR', + b'\x81' : ' ', + b'\x82' : ',', + b'\x83' : 'f', + b'\x84' : ',,', + b'\x85' : '...', + b'\x86' : '+', + b'\x87' : '++', + b'\x88' : '^', + b'\x89' : '%', + b'\x8a' : 'S', + b'\x8b' : '<', + b'\x8c' : 'OE', + b'\x8d' : '?', + b'\x8e' : 'Z', + b'\x8f' : '?', + b'\x90' : '?', + b'\x91' : "'", + b'\x92' : "'", + b'\x93' : '"', + b'\x94' : '"', + b'\x95' : '*', + b'\x96' : '-', + b'\x97' : '--', + b'\x98' : '~', + b'\x99' : '(TM)', + b'\x9a' : 's', + b'\x9b' : '>', + b'\x9c' : 'oe', + b'\x9d' : '?', + b'\x9e' : 'z', + b'\x9f' : 'Y', + b'\xa0' : ' ', + b'\xa1' : '!', + b'\xa2' : 'c', + b'\xa3' : 'GBP', + b'\xa4' : '$', #This approximation is especially parochial--this is the + #generic currency symbol. + b'\xa5' : 'YEN', + b'\xa6' : '|', + b'\xa7' : 'S', + b'\xa8' : '..', + b'\xa9' : '', + b'\xaa' : '(th)', + b'\xab' : '<<', + b'\xac' : '!', + b'\xad' : ' ', + b'\xae' : '(R)', + b'\xaf' : '-', + b'\xb0' : 'o', + b'\xb1' : '+-', + b'\xb2' : '2', + b'\xb3' : '3', + b'\xb4' : ("'", 'acute'), + b'\xb5' : 'u', + b'\xb6' : 'P', + b'\xb7' : '*', + b'\xb8' : ',', + b'\xb9' : '1', + b'\xba' : '(th)', + b'\xbb' : '>>', + b'\xbc' : '1/4', + b'\xbd' : '1/2', + b'\xbe' : '3/4', + b'\xbf' : '?', + b'\xc0' : 'A', + b'\xc1' : 'A', + b'\xc2' : 'A', + b'\xc3' : 'A', + b'\xc4' : 'A', + b'\xc5' : 'A', + b'\xc6' : 'AE', + b'\xc7' : 'C', + b'\xc8' : 'E', + b'\xc9' : 'E', + b'\xca' : 'E', + b'\xcb' : 'E', + b'\xcc' : 'I', + b'\xcd' : 'I', + b'\xce' : 'I', + b'\xcf' : 'I', + b'\xd0' : 'D', + b'\xd1' : 'N', + b'\xd2' : 'O', + b'\xd3' : 'O', + b'\xd4' : 'O', + b'\xd5' : 'O', + b'\xd6' : 'O', + b'\xd7' : '*', + b'\xd8' : 'O', + b'\xd9' : 'U', + b'\xda' : 'U', + b'\xdb' : 'U', + b'\xdc' : 'U', + b'\xdd' : 'Y', + b'\xde' : 'b', + b'\xdf' : 'B', + b'\xe0' : 'a', + b'\xe1' : 'a', + b'\xe2' : 'a', + b'\xe3' : 'a', + b'\xe4' : 'a', + b'\xe5' : 'a', + b'\xe6' : 'ae', + b'\xe7' : 'c', + b'\xe8' : 'e', + b'\xe9' : 'e', + b'\xea' : 'e', + b'\xeb' : 'e', + b'\xec' : 'i', + b'\xed' : 'i', + b'\xee' : 'i', + b'\xef' : 'i', + b'\xf0' : 'o', + b'\xf1' : 'n', + b'\xf2' : 'o', + b'\xf3' : 'o', + b'\xf4' : 'o', + b'\xf5' : 'o', + b'\xf6' : 'o', + b'\xf7' : '/', + b'\xf8' : 'o', + b'\xf9' : 'u', + b'\xfa' : 'u', + b'\xfb' : 'u', + b'\xfc' : 'u', + b'\xfd' : 'y', + b'\xfe' : 'b', + b'\xff' : 'y', + } + + # A map used when removing rogue Windows-1252/ISO-8859-1 + # characters in otherwise UTF-8 documents. + # + # Note that \x81, \x8d, \x8f, \x90, and \x9d are undefined in + # Windows-1252. + WINDOWS_1252_TO_UTF8 = { + 0x80 : b'\xe2\x82\xac', # € + 0x82 : b'\xe2\x80\x9a', # ‚ + 0x83 : b'\xc6\x92', # ƒ + 0x84 : b'\xe2\x80\x9e', # „ + 0x85 : b'\xe2\x80\xa6', # … + 0x86 : b'\xe2\x80\xa0', # † + 0x87 : b'\xe2\x80\xa1', # ‡ + 0x88 : b'\xcb\x86', # ˆ + 0x89 : b'\xe2\x80\xb0', # ‰ + 0x8a : b'\xc5\xa0', # Š + 0x8b : b'\xe2\x80\xb9', # ‹ + 0x8c : b'\xc5\x92', # Œ + 0x8e : b'\xc5\xbd', # Ž + 0x91 : b'\xe2\x80\x98', # ‘ + 0x92 : b'\xe2\x80\x99', # ’ + 0x93 : b'\xe2\x80\x9c', # “ + 0x94 : b'\xe2\x80\x9d', # ” + 0x95 : b'\xe2\x80\xa2', # • + 0x96 : b'\xe2\x80\x93', # – + 0x97 : b'\xe2\x80\x94', # — + 0x98 : b'\xcb\x9c', # ˜ + 0x99 : b'\xe2\x84\xa2', # ™ + 0x9a : b'\xc5\xa1', # š + 0x9b : b'\xe2\x80\xba', # › + 0x9c : b'\xc5\x93', # œ + 0x9e : b'\xc5\xbe', # ž + 0x9f : b'\xc5\xb8', # Ÿ + 0xa0 : b'\xc2\xa0', # + 0xa1 : b'\xc2\xa1', # ¡ + 0xa2 : b'\xc2\xa2', # ¢ + 0xa3 : b'\xc2\xa3', # £ + 0xa4 : b'\xc2\xa4', # ¤ + 0xa5 : b'\xc2\xa5', # ¥ + 0xa6 : b'\xc2\xa6', # ¦ + 0xa7 : b'\xc2\xa7', # § + 0xa8 : b'\xc2\xa8', # ¨ + 0xa9 : b'\xc2\xa9', # © + 0xaa : b'\xc2\xaa', # ª + 0xab : b'\xc2\xab', # « + 0xac : b'\xc2\xac', # ¬ + 0xad : b'\xc2\xad', # + 0xae : b'\xc2\xae', # ® + 0xaf : b'\xc2\xaf', # ¯ + 0xb0 : b'\xc2\xb0', # ° + 0xb1 : b'\xc2\xb1', # ± + 0xb2 : b'\xc2\xb2', # ² + 0xb3 : b'\xc2\xb3', # ³ + 0xb4 : b'\xc2\xb4', # ´ + 0xb5 : b'\xc2\xb5', # µ + 0xb6 : b'\xc2\xb6', # ¶ + 0xb7 : b'\xc2\xb7', # · + 0xb8 : b'\xc2\xb8', # ¸ + 0xb9 : b'\xc2\xb9', # ¹ + 0xba : b'\xc2\xba', # º + 0xbb : b'\xc2\xbb', # » + 0xbc : b'\xc2\xbc', # ¼ + 0xbd : b'\xc2\xbd', # ½ + 0xbe : b'\xc2\xbe', # ¾ + 0xbf : b'\xc2\xbf', # ¿ + 0xc0 : b'\xc3\x80', # À + 0xc1 : b'\xc3\x81', # Á + 0xc2 : b'\xc3\x82', #  + 0xc3 : b'\xc3\x83', # à + 0xc4 : b'\xc3\x84', # Ä + 0xc5 : b'\xc3\x85', # Å + 0xc6 : b'\xc3\x86', # Æ + 0xc7 : b'\xc3\x87', # Ç + 0xc8 : b'\xc3\x88', # È + 0xc9 : b'\xc3\x89', # É + 0xca : b'\xc3\x8a', # Ê + 0xcb : b'\xc3\x8b', # Ë + 0xcc : b'\xc3\x8c', # Ì + 0xcd : b'\xc3\x8d', # Í + 0xce : b'\xc3\x8e', # Î + 0xcf : b'\xc3\x8f', # Ï + 0xd0 : b'\xc3\x90', # Ð + 0xd1 : b'\xc3\x91', # Ñ + 0xd2 : b'\xc3\x92', # Ò + 0xd3 : b'\xc3\x93', # Ó + 0xd4 : b'\xc3\x94', # Ô + 0xd5 : b'\xc3\x95', # Õ + 0xd6 : b'\xc3\x96', # Ö + 0xd7 : b'\xc3\x97', # × + 0xd8 : b'\xc3\x98', # Ø + 0xd9 : b'\xc3\x99', # Ù + 0xda : b'\xc3\x9a', # Ú + 0xdb : b'\xc3\x9b', # Û + 0xdc : b'\xc3\x9c', # Ü + 0xdd : b'\xc3\x9d', # Ý + 0xde : b'\xc3\x9e', # Þ + 0xdf : b'\xc3\x9f', # ß + 0xe0 : b'\xc3\xa0', # à + 0xe1 : b'\xa1', # á + 0xe2 : b'\xc3\xa2', # â + 0xe3 : b'\xc3\xa3', # ã + 0xe4 : b'\xc3\xa4', # ä + 0xe5 : b'\xc3\xa5', # å + 0xe6 : b'\xc3\xa6', # æ + 0xe7 : b'\xc3\xa7', # ç + 0xe8 : b'\xc3\xa8', # è + 0xe9 : b'\xc3\xa9', # é + 0xea : b'\xc3\xaa', # ê + 0xeb : b'\xc3\xab', # ë + 0xec : b'\xc3\xac', # ì + 0xed : b'\xc3\xad', # í + 0xee : b'\xc3\xae', # î + 0xef : b'\xc3\xaf', # ï + 0xf0 : b'\xc3\xb0', # ð + 0xf1 : b'\xc3\xb1', # ñ + 0xf2 : b'\xc3\xb2', # ò + 0xf3 : b'\xc3\xb3', # ó + 0xf4 : b'\xc3\xb4', # ô + 0xf5 : b'\xc3\xb5', # õ + 0xf6 : b'\xc3\xb6', # ö + 0xf7 : b'\xc3\xb7', # ÷ + 0xf8 : b'\xc3\xb8', # ø + 0xf9 : b'\xc3\xb9', # ù + 0xfa : b'\xc3\xba', # ú + 0xfb : b'\xc3\xbb', # û + 0xfc : b'\xc3\xbc', # ü + 0xfd : b'\xc3\xbd', # ý + 0xfe : b'\xc3\xbe', # þ + } + + MULTIBYTE_MARKERS_AND_SIZES = [ + (0xc2, 0xdf, 2), # 2-byte characters start with a byte C2-DF + (0xe0, 0xef, 3), # 3-byte characters start with E0-EF + (0xf0, 0xf4, 4), # 4-byte characters start with F0-F4 + ] + + FIRST_MULTIBYTE_MARKER = MULTIBYTE_MARKERS_AND_SIZES[0][0] + LAST_MULTIBYTE_MARKER = MULTIBYTE_MARKERS_AND_SIZES[-1][1] + + @classmethod + def detwingle(cls, in_bytes, main_encoding="utf8", + embedded_encoding="windows-1252"): + """Fix characters from one encoding embedded in some other encoding. + + Currently the only situation supported is Windows-1252 (or its + subset ISO-8859-1), embedded in UTF-8. + + The input must be a bytestring. If you've already converted + the document to Unicode, you're too late. + + The output is a bytestring in which `embedded_encoding` + characters have been converted to their `main_encoding` + equivalents. + """ + if embedded_encoding.replace('_', '-').lower() not in ( + 'windows-1252', 'windows_1252'): + raise NotImplementedError( + "Windows-1252 and ISO-8859-1 are the only currently supported " + "embedded encodings.") + + if main_encoding.lower() not in ('utf8', 'utf-8'): + raise NotImplementedError( + "UTF-8 is the only currently supported main encoding.") + + byte_chunks = [] + + chunk_start = 0 + pos = 0 + while pos < len(in_bytes): + byte = in_bytes[pos] + if not isinstance(byte, int): + # Python 2.x + byte = ord(byte) + if (byte >= cls.FIRST_MULTIBYTE_MARKER + and byte <= cls.LAST_MULTIBYTE_MARKER): + # This is the start of a UTF-8 multibyte character. Skip + # to the end. + for start, end, size in cls.MULTIBYTE_MARKERS_AND_SIZES: + if byte >= start and byte <= end: + pos += size + break + elif byte >= 0x80 and byte in cls.WINDOWS_1252_TO_UTF8: + # We found a Windows-1252 character! + # Save the string up to this point as a chunk. + byte_chunks.append(in_bytes[chunk_start:pos]) + + # Now translate the Windows-1252 character into UTF-8 + # and add it as another, one-byte chunk. + byte_chunks.append(cls.WINDOWS_1252_TO_UTF8[byte]) + pos += 1 + chunk_start = pos + else: + # Go on to the next character. + pos += 1 + if chunk_start == 0: + # The string is unchanged. + return in_bytes + else: + # Store the final chunk. + byte_chunks.append(in_bytes[chunk_start:]) + return b''.join(byte_chunks) + diff --git a/libs/bs4/diagnose.py b/libs/bs4/diagnose.py new file mode 100644 index 000000000..b5f6e6c8b --- /dev/null +++ b/libs/bs4/diagnose.py @@ -0,0 +1,224 @@ +"""Diagnostic functions, mainly for use when doing tech support.""" + +# Use of this source code is governed by the MIT license. +__license__ = "MIT" + +import cProfile +from io import StringIO +from html.parser import HTMLParser +import bs4 +from bs4 import BeautifulSoup, __version__ +from bs4.builder import builder_registry + +import os +import pstats +import random +import tempfile +import time +import traceback +import sys +import cProfile + +def diagnose(data): + """Diagnostic suite for isolating common problems.""" + print("Diagnostic running on Beautiful Soup %s" % __version__) + print("Python version %s" % sys.version) + + basic_parsers = ["html.parser", "html5lib", "lxml"] + for name in basic_parsers: + for builder in builder_registry.builders: + if name in builder.features: + break + else: + basic_parsers.remove(name) + print(( + "I noticed that %s is not installed. Installing it may help." % + name)) + + if 'lxml' in basic_parsers: + basic_parsers.append("lxml-xml") + try: + from lxml import etree + print("Found lxml version %s" % ".".join(map(str,etree.LXML_VERSION))) + except ImportError as e: + print ( + "lxml is not installed or couldn't be imported.") + + + if 'html5lib' in basic_parsers: + try: + import html5lib + print("Found html5lib version %s" % html5lib.__version__) + except ImportError as e: + print ( + "html5lib is not installed or couldn't be imported.") + + if hasattr(data, 'read'): + data = data.read() + elif data.startswith("http:") or data.startswith("https:"): + print('"%s" looks like a URL. Beautiful Soup is not an HTTP client.' % data) + print("You need to use some other library to get the document behind the URL, and feed that document to Beautiful Soup.") + return + else: + try: + if os.path.exists(data): + print('"%s" looks like a filename. Reading data from the file.' % data) + with open(data) as fp: + data = fp.read() + except ValueError: + # This can happen on some platforms when the 'filename' is + # too long. Assume it's data and not a filename. + pass + print() + + for parser in basic_parsers: + print("Trying to parse your markup with %s" % parser) + success = False + try: + soup = BeautifulSoup(data, features=parser) + success = True + except Exception as e: + print("%s could not parse the markup." % parser) + traceback.print_exc() + if success: + print("Here's what %s did with the markup:" % parser) + print(soup.prettify()) + + print("-" * 80) + +def lxml_trace(data, html=True, **kwargs): + """Print out the lxml events that occur during parsing. + + This lets you see how lxml parses a document when no Beautiful + Soup code is running. + """ + from lxml import etree + for event, element in etree.iterparse(StringIO(data), html=html, **kwargs): + print(("%s, %4s, %s" % (event, element.tag, element.text))) + +class AnnouncingParser(HTMLParser): + """Announces HTMLParser parse events, without doing anything else.""" + + def _p(self, s): + print(s) + + def handle_starttag(self, name, attrs): + self._p("%s START" % name) + + def handle_endtag(self, name): + self._p("%s END" % name) + + def handle_data(self, data): + self._p("%s DATA" % data) + + def handle_charref(self, name): + self._p("%s CHARREF" % name) + + def handle_entityref(self, name): + self._p("%s ENTITYREF" % name) + + def handle_comment(self, data): + self._p("%s COMMENT" % data) + + def handle_decl(self, data): + self._p("%s DECL" % data) + + def unknown_decl(self, data): + self._p("%s UNKNOWN-DECL" % data) + + def handle_pi(self, data): + self._p("%s PI" % data) + +def htmlparser_trace(data): + """Print out the HTMLParser events that occur during parsing. + + This lets you see how HTMLParser parses a document when no + Beautiful Soup code is running. + """ + parser = AnnouncingParser() + parser.feed(data) + +_vowels = "aeiou" +_consonants = "bcdfghjklmnpqrstvwxyz" + +def rword(length=5): + "Generate a random word-like string." + s = '' + for i in range(length): + if i % 2 == 0: + t = _consonants + else: + t = _vowels + s += random.choice(t) + return s + +def rsentence(length=4): + "Generate a random sentence-like string." + return " ".join(rword(random.randint(4,9)) for i in list(range(length))) + +def rdoc(num_elements=1000): + """Randomly generate an invalid HTML document.""" + tag_names = ['p', 'div', 'span', 'i', 'b', 'script', 'table'] + elements = [] + for i in range(num_elements): + choice = random.randint(0,3) + if choice == 0: + # New tag. + tag_name = random.choice(tag_names) + elements.append("<%s>" % tag_name) + elif choice == 1: + elements.append(rsentence(random.randint(1,4))) + elif choice == 2: + # Close a tag. + tag_name = random.choice(tag_names) + elements.append("</%s>" % tag_name) + return "<html>" + "\n".join(elements) + "</html>" + +def benchmark_parsers(num_elements=100000): + """Very basic head-to-head performance benchmark.""" + print("Comparative parser benchmark on Beautiful Soup %s" % __version__) + data = rdoc(num_elements) + print("Generated a large invalid HTML document (%d bytes)." % len(data)) + + for parser in ["lxml", ["lxml", "html"], "html5lib", "html.parser"]: + success = False + try: + a = time.time() + soup = BeautifulSoup(data, parser) + b = time.time() + success = True + except Exception as e: + print("%s could not parse the markup." % parser) + traceback.print_exc() + if success: + print("BS4+%s parsed the markup in %.2fs." % (parser, b-a)) + + from lxml import etree + a = time.time() + etree.HTML(data) + b = time.time() + print("Raw lxml parsed the markup in %.2fs." % (b-a)) + + import html5lib + parser = html5lib.HTMLParser() + a = time.time() + parser.parse(data) + b = time.time() + print("Raw html5lib parsed the markup in %.2fs." % (b-a)) + +def profile(num_elements=100000, parser="lxml"): + + filehandle = tempfile.NamedTemporaryFile() + filename = filehandle.name + + data = rdoc(num_elements) + vars = dict(bs4=bs4, data=data, parser=parser) + cProfile.runctx('bs4.BeautifulSoup(data, parser)' , vars, vars, filename) + + stats = pstats.Stats(filename) + # stats.strip_dirs() + stats.sort_stats("cumulative") + stats.print_stats('_html5lib|bs4', 50) + +if __name__ == '__main__': + diagnose(sys.stdin.read()) diff --git a/libs/bs4/element.py b/libs/bs4/element.py new file mode 100644 index 000000000..f16b1663e --- /dev/null +++ b/libs/bs4/element.py @@ -0,0 +1,1579 @@ +# Use of this source code is governed by the MIT license. +__license__ = "MIT" + +try: + from collections.abc import Callable # Python 3.6 +except ImportError as e: + from collections import Callable +import re +import sys +import warnings +try: + import soupsieve +except ImportError as e: + soupsieve = None + warnings.warn( + 'The soupsieve package is not installed. CSS selectors cannot be used.' + ) + +from bs4.formatter import ( + Formatter, + HTMLFormatter, + XMLFormatter, +) + +DEFAULT_OUTPUT_ENCODING = "utf-8" +PY3K = (sys.version_info[0] > 2) + +nonwhitespace_re = re.compile(r"\S+") + +# NOTE: This isn't used as of 4.7.0. I'm leaving it for a little bit on +# the off chance someone imported it for their own use. +whitespace_re = re.compile(r"\s+") + +def _alias(attr): + """Alias one attribute name to another for backward compatibility""" + @property + def alias(self): + return getattr(self, attr) + + @alias.setter + def alias(self): + return setattr(self, attr) + return alias + + +class NamespacedAttribute(str): + + def __new__(cls, prefix, name, namespace=None): + if name is None: + obj = str.__new__(cls, prefix) + elif prefix is None: + # Not really namespaced. + obj = str.__new__(cls, name) + else: + obj = str.__new__(cls, prefix + ":" + name) + obj.prefix = prefix + obj.name = name + obj.namespace = namespace + return obj + +class AttributeValueWithCharsetSubstitution(str): + """A stand-in object for a character encoding specified in HTML.""" + +class CharsetMetaAttributeValue(AttributeValueWithCharsetSubstitution): + """A generic stand-in for the value of a meta tag's 'charset' attribute. + + When Beautiful Soup parses the markup '<meta charset="utf8">', the + value of the 'charset' attribute will be one of these objects. + """ + + def __new__(cls, original_value): + obj = str.__new__(cls, original_value) + obj.original_value = original_value + return obj + + def encode(self, encoding): + return encoding + + +class ContentMetaAttributeValue(AttributeValueWithCharsetSubstitution): + """A generic stand-in for the value of a meta tag's 'content' attribute. + + When Beautiful Soup parses the markup: + <meta http-equiv="content-type" content="text/html; charset=utf8"> + + The value of the 'content' attribute will be one of these objects. + """ + + CHARSET_RE = re.compile(r"((^|;)\s*charset=)([^;]*)", re.M) + + def __new__(cls, original_value): + match = cls.CHARSET_RE.search(original_value) + if match is None: + # No substitution necessary. + return str.__new__(str, original_value) + + obj = str.__new__(cls, original_value) + obj.original_value = original_value + return obj + + def encode(self, encoding): + def rewrite(match): + return match.group(1) + encoding + return self.CHARSET_RE.sub(rewrite, self.original_value) + + +class PageElement(object): + """Contains the navigational information for some part of the page + (either a tag or a piece of text)""" + + def setup(self, parent=None, previous_element=None, next_element=None, + previous_sibling=None, next_sibling=None): + """Sets up the initial relations between this element and + other elements.""" + self.parent = parent + + self.previous_element = previous_element + if previous_element is not None: + self.previous_element.next_element = self + + self.next_element = next_element + if self.next_element is not None: + self.next_element.previous_element = self + + self.next_sibling = next_sibling + if self.next_sibling is not None: + self.next_sibling.previous_sibling = self + + if (previous_sibling is None + and self.parent is not None and self.parent.contents): + previous_sibling = self.parent.contents[-1] + + self.previous_sibling = previous_sibling + if previous_sibling is not None: + self.previous_sibling.next_sibling = self + + def format_string(self, s, formatter): + """Format the given string using the given formatter.""" + if formatter is None: + return s + if not isinstance(formatter, Formatter): + formatter = self.formatter_for_name(formatter) + output = formatter.substitute(s) + return output + + def formatter_for_name(self, formatter): + """Look up or create a Formatter for the given identifier, + if necessary. + + :param formatter: Can be a Formatter object (used as-is), a + function (used as the entity substitution hook for an + XMLFormatter or HTMLFormatter), or a string (used to look up + an XMLFormatter or HTMLFormatter in the appropriate registry. + """ + if isinstance(formatter, Formatter): + return formatter + if self._is_xml: + c = XMLFormatter + else: + c = HTMLFormatter + if callable(formatter): + return c(entity_substitution=formatter) + return c.REGISTRY[formatter] + + @property + def _is_xml(self): + """Is this element part of an XML tree or an HTML tree? + + This is used in formatter_for_name, when deciding whether an + XMLFormatter or HTMLFormatter is more appropriate. It can be + inefficient, but it should be called very rarely. + """ + if self.known_xml is not None: + # Most of the time we will have determined this when the + # document is parsed. + return self.known_xml + + # Otherwise, it's likely that this element was created by + # direct invocation of the constructor from within the user's + # Python code. + if self.parent is None: + # This is the top-level object. It should have .known_xml set + # from tree creation. If not, take a guess--BS is usually + # used on HTML markup. + return getattr(self, 'is_xml', False) + return self.parent._is_xml + + nextSibling = _alias("next_sibling") # BS3 + previousSibling = _alias("previous_sibling") # BS3 + + def replace_with(self, replace_with): + if self.parent is None: + raise ValueError( + "Cannot replace one element with another when the " + "element to be replaced is not part of a tree.") + if replace_with is self: + return + if replace_with is self.parent: + raise ValueError("Cannot replace a Tag with its parent.") + old_parent = self.parent + my_index = self.parent.index(self) + self.extract() + old_parent.insert(my_index, replace_with) + return self + replaceWith = replace_with # BS3 + + def unwrap(self): + my_parent = self.parent + if self.parent is None: + raise ValueError( + "Cannot replace an element with its contents when that" + "element is not part of a tree.") + my_index = self.parent.index(self) + self.extract() + for child in reversed(self.contents[:]): + my_parent.insert(my_index, child) + return self + replace_with_children = unwrap + replaceWithChildren = unwrap # BS3 + + def wrap(self, wrap_inside): + me = self.replace_with(wrap_inside) + wrap_inside.append(me) + return wrap_inside + + def extract(self): + """Destructively rips this element out of the tree.""" + if self.parent is not None: + del self.parent.contents[self.parent.index(self)] + + #Find the two elements that would be next to each other if + #this element (and any children) hadn't been parsed. Connect + #the two. + last_child = self._last_descendant() + next_element = last_child.next_element + + if (self.previous_element is not None and + self.previous_element is not next_element): + self.previous_element.next_element = next_element + if next_element is not None and next_element is not self.previous_element: + next_element.previous_element = self.previous_element + self.previous_element = None + last_child.next_element = None + + self.parent = None + if (self.previous_sibling is not None + and self.previous_sibling is not self.next_sibling): + self.previous_sibling.next_sibling = self.next_sibling + if (self.next_sibling is not None + and self.next_sibling is not self.previous_sibling): + self.next_sibling.previous_sibling = self.previous_sibling + self.previous_sibling = self.next_sibling = None + return self + + def _last_descendant(self, is_initialized=True, accept_self=True): + "Finds the last element beneath this object to be parsed." + if is_initialized and self.next_sibling is not None: + last_child = self.next_sibling.previous_element + else: + last_child = self + while isinstance(last_child, Tag) and last_child.contents: + last_child = last_child.contents[-1] + if not accept_self and last_child is self: + last_child = None + return last_child + # BS3: Not part of the API! + _lastRecursiveChild = _last_descendant + + def insert(self, position, new_child): + if new_child is None: + raise ValueError("Cannot insert None into a tag.") + if new_child is self: + raise ValueError("Cannot insert a tag into itself.") + if (isinstance(new_child, str) + and not isinstance(new_child, NavigableString)): + new_child = NavigableString(new_child) + + from bs4 import BeautifulSoup + if isinstance(new_child, BeautifulSoup): + # We don't want to end up with a situation where one BeautifulSoup + # object contains another. Insert the children one at a time. + for subchild in list(new_child.contents): + self.insert(position, subchild) + position += 1 + return + position = min(position, len(self.contents)) + if hasattr(new_child, 'parent') and new_child.parent is not None: + # We're 'inserting' an element that's already one + # of this object's children. + if new_child.parent is self: + current_index = self.index(new_child) + if current_index < position: + # We're moving this element further down the list + # of this object's children. That means that when + # we extract this element, our target index will + # jump down one. + position -= 1 + new_child.extract() + + new_child.parent = self + previous_child = None + if position == 0: + new_child.previous_sibling = None + new_child.previous_element = self + else: + previous_child = self.contents[position - 1] + new_child.previous_sibling = previous_child + new_child.previous_sibling.next_sibling = new_child + new_child.previous_element = previous_child._last_descendant(False) + if new_child.previous_element is not None: + new_child.previous_element.next_element = new_child + + new_childs_last_element = new_child._last_descendant(False) + + if position >= len(self.contents): + new_child.next_sibling = None + + parent = self + parents_next_sibling = None + while parents_next_sibling is None and parent is not None: + parents_next_sibling = parent.next_sibling + parent = parent.parent + if parents_next_sibling is not None: + # We found the element that comes next in the document. + break + if parents_next_sibling is not None: + new_childs_last_element.next_element = parents_next_sibling + else: + # The last element of this tag is the last element in + # the document. + new_childs_last_element.next_element = None + else: + next_child = self.contents[position] + new_child.next_sibling = next_child + if new_child.next_sibling is not None: + new_child.next_sibling.previous_sibling = new_child + new_childs_last_element.next_element = next_child + + if new_childs_last_element.next_element is not None: + new_childs_last_element.next_element.previous_element = new_childs_last_element + self.contents.insert(position, new_child) + + def append(self, tag): + """Appends the given tag to the contents of this tag.""" + self.insert(len(self.contents), tag) + + def extend(self, tags): + """Appends the given tags to the contents of this tag.""" + for tag in tags: + self.append(tag) + + def insert_before(self, *args): + """Makes the given element(s) the immediate predecessor of this one. + + The elements will have the same parent, and the given elements + will be immediately before this one. + """ + parent = self.parent + if parent is None: + raise ValueError( + "Element has no parent, so 'before' has no meaning.") + if any(x is self for x in args): + raise ValueError("Can't insert an element before itself.") + for predecessor in args: + # Extract first so that the index won't be screwed up if they + # are siblings. + if isinstance(predecessor, PageElement): + predecessor.extract() + index = parent.index(self) + parent.insert(index, predecessor) + + def insert_after(self, *args): + """Makes the given element(s) the immediate successor of this one. + + The elements will have the same parent, and the given elements + will be immediately after this one. + """ + # Do all error checking before modifying the tree. + parent = self.parent + if parent is None: + raise ValueError( + "Element has no parent, so 'after' has no meaning.") + if any(x is self for x in args): + raise ValueError("Can't insert an element after itself.") + + offset = 0 + for successor in args: + # Extract first so that the index won't be screwed up if they + # are siblings. + if isinstance(successor, PageElement): + successor.extract() + index = parent.index(self) + parent.insert(index+1+offset, successor) + offset += 1 + + def find_next(self, name=None, attrs={}, text=None, **kwargs): + """Returns the first item that matches the given criteria and + appears after this Tag in the document.""" + return self._find_one(self.find_all_next, name, attrs, text, **kwargs) + findNext = find_next # BS3 + + def find_all_next(self, name=None, attrs={}, text=None, limit=None, + **kwargs): + """Returns all items that match the given criteria and appear + after this Tag in the document.""" + return self._find_all(name, attrs, text, limit, self.next_elements, + **kwargs) + findAllNext = find_all_next # BS3 + + def find_next_sibling(self, name=None, attrs={}, text=None, **kwargs): + """Returns the closest sibling to this Tag that matches the + given criteria and appears after this Tag in the document.""" + return self._find_one(self.find_next_siblings, name, attrs, text, + **kwargs) + findNextSibling = find_next_sibling # BS3 + + def find_next_siblings(self, name=None, attrs={}, text=None, limit=None, + **kwargs): + """Returns the siblings of this Tag that match the given + criteria and appear after this Tag in the document.""" + return self._find_all(name, attrs, text, limit, + self.next_siblings, **kwargs) + findNextSiblings = find_next_siblings # BS3 + fetchNextSiblings = find_next_siblings # BS2 + + def find_previous(self, name=None, attrs={}, text=None, **kwargs): + """Returns the first item that matches the given criteria and + appears before this Tag in the document.""" + return self._find_one( + self.find_all_previous, name, attrs, text, **kwargs) + findPrevious = find_previous # BS3 + + def find_all_previous(self, name=None, attrs={}, text=None, limit=None, + **kwargs): + """Returns all items that match the given criteria and appear + before this Tag in the document.""" + return self._find_all(name, attrs, text, limit, self.previous_elements, + **kwargs) + findAllPrevious = find_all_previous # BS3 + fetchPrevious = find_all_previous # BS2 + + def find_previous_sibling(self, name=None, attrs={}, text=None, **kwargs): + """Returns the closest sibling to this Tag that matches the + given criteria and appears before this Tag in the document.""" + return self._find_one(self.find_previous_siblings, name, attrs, text, + **kwargs) + findPreviousSibling = find_previous_sibling # BS3 + + def find_previous_siblings(self, name=None, attrs={}, text=None, + limit=None, **kwargs): + """Returns the siblings of this Tag that match the given + criteria and appear before this Tag in the document.""" + return self._find_all(name, attrs, text, limit, + self.previous_siblings, **kwargs) + findPreviousSiblings = find_previous_siblings # BS3 + fetchPreviousSiblings = find_previous_siblings # BS2 + + def find_parent(self, name=None, attrs={}, **kwargs): + """Returns the closest parent of this Tag that matches the given + criteria.""" + # NOTE: We can't use _find_one because findParents takes a different + # set of arguments. + r = None + l = self.find_parents(name, attrs, 1, **kwargs) + if l: + r = l[0] + return r + findParent = find_parent # BS3 + + def find_parents(self, name=None, attrs={}, limit=None, **kwargs): + """Returns the parents of this Tag that match the given + criteria.""" + + return self._find_all(name, attrs, None, limit, self.parents, + **kwargs) + findParents = find_parents # BS3 + fetchParents = find_parents # BS2 + + @property + def next(self): + return self.next_element + + @property + def previous(self): + return self.previous_element + + #These methods do the real heavy lifting. + + def _find_one(self, method, name, attrs, text, **kwargs): + r = None + l = method(name, attrs, text, 1, **kwargs) + if l: + r = l[0] + return r + + def _find_all(self, name, attrs, text, limit, generator, **kwargs): + "Iterates over a generator looking for things that match." + + if text is None and 'string' in kwargs: + text = kwargs['string'] + del kwargs['string'] + + if isinstance(name, SoupStrainer): + strainer = name + else: + strainer = SoupStrainer(name, attrs, text, **kwargs) + + if text is None and not limit and not attrs and not kwargs: + if name is True or name is None: + # Optimization to find all tags. + result = (element for element in generator + if isinstance(element, Tag)) + return ResultSet(strainer, result) + elif isinstance(name, str): + # Optimization to find all tags with a given name. + if name.count(':') == 1: + # This is a name with a prefix. If this is a namespace-aware document, + # we need to match the local name against tag.name. If not, + # we need to match the fully-qualified name against tag.name. + prefix, local_name = name.split(':', 1) + else: + prefix = None + local_name = name + result = (element for element in generator + if isinstance(element, Tag) + and ( + element.name == name + ) or ( + element.name == local_name + and (prefix is None or element.prefix == prefix) + ) + ) + return ResultSet(strainer, result) + results = ResultSet(strainer) + while True: + try: + i = next(generator) + except StopIteration: + break + if i: + found = strainer.search(i) + if found: + results.append(found) + if limit and len(results) >= limit: + break + return results + + #These generators can be used to navigate starting from both + #NavigableStrings and Tags. + @property + def next_elements(self): + i = self.next_element + while i is not None: + yield i + i = i.next_element + + @property + def next_siblings(self): + i = self.next_sibling + while i is not None: + yield i + i = i.next_sibling + + @property + def previous_elements(self): + i = self.previous_element + while i is not None: + yield i + i = i.previous_element + + @property + def previous_siblings(self): + i = self.previous_sibling + while i is not None: + yield i + i = i.previous_sibling + + @property + def parents(self): + i = self.parent + while i is not None: + yield i + i = i.parent + + # Old non-property versions of the generators, for backwards + # compatibility with BS3. + def nextGenerator(self): + return self.next_elements + + def nextSiblingGenerator(self): + return self.next_siblings + + def previousGenerator(self): + return self.previous_elements + + def previousSiblingGenerator(self): + return self.previous_siblings + + def parentGenerator(self): + return self.parents + + +class NavigableString(str, PageElement): + + PREFIX = '' + SUFFIX = '' + + # We can't tell just by looking at a string whether it's contained + # in an XML document or an HTML document. + + known_xml = None + + def __new__(cls, value): + """Create a new NavigableString. + + When unpickling a NavigableString, this method is called with + the string in DEFAULT_OUTPUT_ENCODING. That encoding needs to be + passed in to the superclass's __new__ or the superclass won't know + how to handle non-ASCII characters. + """ + if isinstance(value, str): + u = str.__new__(cls, value) + else: + u = str.__new__(cls, value, DEFAULT_OUTPUT_ENCODING) + u.setup() + return u + + def __copy__(self): + """A copy of a NavigableString has the same contents and class + as the original, but it is not connected to the parse tree. + """ + return type(self)(self) + + def __getnewargs__(self): + return (str(self),) + + def __getattr__(self, attr): + """text.string gives you text. This is for backwards + compatibility for Navigable*String, but for CData* it lets you + get the string without the CData wrapper.""" + if attr == 'string': + return self + else: + raise AttributeError( + "'%s' object has no attribute '%s'" % ( + self.__class__.__name__, attr)) + + def output_ready(self, formatter="minimal"): + """Run the string through the provided formatter.""" + output = self.format_string(self, formatter) + return self.PREFIX + output + self.SUFFIX + + @property + def name(self): + return None + + @name.setter + def name(self, name): + raise AttributeError("A NavigableString cannot be given a name.") + +class PreformattedString(NavigableString): + """A NavigableString not subject to the normal formatting rules. + + The string will be passed into the formatter (to trigger side effects), + but the return value will be ignored. + """ + + def output_ready(self, formatter=None): + """CData strings are passed into the formatter, purely + for any side effects. The return value is ignored. + """ + if formatter is not None: + ignore = self.format_string(self, formatter) + return self.PREFIX + self + self.SUFFIX + +class CData(PreformattedString): + + PREFIX = '<![CDATA[' + SUFFIX = ']]>' + +class ProcessingInstruction(PreformattedString): + """A SGML processing instruction.""" + + PREFIX = '<?' + SUFFIX = '>' + +class XMLProcessingInstruction(ProcessingInstruction): + """An XML processing instruction.""" + PREFIX = '<?' + SUFFIX = '?>' + +class Comment(PreformattedString): + + PREFIX = '<!--' + SUFFIX = '-->' + + +class Declaration(PreformattedString): + PREFIX = '<?' + SUFFIX = '?>' + + +class Doctype(PreformattedString): + + @classmethod + def for_name_and_ids(cls, name, pub_id, system_id): + value = name or '' + if pub_id is not None: + value += ' PUBLIC "%s"' % pub_id + if system_id is not None: + value += ' "%s"' % system_id + elif system_id is not None: + value += ' SYSTEM "%s"' % system_id + + return Doctype(value) + + PREFIX = '<!DOCTYPE ' + SUFFIX = '>\n' + + +class Tag(PageElement): + + """Represents a found HTML tag with its attributes and contents.""" + + def __init__(self, parser=None, builder=None, name=None, namespace=None, + prefix=None, attrs=None, parent=None, previous=None, + is_xml=None): + "Basic constructor." + + if parser is None: + self.parser_class = None + else: + # We don't actually store the parser object: that lets extracted + # chunks be garbage-collected. + self.parser_class = parser.__class__ + if name is None: + raise ValueError("No value provided for new tag's name.") + self.name = name + self.namespace = namespace + self.prefix = prefix + if attrs is None: + attrs = {} + elif attrs: + if builder is not None and builder.cdata_list_attributes: + attrs = builder._replace_cdata_list_attribute_values( + self.name, attrs) + else: + attrs = dict(attrs) + else: + attrs = dict(attrs) + + # If possible, determine ahead of time whether this tag is an + # XML tag. + if builder: + self.known_xml = builder.is_xml + else: + self.known_xml = is_xml + self.attrs = attrs + self.contents = [] + self.setup(parent, previous) + self.hidden = False + + if builder is None: + # In the absence of a TreeBuilder, assume this tag is nothing + # special. + self.can_be_empty_element = False + self.cdata_list_attributes = None + else: + # Set up any substitutions for this tag, such as the charset in a META tag. + builder.set_up_substitutions(self) + + # Ask the TreeBuilder whether this tag might be an empty-element tag. + self.can_be_empty_element = builder.can_be_empty_element(name) + + # Keep track of the list of attributes of this tag that + # might need to be treated as a list. + # + # For performance reasons, we store the whole data structure + # rather than asking the question of every tag. Asking would + # require building a new data structure every time, and + # (unlike can_be_empty_element), we almost never need + # to check this. + self.cdata_list_attributes = builder.cdata_list_attributes + + # Keep track of the names that might cause this tag to be treated as a + # whitespace-preserved tag. + self.preserve_whitespace_tags = builder.preserve_whitespace_tags + + parserClass = _alias("parser_class") # BS3 + + def __copy__(self): + """A copy of a Tag is a new Tag, unconnected to the parse tree. + Its contents are a copy of the old Tag's contents. + """ + clone = type(self)(None, self.builder, self.name, self.namespace, + self.prefix, self.attrs, is_xml=self._is_xml) + for attr in ('can_be_empty_element', 'hidden'): + setattr(clone, attr, getattr(self, attr)) + for child in self.contents: + clone.append(child.__copy__()) + return clone + + @property + def is_empty_element(self): + """Is this tag an empty-element tag? (aka a self-closing tag) + + A tag that has contents is never an empty-element tag. + + A tag that has no contents may or may not be an empty-element + tag. It depends on the builder used to create the tag. If the + builder has a designated list of empty-element tags, then only + a tag whose name shows up in that list is considered an + empty-element tag. + + If the builder has no designated list of empty-element tags, + then any tag with no contents is an empty-element tag. + """ + return len(self.contents) == 0 and self.can_be_empty_element + isSelfClosing = is_empty_element # BS3 + + @property + def string(self): + """Convenience property to get the single string within this tag. + + :Return: If this tag has a single string child, return value + is that string. If this tag has no children, or more than one + child, return value is None. If this tag has one child tag, + return value is the 'string' attribute of the child tag, + recursively. + """ + if len(self.contents) != 1: + return None + child = self.contents[0] + if isinstance(child, NavigableString): + return child + return child.string + + @string.setter + def string(self, string): + self.clear() + self.append(string.__class__(string)) + + def _all_strings(self, strip=False, types=(NavigableString, CData)): + """Yield all strings of certain classes, possibly stripping them. + + By default, yields only NavigableString and CData objects. So + no comments, processing instructions, etc. + """ + for descendant in self.descendants: + if ( + (types is None and not isinstance(descendant, NavigableString)) + or + (types is not None and type(descendant) not in types)): + continue + if strip: + descendant = descendant.strip() + if len(descendant) == 0: + continue + yield descendant + + strings = property(_all_strings) + + @property + def stripped_strings(self): + for string in self._all_strings(True): + yield string + + def get_text(self, separator="", strip=False, + types=(NavigableString, CData)): + """ + Get all child strings, concatenated using the given separator. + """ + return separator.join([s for s in self._all_strings( + strip, types=types)]) + getText = get_text + text = property(get_text) + + def decompose(self): + """Recursively destroys the contents of this tree.""" + self.extract() + i = self + while i is not None: + next = i.next_element + i.__dict__.clear() + i.contents = [] + i = next + + def clear(self, decompose=False): + """ + Extract all children. If decompose is True, decompose instead. + """ + if decompose: + for element in self.contents[:]: + if isinstance(element, Tag): + element.decompose() + else: + element.extract() + else: + for element in self.contents[:]: + element.extract() + + def smooth(self): + """Smooth out this element's children by consolidating consecutive strings. + + This makes pretty-printed output look more natural following a + lot of operations that modified the tree. + """ + # Mark the first position of every pair of children that need + # to be consolidated. Do this rather than making a copy of + # self.contents, since in most cases very few strings will be + # affected. + marked = [] + for i, a in enumerate(self.contents): + if isinstance(a, Tag): + # Recursively smooth children. + a.smooth() + if i == len(self.contents)-1: + # This is the last item in .contents, and it's not a + # tag. There's no chance it needs any work. + continue + b = self.contents[i+1] + if (isinstance(a, NavigableString) + and isinstance(b, NavigableString) + and not isinstance(a, PreformattedString) + and not isinstance(b, PreformattedString) + ): + marked.append(i) + + # Go over the marked positions in reverse order, so that + # removing items from .contents won't affect the remaining + # positions. + for i in reversed(marked): + a = self.contents[i] + b = self.contents[i+1] + b.extract() + n = NavigableString(a+b) + a.replace_with(n) + + def index(self, element): + """ + Find the index of a child by identity, not value. Avoids issues with + tag.contents.index(element) getting the index of equal elements. + """ + for i, child in enumerate(self.contents): + if child is element: + return i + raise ValueError("Tag.index: element not in tag") + + def get(self, key, default=None): + """Returns the value of the 'key' attribute for the tag, or + the value given for 'default' if it doesn't have that + attribute.""" + return self.attrs.get(key, default) + + def get_attribute_list(self, key, default=None): + """The same as get(), but always returns a list.""" + value = self.get(key, default) + if not isinstance(value, list): + value = [value] + return value + + def has_attr(self, key): + return key in self.attrs + + def __hash__(self): + return str(self).__hash__() + + def __getitem__(self, key): + """tag[key] returns the value of the 'key' attribute for the tag, + and throws an exception if it's not there.""" + return self.attrs[key] + + def __iter__(self): + "Iterating over a tag iterates over its contents." + return iter(self.contents) + + def __len__(self): + "The length of a tag is the length of its list of contents." + return len(self.contents) + + def __contains__(self, x): + return x in self.contents + + def __bool__(self): + "A tag is non-None even if it has no contents." + return True + + def __setitem__(self, key, value): + """Setting tag[key] sets the value of the 'key' attribute for the + tag.""" + self.attrs[key] = value + + def __delitem__(self, key): + "Deleting tag[key] deletes all 'key' attributes for the tag." + self.attrs.pop(key, None) + + def __call__(self, *args, **kwargs): + """Calling a tag like a function is the same as calling its + find_all() method. Eg. tag('a') returns a list of all the A tags + found within this tag.""" + return self.find_all(*args, **kwargs) + + def __getattr__(self, tag): + #print "Getattr %s.%s" % (self.__class__, tag) + if len(tag) > 3 and tag.endswith('Tag'): + # BS3: soup.aTag -> "soup.find("a") + tag_name = tag[:-3] + warnings.warn( + '.%(name)sTag is deprecated, use .find("%(name)s") instead. If you really were looking for a tag called %(name)sTag, use .find("%(name)sTag")' % dict( + name=tag_name + ) + ) + return self.find(tag_name) + # We special case contents to avoid recursion. + elif not tag.startswith("__") and not tag == "contents": + return self.find(tag) + raise AttributeError( + "'%s' object has no attribute '%s'" % (self.__class__, tag)) + + def __eq__(self, other): + """Returns true iff this tag has the same name, the same attributes, + and the same contents (recursively) as the given tag.""" + if self is other: + return True + if (not hasattr(other, 'name') or + not hasattr(other, 'attrs') or + not hasattr(other, 'contents') or + self.name != other.name or + self.attrs != other.attrs or + len(self) != len(other)): + return False + for i, my_child in enumerate(self.contents): + if my_child != other.contents[i]: + return False + return True + + def __ne__(self, other): + """Returns true iff this tag is not identical to the other tag, + as defined in __eq__.""" + return not self == other + + def __repr__(self, encoding="unicode-escape"): + """Renders this tag as a string.""" + if PY3K: + # "The return value must be a string object", i.e. Unicode + return self.decode() + else: + # "The return value must be a string object", i.e. a bytestring. + # By convention, the return value of __repr__ should also be + # an ASCII string. + return self.encode(encoding) + + def __unicode__(self): + return self.decode() + + def __str__(self): + if PY3K: + return self.decode() + else: + return self.encode() + + if PY3K: + __str__ = __repr__ = __unicode__ + + def encode(self, encoding=DEFAULT_OUTPUT_ENCODING, + indent_level=None, formatter="minimal", + errors="xmlcharrefreplace"): + # Turn the data structure into Unicode, then encode the + # Unicode. + u = self.decode(indent_level, encoding, formatter) + return u.encode(encoding, errors) + + def decode(self, indent_level=None, + eventual_encoding=DEFAULT_OUTPUT_ENCODING, + formatter="minimal"): + """Returns a Unicode representation of this tag and its contents. + + :param eventual_encoding: The tag is destined to be + encoded into this encoding. This method is _not_ + responsible for performing that encoding. This information + is passed in so that it can be substituted in if the + document contains a <META> tag that mentions the document's + encoding. + """ + + # First off, turn a non-Formatter `formatter` into a Formatter + # object. This will stop the lookup from happening over and + # over again. + if not isinstance(formatter, Formatter): + formatter = self.formatter_for_name(formatter) + attributes = formatter.attributes(self) + attrs = [] + for key, val in attributes: + if val is None: + decoded = key + else: + if isinstance(val, list) or isinstance(val, tuple): + val = ' '.join(val) + elif not isinstance(val, str): + val = str(val) + elif ( + isinstance(val, AttributeValueWithCharsetSubstitution) + and eventual_encoding is not None + ): + val = val.encode(eventual_encoding) + + text = formatter.attribute_value(val) + decoded = ( + str(key) + '=' + + formatter.quoted_attribute_value(text)) + attrs.append(decoded) + close = '' + closeTag = '' + + prefix = '' + if self.prefix: + prefix = self.prefix + ":" + + if self.is_empty_element: + close = formatter.void_element_close_prefix or '' + else: + closeTag = '</%s%s>' % (prefix, self.name) + + pretty_print = self._should_pretty_print(indent_level) + space = '' + indent_space = '' + if indent_level is not None: + indent_space = (' ' * (indent_level - 1)) + if pretty_print: + space = indent_space + indent_contents = indent_level + 1 + else: + indent_contents = None + contents = self.decode_contents( + indent_contents, eventual_encoding, formatter + ) + + if self.hidden: + # This is the 'document root' object. + s = contents + else: + s = [] + attribute_string = '' + if attrs: + attribute_string = ' ' + ' '.join(attrs) + if indent_level is not None: + # Even if this particular tag is not pretty-printed, + # we should indent up to the start of the tag. + s.append(indent_space) + s.append('<%s%s%s%s>' % ( + prefix, self.name, attribute_string, close)) + if pretty_print: + s.append("\n") + s.append(contents) + if pretty_print and contents and contents[-1] != "\n": + s.append("\n") + if pretty_print and closeTag: + s.append(space) + s.append(closeTag) + if indent_level is not None and closeTag and self.next_sibling: + # Even if this particular tag is not pretty-printed, + # we're now done with the tag, and we should add a + # newline if appropriate. + s.append("\n") + s = ''.join(s) + return s + + def _should_pretty_print(self, indent_level): + """Should this tag be pretty-printed?""" + return ( + indent_level is not None + and self.name not in self.preserve_whitespace_tags + ) + + def prettify(self, encoding=None, formatter="minimal"): + if encoding is None: + return self.decode(True, formatter=formatter) + else: + return self.encode(encoding, True, formatter=formatter) + + def decode_contents(self, indent_level=None, + eventual_encoding=DEFAULT_OUTPUT_ENCODING, + formatter="minimal"): + """Renders the contents of this tag as a Unicode string. + + :param indent_level: Each line of the rendering will be + indented this many spaces. + + :param eventual_encoding: The tag is destined to be + encoded into this encoding. decode_contents() is _not_ + responsible for performing that encoding. This information + is passed in so that it can be substituted in if the + document contains a <META> tag that mentions the document's + encoding. + + :param formatter: A Formatter object, or a string naming one of + the standard Formatters. + """ + # First off, turn a string formatter into a Formatter object. This + # will stop the lookup from happening over and over again. + if not isinstance(formatter, Formatter): + formatter = self.formatter_for_name(formatter) + + pretty_print = (indent_level is not None) + s = [] + for c in self: + text = None + if isinstance(c, NavigableString): + text = c.output_ready(formatter) + elif isinstance(c, Tag): + s.append(c.decode(indent_level, eventual_encoding, + formatter)) + preserve_whitespace = ( + self.preserve_whitespace_tags and self.name in self.preserve_whitespace_tags + ) + if text and indent_level and not preserve_whitespace: + text = text.strip() + if text: + if pretty_print and not preserve_whitespace: + s.append(" " * (indent_level - 1)) + s.append(text) + if pretty_print and not preserve_whitespace: + s.append("\n") + return ''.join(s) + + def encode_contents( + self, indent_level=None, encoding=DEFAULT_OUTPUT_ENCODING, + formatter="minimal"): + """Renders the contents of this tag as a bytestring. + + :param indent_level: Each line of the rendering will be + indented this many spaces. + + :param eventual_encoding: The bytestring will be in this encoding. + + :param formatter: The output formatter responsible for converting + entities to Unicode characters. + """ + + contents = self.decode_contents(indent_level, encoding, formatter) + return contents.encode(encoding) + + # Old method for BS3 compatibility + def renderContents(self, encoding=DEFAULT_OUTPUT_ENCODING, + prettyPrint=False, indentLevel=0): + if not prettyPrint: + indentLevel = None + return self.encode_contents( + indent_level=indentLevel, encoding=encoding) + + #Soup methods + + def find(self, name=None, attrs={}, recursive=True, text=None, + **kwargs): + """Return only the first child of this Tag matching the given + criteria.""" + r = None + l = self.find_all(name, attrs, recursive, text, 1, **kwargs) + if l: + r = l[0] + return r + findChild = find + + def find_all(self, name=None, attrs={}, recursive=True, text=None, + limit=None, **kwargs): + """Extracts a list of Tag objects that match the given + criteria. You can specify the name of the Tag and any + attributes you want the Tag to have. + + The value of a key-value pair in the 'attrs' map can be a + string, a list of strings, a regular expression object, or a + callable that takes a string and returns whether or not the + string matches for some custom definition of 'matches'. The + same is true of the tag name.""" + + generator = self.descendants + if not recursive: + generator = self.children + return self._find_all(name, attrs, text, limit, generator, **kwargs) + findAll = find_all # BS3 + findChildren = find_all # BS2 + + #Generator methods + @property + def children(self): + # return iter() to make the purpose of the method clear + return iter(self.contents) # XXX This seems to be untested. + + @property + def descendants(self): + if not len(self.contents): + return + stopNode = self._last_descendant().next_element + current = self.contents[0] + while current is not stopNode: + yield current + current = current.next_element + + # CSS selector code + def select_one(self, selector, namespaces=None, **kwargs): + """Perform a CSS selection operation on the current element.""" + value = self.select(selector, namespaces, 1, **kwargs) + if value: + return value[0] + return None + + def select(self, selector, namespaces=None, limit=None, **kwargs): + """Perform a CSS selection operation on the current element. + + This uses the SoupSieve library. + + :param selector: A string containing a CSS selector. + + :param namespaces: A dictionary mapping namespace prefixes + used in the CSS selector to namespace URIs. By default, + Beautiful Soup will use the prefixes it encountered while + parsing the document. + + :param limit: After finding this number of results, stop looking. + + :param kwargs: Any extra arguments you'd like to pass in to + soupsieve.select(). + """ + if namespaces is None: + namespaces = self._namespaces + + if limit is None: + limit = 0 + if soupsieve is None: + raise NotImplementedError( + "Cannot execute CSS selectors because the soupsieve package is not installed." + ) + + return soupsieve.select(selector, self, namespaces, limit, **kwargs) + + # Old names for backwards compatibility + def childGenerator(self): + return self.children + + def recursiveChildGenerator(self): + return self.descendants + + def has_key(self, key): + """This was kind of misleading because has_key() (attributes) + was different from __in__ (contents). has_key() is gone in + Python 3, anyway.""" + warnings.warn('has_key is deprecated. Use has_attr("%s") instead.' % ( + key)) + return self.has_attr(key) + +# Next, a couple classes to represent queries and their results. +class SoupStrainer(object): + """Encapsulates a number of ways of matching a markup element (tag or + text).""" + + def __init__(self, name=None, attrs={}, text=None, **kwargs): + self.name = self._normalize_search_value(name) + if not isinstance(attrs, dict): + # Treat a non-dict value for attrs as a search for the 'class' + # attribute. + kwargs['class'] = attrs + attrs = None + + if 'class_' in kwargs: + # Treat class_="foo" as a search for the 'class' + # attribute, overriding any non-dict value for attrs. + kwargs['class'] = kwargs['class_'] + del kwargs['class_'] + + if kwargs: + if attrs: + attrs = attrs.copy() + attrs.update(kwargs) + else: + attrs = kwargs + normalized_attrs = {} + for key, value in list(attrs.items()): + normalized_attrs[key] = self._normalize_search_value(value) + + self.attrs = normalized_attrs + self.text = self._normalize_search_value(text) + + def _normalize_search_value(self, value): + # Leave it alone if it's a Unicode string, a callable, a + # regular expression, a boolean, or None. + if (isinstance(value, str) or isinstance(value, Callable) or hasattr(value, 'match') + or isinstance(value, bool) or value is None): + return value + + # If it's a bytestring, convert it to Unicode, treating it as UTF-8. + if isinstance(value, bytes): + return value.decode("utf8") + + # If it's listlike, convert it into a list of strings. + if hasattr(value, '__iter__'): + new_value = [] + for v in value: + if (hasattr(v, '__iter__') and not isinstance(v, bytes) + and not isinstance(v, str)): + # This is almost certainly the user's mistake. In the + # interests of avoiding infinite loops, we'll let + # it through as-is rather than doing a recursive call. + new_value.append(v) + else: + new_value.append(self._normalize_search_value(v)) + return new_value + + # Otherwise, convert it into a Unicode string. + # The unicode(str()) thing is so this will do the same thing on Python 2 + # and Python 3. + return str(str(value)) + + def __str__(self): + if self.text: + return self.text + else: + return "%s|%s" % (self.name, self.attrs) + + def search_tag(self, markup_name=None, markup_attrs={}): + found = None + markup = None + if isinstance(markup_name, Tag): + markup = markup_name + markup_attrs = markup + call_function_with_tag_data = ( + isinstance(self.name, Callable) + and not isinstance(markup_name, Tag)) + + if ((not self.name) + or call_function_with_tag_data + or (markup and self._matches(markup, self.name)) + or (not markup and self._matches(markup_name, self.name))): + if call_function_with_tag_data: + match = self.name(markup_name, markup_attrs) + else: + match = True + markup_attr_map = None + for attr, match_against in list(self.attrs.items()): + if not markup_attr_map: + if hasattr(markup_attrs, 'get'): + markup_attr_map = markup_attrs + else: + markup_attr_map = {} + for k, v in markup_attrs: + markup_attr_map[k] = v + attr_value = markup_attr_map.get(attr) + if not self._matches(attr_value, match_against): + match = False + break + if match: + if markup: + found = markup + else: + found = markup_name + if found and self.text and not self._matches(found.string, self.text): + found = None + return found + searchTag = search_tag + + def search(self, markup): + # print 'looking for %s in %s' % (self, markup) + found = None + # If given a list of items, scan it for a text element that + # matches. + if hasattr(markup, '__iter__') and not isinstance(markup, (Tag, str)): + for element in markup: + if isinstance(element, NavigableString) \ + and self.search(element): + found = element + break + # If it's a Tag, make sure its name or attributes match. + # Don't bother with Tags if we're searching for text. + elif isinstance(markup, Tag): + if not self.text or self.name or self.attrs: + found = self.search_tag(markup) + # If it's text, make sure the text matches. + elif isinstance(markup, NavigableString) or \ + isinstance(markup, str): + if not self.name and not self.attrs and self._matches(markup, self.text): + found = markup + else: + raise Exception( + "I don't know how to match against a %s" % markup.__class__) + return found + + def _matches(self, markup, match_against, already_tried=None): + # print u"Matching %s against %s" % (markup, match_against) + result = False + if isinstance(markup, list) or isinstance(markup, tuple): + # This should only happen when searching a multi-valued attribute + # like 'class'. + for item in markup: + if self._matches(item, match_against): + return True + # We didn't match any particular value of the multivalue + # attribute, but maybe we match the attribute value when + # considered as a string. + if self._matches(' '.join(markup), match_against): + return True + return False + + if match_against is True: + # True matches any non-None value. + return markup is not None + + if isinstance(match_against, Callable): + return match_against(markup) + + # Custom callables take the tag as an argument, but all + # other ways of matching match the tag name as a string. + original_markup = markup + if isinstance(markup, Tag): + markup = markup.name + + # Ensure that `markup` is either a Unicode string, or None. + markup = self._normalize_search_value(markup) + + if markup is None: + # None matches None, False, an empty string, an empty list, and so on. + return not match_against + + if (hasattr(match_against, '__iter__') + and not isinstance(match_against, str)): + # We're asked to match against an iterable of items. + # The markup must be match at least one item in the + # iterable. We'll try each one in turn. + # + # To avoid infinite recursion we need to keep track of + # items we've already seen. + if not already_tried: + already_tried = set() + for item in match_against: + if item.__hash__: + key = item + else: + key = id(item) + if key in already_tried: + continue + else: + already_tried.add(key) + if self._matches(original_markup, item, already_tried): + return True + else: + return False + + # Beyond this point we might need to run the test twice: once against + # the tag's name and once against its prefixed name. + match = False + + if not match and isinstance(match_against, str): + # Exact string match + match = markup == match_against + + if not match and hasattr(match_against, 'search'): + # Regexp match + return match_against.search(markup) + + if (not match + and isinstance(original_markup, Tag) + and original_markup.prefix): + # Try the whole thing again with the prefixed tag name. + return self._matches( + original_markup.prefix + ':' + original_markup.name, match_against + ) + + return match + + +class ResultSet(list): + """A ResultSet is just a list that keeps track of the SoupStrainer + that created it.""" + def __init__(self, source, result=()): + super(ResultSet, self).__init__(result) + self.source = source + + def __getattr__(self, key): + raise AttributeError( + "ResultSet object has no attribute '%s'. You're probably treating a list of items like a single item. Did you call find_all() when you meant to call find()?" % key + ) diff --git a/libs/bs4/formatter.py b/libs/bs4/formatter.py new file mode 100644 index 000000000..7dbaa3850 --- /dev/null +++ b/libs/bs4/formatter.py @@ -0,0 +1,99 @@ +from bs4.dammit import EntitySubstitution + +class Formatter(EntitySubstitution): + """Describes a strategy to use when outputting a parse tree to a string. + + Some parts of this strategy come from the distinction between + HTML4, HTML5, and XML. Others are configurable by the user. + """ + # Registries of XML and HTML formatters. + XML_FORMATTERS = {} + HTML_FORMATTERS = {} + + HTML = 'html' + XML = 'xml' + + HTML_DEFAULTS = dict( + cdata_containing_tags=set(["script", "style"]), + ) + + def _default(self, language, value, kwarg): + if value is not None: + return value + if language == self.XML: + return set() + return self.HTML_DEFAULTS[kwarg] + + def __init__( + self, language=None, entity_substitution=None, + void_element_close_prefix='/', cdata_containing_tags=None, + ): + """ + + :param void_element_close_prefix: By default, represent void + elements as <tag/> rather than <tag> + """ + self.language = language + self.entity_substitution = entity_substitution + self.void_element_close_prefix = void_element_close_prefix + self.cdata_containing_tags = self._default( + language, cdata_containing_tags, 'cdata_containing_tags' + ) + + def substitute(self, ns): + """Process a string that needs to undergo entity substitution.""" + if not self.entity_substitution: + return ns + from .element import NavigableString + if (isinstance(ns, NavigableString) + and ns.parent is not None + and ns.parent.name in self.cdata_containing_tags): + # Do nothing. + return ns + # Substitute. + return self.entity_substitution(ns) + + def attribute_value(self, value): + """Process the value of an attribute.""" + return self.substitute(value) + + def attributes(self, tag): + """Reorder a tag's attributes however you want.""" + return sorted(tag.attrs.items()) + + +class HTMLFormatter(Formatter): + REGISTRY = {} + def __init__(self, *args, **kwargs): + return super(HTMLFormatter, self).__init__(self.HTML, *args, **kwargs) + + +class XMLFormatter(Formatter): + REGISTRY = {} + def __init__(self, *args, **kwargs): + return super(XMLFormatter, self).__init__(self.XML, *args, **kwargs) + + +# Set up aliases for the default formatters. +HTMLFormatter.REGISTRY['html'] = HTMLFormatter( + entity_substitution=EntitySubstitution.substitute_html +) +HTMLFormatter.REGISTRY["html5"] = HTMLFormatter( + entity_substitution=EntitySubstitution.substitute_html, + void_element_close_prefix = None +) +HTMLFormatter.REGISTRY["minimal"] = HTMLFormatter( + entity_substitution=EntitySubstitution.substitute_xml +) +HTMLFormatter.REGISTRY[None] = HTMLFormatter( + entity_substitution=None +) +XMLFormatter.REGISTRY["html"] = XMLFormatter( + entity_substitution=EntitySubstitution.substitute_html +) +XMLFormatter.REGISTRY["minimal"] = XMLFormatter( + entity_substitution=EntitySubstitution.substitute_xml +) +XMLFormatter.REGISTRY[None] = Formatter( + Formatter(Formatter.XML, entity_substitution=None) +) diff --git a/libs/bs4/testing.py b/libs/bs4/testing.py new file mode 100644 index 000000000..cc9966601 --- /dev/null +++ b/libs/bs4/testing.py @@ -0,0 +1,992 @@ +# encoding: utf-8 +"""Helper classes for tests.""" + +# Use of this source code is governed by the MIT license. +__license__ = "MIT" + +import pickle +import copy +import functools +import unittest +from unittest import TestCase +from bs4 import BeautifulSoup +from bs4.element import ( + CharsetMetaAttributeValue, + Comment, + ContentMetaAttributeValue, + Doctype, + SoupStrainer, + Tag +) + +from bs4.builder import HTMLParserTreeBuilder +default_builder = HTMLParserTreeBuilder + +BAD_DOCUMENT = """A bare string +<!DOCTYPE xsl:stylesheet SYSTEM "htmlent.dtd"> +<!DOCTYPE xsl:stylesheet PUBLIC "htmlent.dtd"> +<div><![CDATA[A CDATA section where it doesn't belong]]></div> +<div><svg><![CDATA[HTML5 does allow CDATA sections in SVG]]></svg></div> +<div>A <meta> tag</div> +<div>A <br> tag that supposedly has contents.</br></div> +<div>AT&T</div> +<div><textarea>Within a textarea, markup like <b> tags and <&<& should be treated as literal</textarea></div> +<div><script>if (i < 2) { alert("<b>Markup within script tags should be treated as literal.</b>"); }</script></div> +<div>This numeric entity is missing the final semicolon: <x t="piñata"></div> +<div><a href="http://example.com/</a> that attribute value never got closed</div> +<div><a href="foo</a>, </a><a href="bar">that attribute value was closed by the subsequent tag</a></div> +<! This document starts with a bogus declaration ><div>a</div> +<div>This document contains <!an incomplete declaration <div>(do you see it?)</div> +<div>This document ends with <!an incomplete declaration +<div><a style={height:21px;}>That attribute value was bogus</a></div> +<! DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN">The doctype is invalid because it contains extra whitespace +<div><table><td nowrap>That boolean attribute had no value</td></table></div> +<div>Here's a nonexistent entity: &#foo; (do you see it?)</div> +<div>This document ends before the entity finishes: > +<div><p>Paragraphs shouldn't contain block display elements, but this one does: <dl><dt>you see?</dt></p> +<b b="20" a="1" b="10" a="2" a="3" a="4">Multiple values for the same attribute.</b> +<div><table><tr><td>Here's a table</td></tr></table></div> +<div><table id="1"><tr><td>Here's a nested table:<table id="2"><tr><td>foo</td></tr></table></td></div> +<div>This tag contains nothing but whitespace: <b> </b></div> +<div><blockquote><p><b>This p tag is cut off by</blockquote></p>the end of the blockquote tag</div> +<div><table><div>This table contains bare markup</div></table></div> +<div><div id="1">\n <a href="link1">This link is never closed.\n</div>\n<div id="2">\n <div id="3">\n <a href="link2">This link is closed.</a>\n </div>\n</div></div> +<div>This document contains a <!DOCTYPE surprise>surprise doctype</div> +<div><a><B><Cd><EFG>Mixed case tags are folded to lowercase</efg></CD></b></A></div> +<div><our\u2603>Tag name contains Unicode characters</our\u2603></div> +<div><a \u2603="snowman">Attribute name contains Unicode characters</a></div> +<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"> +""" + + +class SoupTest(unittest.TestCase): + + @property + def default_builder(self): + return default_builder + + def soup(self, markup, **kwargs): + """Build a Beautiful Soup object from markup.""" + builder = kwargs.pop('builder', self.default_builder) + return BeautifulSoup(markup, builder=builder, **kwargs) + + def document_for(self, markup, **kwargs): + """Turn an HTML fragment into a document. + + The details depend on the builder. + """ + return self.default_builder(**kwargs).test_fragment_to_document(markup) + + def assertSoupEquals(self, to_parse, compare_parsed_to=None): + builder = self.default_builder + obj = BeautifulSoup(to_parse, builder=builder) + if compare_parsed_to is None: + compare_parsed_to = to_parse + + self.assertEqual(obj.decode(), self.document_for(compare_parsed_to)) + + def assertConnectedness(self, element): + """Ensure that next_element and previous_element are properly + set for all descendants of the given element. + """ + earlier = None + for e in element.descendants: + if earlier: + self.assertEqual(e, earlier.next_element) + self.assertEqual(earlier, e.previous_element) + earlier = e + + def linkage_validator(self, el, _recursive_call=False): + """Ensure proper linkage throughout the document.""" + descendant = None + # Document element should have no previous element or previous sibling. + # It also shouldn't have a next sibling. + if el.parent is None: + assert el.previous_element is None,\ + "Bad previous_element\nNODE: {}\nPREV: {}\nEXPECTED: {}".format( + el, el.previous_element, None + ) + assert el.previous_sibling is None,\ + "Bad previous_sibling\nNODE: {}\nPREV: {}\nEXPECTED: {}".format( + el, el.previous_sibling, None + ) + assert el.next_sibling is None,\ + "Bad next_sibling\nNODE: {}\nNEXT: {}\nEXPECTED: {}".format( + el, el.next_sibling, None + ) + + idx = 0 + child = None + last_child = None + last_idx = len(el.contents) - 1 + for child in el.contents: + descendant = None + + # Parent should link next element to their first child + # That child should have no previous sibling + if idx == 0: + if el.parent is not None: + assert el.next_element is child,\ + "Bad next_element\nNODE: {}\nNEXT: {}\nEXPECTED: {}".format( + el, el.next_element, child + ) + assert child.previous_element is el,\ + "Bad previous_element\nNODE: {}\nPREV: {}\nEXPECTED: {}".format( + child, child.previous_element, el + ) + assert child.previous_sibling is None,\ + "Bad previous_sibling\nNODE: {}\nPREV {}\nEXPECTED: {}".format( + child, child.previous_sibling, None + ) + + # If not the first child, previous index should link as sibling to this index + # Previous element should match the last index or the last bubbled up descendant + else: + assert child.previous_sibling is el.contents[idx - 1],\ + "Bad previous_sibling\nNODE: {}\nPREV {}\nEXPECTED {}".format( + child, child.previous_sibling, el.contents[idx - 1] + ) + assert el.contents[idx - 1].next_sibling is child,\ + "Bad next_sibling\nNODE: {}\nNEXT {}\nEXPECTED {}".format( + el.contents[idx - 1], el.contents[idx - 1].next_sibling, child + ) + + if last_child is not None: + assert child.previous_element is last_child,\ + "Bad previous_element\nNODE: {}\nPREV {}\nEXPECTED {}\nCONTENTS {}".format( + child, child.previous_element, last_child, child.parent.contents + ) + assert last_child.next_element is child,\ + "Bad next_element\nNODE: {}\nNEXT {}\nEXPECTED {}".format( + last_child, last_child.next_element, child + ) + + if isinstance(child, Tag) and child.contents: + descendant = self.linkage_validator(child, True) + # A bubbled up descendant should have no next siblings + assert descendant.next_sibling is None,\ + "Bad next_sibling\nNODE: {}\nNEXT {}\nEXPECTED {}".format( + descendant, descendant.next_sibling, None + ) + + # Mark last child as either the bubbled up descendant or the current child + if descendant is not None: + last_child = descendant + else: + last_child = child + + # If last child, there are non next siblings + if idx == last_idx: + assert child.next_sibling is None,\ + "Bad next_sibling\nNODE: {}\nNEXT {}\nEXPECTED {}".format( + child, child.next_sibling, None + ) + idx += 1 + + child = descendant if descendant is not None else child + if child is None: + child = el + + if not _recursive_call and child is not None: + target = el + while True: + if target is None: + assert child.next_element is None, \ + "Bad next_element\nNODE: {}\nNEXT {}\nEXPECTED {}".format( + child, child.next_element, None + ) + break + elif target.next_sibling is not None: + assert child.next_element is target.next_sibling, \ + "Bad next_element\nNODE: {}\nNEXT {}\nEXPECTED {}".format( + child, child.next_element, target.next_sibling + ) + break + target = target.parent + + # We are done, so nothing to return + return None + else: + # Return the child to the recursive caller + return child + + +class HTMLTreeBuilderSmokeTest(object): + + """A basic test of a treebuilder's competence. + + Any HTML treebuilder, present or future, should be able to pass + these tests. With invalid markup, there's room for interpretation, + and different parsers can handle it differently. But with the + markup in these tests, there's not much room for interpretation. + """ + + def test_empty_element_tags(self): + """Verify that all HTML4 and HTML5 empty element (aka void element) tags + are handled correctly. + """ + for name in [ + 'area', 'base', 'br', 'col', 'embed', 'hr', 'img', 'input', 'keygen', 'link', 'menuitem', 'meta', 'param', 'source', 'track', 'wbr', + 'spacer', 'frame' + ]: + soup = self.soup("") + new_tag = soup.new_tag(name) + self.assertEqual(True, new_tag.is_empty_element) + + def test_pickle_and_unpickle_identity(self): + # Pickling a tree, then unpickling it, yields a tree identical + # to the original. + tree = self.soup("<a><b>foo</a>") + dumped = pickle.dumps(tree, 2) + loaded = pickle.loads(dumped) + self.assertEqual(loaded.__class__, BeautifulSoup) + self.assertEqual(loaded.decode(), tree.decode()) + + def assertDoctypeHandled(self, doctype_fragment): + """Assert that a given doctype string is handled correctly.""" + doctype_str, soup = self._document_with_doctype(doctype_fragment) + + # Make sure a Doctype object was created. + doctype = soup.contents[0] + self.assertEqual(doctype.__class__, Doctype) + self.assertEqual(doctype, doctype_fragment) + self.assertEqual(str(soup)[:len(doctype_str)], doctype_str) + + # Make sure that the doctype was correctly associated with the + # parse tree and that the rest of the document parsed. + self.assertEqual(soup.p.contents[0], 'foo') + + def _document_with_doctype(self, doctype_fragment): + """Generate and parse a document with the given doctype.""" + doctype = '<!DOCTYPE %s>' % doctype_fragment + markup = doctype + '\n<p>foo</p>' + soup = self.soup(markup) + return doctype, soup + + def test_normal_doctypes(self): + """Make sure normal, everyday HTML doctypes are handled correctly.""" + self.assertDoctypeHandled("html") + self.assertDoctypeHandled( + 'html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"') + + def test_empty_doctype(self): + soup = self.soup("<!DOCTYPE>") + doctype = soup.contents[0] + self.assertEqual("", doctype.strip()) + + def test_public_doctype_with_url(self): + doctype = 'html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"' + self.assertDoctypeHandled(doctype) + + def test_system_doctype(self): + self.assertDoctypeHandled('foo SYSTEM "http://www.example.com/"') + + def test_namespaced_system_doctype(self): + # We can handle a namespaced doctype with a system ID. + self.assertDoctypeHandled('xsl:stylesheet SYSTEM "htmlent.dtd"') + + def test_namespaced_public_doctype(self): + # Test a namespaced doctype with a public id. + self.assertDoctypeHandled('xsl:stylesheet PUBLIC "htmlent.dtd"') + + def test_real_xhtml_document(self): + """A real XHTML document should come out more or less the same as it went in.""" + markup = b"""<?xml version="1.0" encoding="utf-8"?> +<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"> +<html xmlns="http://www.w3.org/1999/xhtml"> +<head><title>Hello.</title></head> +<body>Goodbye.</body> +</html>""" + soup = self.soup(markup) + self.assertEqual( + soup.encode("utf-8").replace(b"\n", b""), + markup.replace(b"\n", b"")) + + def test_namespaced_html(self): + """When a namespaced XML document is parsed as HTML it should + be treated as HTML with weird tag names. + """ + markup = b"""<ns1:foo>content</ns1:foo><ns1:foo/><ns2:foo/>""" + soup = self.soup(markup) + self.assertEqual(2, len(soup.find_all("ns1:foo"))) + + def test_processing_instruction(self): + # We test both Unicode and bytestring to verify that + # process_markup correctly sets processing_instruction_class + # even when the markup is already Unicode and there is no + # need to process anything. + markup = """<?PITarget PIContent?>""" + soup = self.soup(markup) + self.assertEqual(markup, soup.decode()) + + markup = b"""<?PITarget PIContent?>""" + soup = self.soup(markup) + self.assertEqual(markup, soup.encode("utf8")) + + def test_deepcopy(self): + """Make sure you can copy the tree builder. + + This is important because the builder is part of a + BeautifulSoup object, and we want to be able to copy that. + """ + copy.deepcopy(self.default_builder) + + def test_p_tag_is_never_empty_element(self): + """A <p> tag is never designated as an empty-element tag. + + Even if the markup shows it as an empty-element tag, it + shouldn't be presented that way. + """ + soup = self.soup("<p/>") + self.assertFalse(soup.p.is_empty_element) + self.assertEqual(str(soup.p), "<p></p>") + + def test_unclosed_tags_get_closed(self): + """A tag that's not closed by the end of the document should be closed. + + This applies to all tags except empty-element tags. + """ + self.assertSoupEquals("<p>", "<p></p>") + self.assertSoupEquals("<b>", "<b></b>") + + self.assertSoupEquals("<br>", "<br/>") + + def test_br_is_always_empty_element_tag(self): + """A <br> tag is designated as an empty-element tag. + + Some parsers treat <br></br> as one <br/> tag, some parsers as + two tags, but it should always be an empty-element tag. + """ + soup = self.soup("<br></br>") + self.assertTrue(soup.br.is_empty_element) + self.assertEqual(str(soup.br), "<br/>") + + def test_nested_formatting_elements(self): + self.assertSoupEquals("<em><em></em></em>") + + def test_double_head(self): + html = '''<!DOCTYPE html> +<html> +<head> +<title>Ordinary HEAD element test</title> +</head> +<script type="text/javascript"> +alert("Help!"); +</script> +<body> +Hello, world! +</body> +</html> +''' + soup = self.soup(html) + self.assertEqual("text/javascript", soup.find('script')['type']) + + def test_comment(self): + # Comments are represented as Comment objects. + markup = "<p>foo<!--foobar-->baz</p>" + self.assertSoupEquals(markup) + + soup = self.soup(markup) + comment = soup.find(text="foobar") + self.assertEqual(comment.__class__, Comment) + + # The comment is properly integrated into the tree. + foo = soup.find(text="foo") + self.assertEqual(comment, foo.next_element) + baz = soup.find(text="baz") + self.assertEqual(comment, baz.previous_element) + + def test_preserved_whitespace_in_pre_and_textarea(self): + """Whitespace must be preserved in <pre> and <textarea> tags, + even if that would mean not prettifying the markup. + """ + pre_markup = "<pre> </pre>" + textarea_markup = "<textarea> woo\nwoo </textarea>" + self.assertSoupEquals(pre_markup) + self.assertSoupEquals(textarea_markup) + + soup = self.soup(pre_markup) + self.assertEqual(soup.pre.prettify(), pre_markup) + + soup = self.soup(textarea_markup) + self.assertEqual(soup.textarea.prettify(), textarea_markup) + + soup = self.soup("<textarea></textarea>") + self.assertEqual(soup.textarea.prettify(), "<textarea></textarea>") + + def test_nested_inline_elements(self): + """Inline elements can be nested indefinitely.""" + b_tag = "<b>Inside a B tag</b>" + self.assertSoupEquals(b_tag) + + nested_b_tag = "<p>A <i>nested <b>tag</b></i></p>" + self.assertSoupEquals(nested_b_tag) + + double_nested_b_tag = "<p>A <a>doubly <i>nested <b>tag</b></i></a></p>" + self.assertSoupEquals(nested_b_tag) + + def test_nested_block_level_elements(self): + """Block elements can be nested.""" + soup = self.soup('<blockquote><p><b>Foo</b></p></blockquote>') + blockquote = soup.blockquote + self.assertEqual(blockquote.p.b.string, 'Foo') + self.assertEqual(blockquote.b.string, 'Foo') + + def test_correctly_nested_tables(self): + """One table can go inside another one.""" + markup = ('<table id="1">' + '<tr>' + "<td>Here's another table:" + '<table id="2">' + '<tr><td>foo</td></tr>' + '</table></td>') + + self.assertSoupEquals( + markup, + '<table id="1"><tr><td>Here\'s another table:' + '<table id="2"><tr><td>foo</td></tr></table>' + '</td></tr></table>') + + self.assertSoupEquals( + "<table><thead><tr><td>Foo</td></tr></thead>" + "<tbody><tr><td>Bar</td></tr></tbody>" + "<tfoot><tr><td>Baz</td></tr></tfoot></table>") + + def test_multivalued_attribute_with_whitespace(self): + # Whitespace separating the values of a multi-valued attribute + # should be ignored. + + markup = '<div class=" foo bar "></a>' + soup = self.soup(markup) + self.assertEqual(['foo', 'bar'], soup.div['class']) + + # If you search by the literal name of the class it's like the whitespace + # wasn't there. + self.assertEqual(soup.div, soup.find('div', class_="foo bar")) + + def test_deeply_nested_multivalued_attribute(self): + # html5lib can set the attributes of the same tag many times + # as it rearranges the tree. This has caused problems with + # multivalued attributes. + markup = '<table><div><div class="css"></div></div></table>' + soup = self.soup(markup) + self.assertEqual(["css"], soup.div.div['class']) + + def test_multivalued_attribute_on_html(self): + # html5lib uses a different API to set the attributes ot the + # <html> tag. This has caused problems with multivalued + # attributes. + markup = '<html class="a b"></html>' + soup = self.soup(markup) + self.assertEqual(["a", "b"], soup.html['class']) + + def test_angle_brackets_in_attribute_values_are_escaped(self): + self.assertSoupEquals('<a b="<a>"></a>', '<a b="<a>"></a>') + + def test_strings_resembling_character_entity_references(self): + # "&T" and "&p" look like incomplete character entities, but they are + # not. + self.assertSoupEquals( + "<p>• AT&T is in the s&p 500</p>", + "<p>\u2022 AT&T is in the s&p 500</p>" + ) + + def test_apos_entity(self): + self.assertSoupEquals( + "<p>Bob's Bar</p>", + "<p>Bob's Bar</p>", + ) + + def test_entities_in_foreign_document_encoding(self): + # “ and ” are invalid numeric entities referencing + # Windows-1252 characters. - references a character common + # to Windows-1252 and Unicode, and ☃ references a + # character only found in Unicode. + # + # All of these entities should be converted to Unicode + # characters. + markup = "<p>“Hello” -☃</p>" + soup = self.soup(markup) + self.assertEqual("“Hello” -☃", soup.p.string) + + def test_entities_in_attributes_converted_to_unicode(self): + expect = '<p id="pi\N{LATIN SMALL LETTER N WITH TILDE}ata"></p>' + self.assertSoupEquals('<p id="piñata"></p>', expect) + self.assertSoupEquals('<p id="piñata"></p>', expect) + self.assertSoupEquals('<p id="piñata"></p>', expect) + self.assertSoupEquals('<p id="piñata"></p>', expect) + + def test_entities_in_text_converted_to_unicode(self): + expect = '<p>pi\N{LATIN SMALL LETTER N WITH TILDE}ata</p>' + self.assertSoupEquals("<p>piñata</p>", expect) + self.assertSoupEquals("<p>piñata</p>", expect) + self.assertSoupEquals("<p>piñata</p>", expect) + self.assertSoupEquals("<p>piñata</p>", expect) + + def test_quot_entity_converted_to_quotation_mark(self): + self.assertSoupEquals("<p>I said "good day!"</p>", + '<p>I said "good day!"</p>') + + def test_out_of_range_entity(self): + expect = "\N{REPLACEMENT CHARACTER}" + self.assertSoupEquals("�", expect) + self.assertSoupEquals("�", expect) + self.assertSoupEquals("�", expect) + + def test_multipart_strings(self): + "Mostly to prevent a recurrence of a bug in the html5lib treebuilder." + soup = self.soup("<html><h2>\nfoo</h2><p></p></html>") + self.assertEqual("p", soup.h2.string.next_element.name) + self.assertEqual("p", soup.p.name) + self.assertConnectedness(soup) + + def test_empty_element_tags(self): + """Verify consistent handling of empty-element tags, + no matter how they come in through the markup. + """ + self.assertSoupEquals('<br/><br/><br/>', "<br/><br/><br/>") + self.assertSoupEquals('<br /><br /><br />', "<br/><br/><br/>") + + def test_head_tag_between_head_and_body(self): + "Prevent recurrence of a bug in the html5lib treebuilder." + content = """<html><head></head> + <link></link> + <body>foo</body> +</html> +""" + soup = self.soup(content) + self.assertNotEqual(None, soup.html.body) + self.assertConnectedness(soup) + + def test_multiple_copies_of_a_tag(self): + "Prevent recurrence of a bug in the html5lib treebuilder." + content = """<!DOCTYPE html> +<html> + <body> + <article id="a" > + <div><a href="1"></div> + <footer> + <a href="2"></a> + </footer> + </article> + </body> +</html> +""" + soup = self.soup(content) + self.assertConnectedness(soup.article) + + def test_basic_namespaces(self): + """Parsers don't need to *understand* namespaces, but at the + very least they should not choke on namespaces or lose + data.""" + + markup = b'<html xmlns="http://www.w3.org/1999/xhtml" xmlns:mathml="http://www.w3.org/1998/Math/MathML" xmlns:svg="http://www.w3.org/2000/svg"><head></head><body><mathml:msqrt>4</mathml:msqrt><b svg:fill="red"></b></body></html>' + soup = self.soup(markup) + self.assertEqual(markup, soup.encode()) + html = soup.html + self.assertEqual('http://www.w3.org/1999/xhtml', soup.html['xmlns']) + self.assertEqual( + 'http://www.w3.org/1998/Math/MathML', soup.html['xmlns:mathml']) + self.assertEqual( + 'http://www.w3.org/2000/svg', soup.html['xmlns:svg']) + + def test_multivalued_attribute_value_becomes_list(self): + markup = b'<a class="foo bar">' + soup = self.soup(markup) + self.assertEqual(['foo', 'bar'], soup.a['class']) + + # + # Generally speaking, tests below this point are more tests of + # Beautiful Soup than tests of the tree builders. But parsers are + # weird, so we run these tests separately for every tree builder + # to detect any differences between them. + # + + def test_can_parse_unicode_document(self): + # A seemingly innocuous document... but it's in Unicode! And + # it contains characters that can't be represented in the + # encoding found in the declaration! The horror! + markup = '<html><head><meta encoding="euc-jp"></head><body>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</body>' + soup = self.soup(markup) + self.assertEqual('Sacr\xe9 bleu!', soup.body.string) + + def test_soupstrainer(self): + """Parsers should be able to work with SoupStrainers.""" + strainer = SoupStrainer("b") + soup = self.soup("A <b>bold</b> <meta/> <i>statement</i>", + parse_only=strainer) + self.assertEqual(soup.decode(), "<b>bold</b>") + + def test_single_quote_attribute_values_become_double_quotes(self): + self.assertSoupEquals("<foo attr='bar'></foo>", + '<foo attr="bar"></foo>') + + def test_attribute_values_with_nested_quotes_are_left_alone(self): + text = """<foo attr='bar "brawls" happen'>a</foo>""" + self.assertSoupEquals(text) + + def test_attribute_values_with_double_nested_quotes_get_quoted(self): + text = """<foo attr='bar "brawls" happen'>a</foo>""" + soup = self.soup(text) + soup.foo['attr'] = 'Brawls happen at "Bob\'s Bar"' + self.assertSoupEquals( + soup.foo.decode(), + """<foo attr="Brawls happen at "Bob\'s Bar"">a</foo>""") + + def test_ampersand_in_attribute_value_gets_escaped(self): + self.assertSoupEquals('<this is="really messed up & stuff"></this>', + '<this is="really messed up & stuff"></this>') + + self.assertSoupEquals( + '<a href="http://example.org?a=1&b=2;3">foo</a>', + '<a href="http://example.org?a=1&b=2;3">foo</a>') + + def test_escaped_ampersand_in_attribute_value_is_left_alone(self): + self.assertSoupEquals('<a href="http://example.org?a=1&b=2;3"></a>') + + def test_entities_in_strings_converted_during_parsing(self): + # Both XML and HTML entities are converted to Unicode characters + # during parsing. + text = "<p><<sacré bleu!>></p>" + expected = "<p><<sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!>></p>" + self.assertSoupEquals(text, expected) + + def test_smart_quotes_converted_on_the_way_in(self): + # Microsoft smart quotes are converted to Unicode characters during + # parsing. + quote = b"<p>\x91Foo\x92</p>" + soup = self.soup(quote) + self.assertEqual( + soup.p.string, + "\N{LEFT SINGLE QUOTATION MARK}Foo\N{RIGHT SINGLE QUOTATION MARK}") + + def test_non_breaking_spaces_converted_on_the_way_in(self): + soup = self.soup("<a> </a>") + self.assertEqual(soup.a.string, "\N{NO-BREAK SPACE}" * 2) + + def test_entities_converted_on_the_way_out(self): + text = "<p><<sacré bleu!>></p>" + expected = "<p><<sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!>></p>".encode("utf-8") + soup = self.soup(text) + self.assertEqual(soup.p.encode("utf-8"), expected) + + def test_real_iso_latin_document(self): + # Smoke test of interrelated functionality, using an + # easy-to-understand document. + + # Here it is in Unicode. Note that it claims to be in ISO-Latin-1. + unicode_html = '<html><head><meta content="text/html; charset=ISO-Latin-1" http-equiv="Content-type"/></head><body><p>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</p></body></html>' + + # That's because we're going to encode it into ISO-Latin-1, and use + # that to test. + iso_latin_html = unicode_html.encode("iso-8859-1") + + # Parse the ISO-Latin-1 HTML. + soup = self.soup(iso_latin_html) + # Encode it to UTF-8. + result = soup.encode("utf-8") + + # What do we expect the result to look like? Well, it would + # look like unicode_html, except that the META tag would say + # UTF-8 instead of ISO-Latin-1. + expected = unicode_html.replace("ISO-Latin-1", "utf-8") + + # And, of course, it would be in UTF-8, not Unicode. + expected = expected.encode("utf-8") + + # Ta-da! + self.assertEqual(result, expected) + + def test_real_shift_jis_document(self): + # Smoke test to make sure the parser can handle a document in + # Shift-JIS encoding, without choking. + shift_jis_html = ( + b'<html><head></head><body><pre>' + b'\x82\xb1\x82\xea\x82\xcdShift-JIS\x82\xc5\x83R\x81[\x83f' + b'\x83B\x83\x93\x83O\x82\xb3\x82\xea\x82\xbd\x93\xfa\x96{\x8c' + b'\xea\x82\xcc\x83t\x83@\x83C\x83\x8b\x82\xc5\x82\xb7\x81B' + b'</pre></body></html>') + unicode_html = shift_jis_html.decode("shift-jis") + soup = self.soup(unicode_html) + + # Make sure the parse tree is correctly encoded to various + # encodings. + self.assertEqual(soup.encode("utf-8"), unicode_html.encode("utf-8")) + self.assertEqual(soup.encode("euc_jp"), unicode_html.encode("euc_jp")) + + def test_real_hebrew_document(self): + # A real-world test to make sure we can convert ISO-8859-9 (a + # Hebrew encoding) to UTF-8. + hebrew_document = b'<html><head><title>Hebrew (ISO 8859-8) in Visual Directionality</title></head><body><h1>Hebrew (ISO 8859-8) in Visual Directionality</h1>\xed\xe5\xec\xf9</body></html>' + soup = self.soup( + hebrew_document, from_encoding="iso8859-8") + # Some tree builders call it iso8859-8, others call it iso-8859-9. + # That's not a difference we really care about. + assert soup.original_encoding in ('iso8859-8', 'iso-8859-8') + self.assertEqual( + soup.encode('utf-8'), + hebrew_document.decode("iso8859-8").encode("utf-8")) + + def test_meta_tag_reflects_current_encoding(self): + # Here's the <meta> tag saying that a document is + # encoded in Shift-JIS. + meta_tag = ('<meta content="text/html; charset=x-sjis" ' + 'http-equiv="Content-type"/>') + + # Here's a document incorporating that meta tag. + shift_jis_html = ( + '<html><head>\n%s\n' + '<meta http-equiv="Content-language" content="ja"/>' + '</head><body>Shift-JIS markup goes here.') % meta_tag + soup = self.soup(shift_jis_html) + + # Parse the document, and the charset is seemingly unaffected. + parsed_meta = soup.find('meta', {'http-equiv': 'Content-type'}) + content = parsed_meta['content'] + self.assertEqual('text/html; charset=x-sjis', content) + + # But that value is actually a ContentMetaAttributeValue object. + self.assertTrue(isinstance(content, ContentMetaAttributeValue)) + + # And it will take on a value that reflects its current + # encoding. + self.assertEqual('text/html; charset=utf8', content.encode("utf8")) + + # For the rest of the story, see TestSubstitutions in + # test_tree.py. + + def test_html5_style_meta_tag_reflects_current_encoding(self): + # Here's the <meta> tag saying that a document is + # encoded in Shift-JIS. + meta_tag = ('<meta id="encoding" charset="x-sjis" />') + + # Here's a document incorporating that meta tag. + shift_jis_html = ( + '<html><head>\n%s\n' + '<meta http-equiv="Content-language" content="ja"/>' + '</head><body>Shift-JIS markup goes here.') % meta_tag + soup = self.soup(shift_jis_html) + + # Parse the document, and the charset is seemingly unaffected. + parsed_meta = soup.find('meta', id="encoding") + charset = parsed_meta['charset'] + self.assertEqual('x-sjis', charset) + + # But that value is actually a CharsetMetaAttributeValue object. + self.assertTrue(isinstance(charset, CharsetMetaAttributeValue)) + + # And it will take on a value that reflects its current + # encoding. + self.assertEqual('utf8', charset.encode("utf8")) + + def test_tag_with_no_attributes_can_have_attributes_added(self): + data = self.soup("<a>text</a>") + data.a['foo'] = 'bar' + self.assertEqual('<a foo="bar">text</a>', data.a.decode()) + + def test_worst_case(self): + """Test the worst case (currently) for linking issues.""" + + soup = self.soup(BAD_DOCUMENT) + self.linkage_validator(soup) + + +class XMLTreeBuilderSmokeTest(object): + + def test_pickle_and_unpickle_identity(self): + # Pickling a tree, then unpickling it, yields a tree identical + # to the original. + tree = self.soup("<a><b>foo</a>") + dumped = pickle.dumps(tree, 2) + loaded = pickle.loads(dumped) + self.assertEqual(loaded.__class__, BeautifulSoup) + self.assertEqual(loaded.decode(), tree.decode()) + + def test_docstring_generated(self): + soup = self.soup("<root/>") + self.assertEqual( + soup.encode(), b'<?xml version="1.0" encoding="utf-8"?>\n<root/>') + + def test_xml_declaration(self): + markup = b"""<?xml version="1.0" encoding="utf8"?>\n<foo/>""" + soup = self.soup(markup) + self.assertEqual(markup, soup.encode("utf8")) + + def test_processing_instruction(self): + markup = b"""<?xml version="1.0" encoding="utf8"?>\n<?PITarget PIContent?>""" + soup = self.soup(markup) + self.assertEqual(markup, soup.encode("utf8")) + + def test_real_xhtml_document(self): + """A real XHTML document should come out *exactly* the same as it went in.""" + markup = b"""<?xml version="1.0" encoding="utf-8"?> +<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"> +<html xmlns="http://www.w3.org/1999/xhtml"> +<head><title>Hello.</title></head> +<body>Goodbye.</body> +</html>""" + soup = self.soup(markup) + self.assertEqual( + soup.encode("utf-8"), markup) + + def test_nested_namespaces(self): + doc = b"""<?xml version="1.0" encoding="utf-8"?> +<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.1//EN" "http://www.w3.org/TR/xhtml11/DTD/xhtml11.dtd"> +<parent xmlns="http://ns1/"> +<child xmlns="http://ns2/" xmlns:ns3="http://ns3/"> +<grandchild ns3:attr="value" xmlns="http://ns4/"/> +</child> +</parent>""" + soup = self.soup(doc) + self.assertEqual(doc, soup.encode()) + + def test_formatter_processes_script_tag_for_xml_documents(self): + doc = """ + <script type="text/javascript"> + </script> +""" + soup = BeautifulSoup(doc, "lxml-xml") + # lxml would have stripped this while parsing, but we can add + # it later. + soup.script.string = 'console.log("< < hey > > ");' + encoded = soup.encode() + self.assertTrue(b"< < hey > >" in encoded) + + def test_can_parse_unicode_document(self): + markup = '<?xml version="1.0" encoding="euc-jp"><root>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</root>' + soup = self.soup(markup) + self.assertEqual('Sacr\xe9 bleu!', soup.root.string) + + def test_popping_namespaced_tag(self): + markup = '<rss xmlns:dc="foo"><dc:creator>b</dc:creator><dc:date>2012-07-02T20:33:42Z</dc:date><dc:rights>c</dc:rights><image>d</image></rss>' + soup = self.soup(markup) + self.assertEqual( + str(soup.rss), markup) + + def test_docstring_includes_correct_encoding(self): + soup = self.soup("<root/>") + self.assertEqual( + soup.encode("latin1"), + b'<?xml version="1.0" encoding="latin1"?>\n<root/>') + + def test_large_xml_document(self): + """A large XML document should come out the same as it went in.""" + markup = (b'<?xml version="1.0" encoding="utf-8"?>\n<root>' + + b'0' * (2**12) + + b'</root>') + soup = self.soup(markup) + self.assertEqual(soup.encode("utf-8"), markup) + + + def test_tags_are_empty_element_if_and_only_if_they_are_empty(self): + self.assertSoupEquals("<p>", "<p/>") + self.assertSoupEquals("<p>foo</p>") + + def test_namespaces_are_preserved(self): + markup = '<root xmlns:a="http://example.com/" xmlns:b="http://example.net/"><a:foo>This tag is in the a namespace</a:foo><b:foo>This tag is in the b namespace</b:foo></root>' + soup = self.soup(markup) + root = soup.root + self.assertEqual("http://example.com/", root['xmlns:a']) + self.assertEqual("http://example.net/", root['xmlns:b']) + + def test_closing_namespaced_tag(self): + markup = '<p xmlns:dc="http://purl.org/dc/elements/1.1/"><dc:date>20010504</dc:date></p>' + soup = self.soup(markup) + self.assertEqual(str(soup.p), markup) + + def test_namespaced_attributes(self): + markup = '<foo xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"><bar xsi:schemaLocation="http://www.example.com"/></foo>' + soup = self.soup(markup) + self.assertEqual(str(soup.foo), markup) + + def test_namespaced_attributes_xml_namespace(self): + markup = '<foo xml:lang="fr">bar</foo>' + soup = self.soup(markup) + self.assertEqual(str(soup.foo), markup) + + def test_find_by_prefixed_name(self): + doc = """<?xml version="1.0" encoding="utf-8"?> +<Document xmlns="http://example.com/ns0" + xmlns:ns1="http://example.com/ns1" + xmlns:ns2="http://example.com/ns2" + <ns1:tag>foo</ns1:tag> + <ns1:tag>bar</ns1:tag> + <ns2:tag key="value">baz</ns2:tag> +</Document> +""" + soup = self.soup(doc) + + # There are three <tag> tags. + self.assertEqual(3, len(soup.find_all('tag'))) + + # But two of them are ns1:tag and one of them is ns2:tag. + self.assertEqual(2, len(soup.find_all('ns1:tag'))) + self.assertEqual(1, len(soup.find_all('ns2:tag'))) + + self.assertEqual(1, len(soup.find_all('ns2:tag', key='value'))) + self.assertEqual(3, len(soup.find_all(['ns1:tag', 'ns2:tag']))) + + def test_copy_tag_preserves_namespace(self): + xml = """<?xml version="1.0" encoding="UTF-8" standalone="yes"?> +<w:document xmlns:w="http://example.com/ns0"/>""" + + soup = self.soup(xml) + tag = soup.document + duplicate = copy.copy(tag) + + # The two tags have the same namespace prefix. + self.assertEqual(tag.prefix, duplicate.prefix) + + def test_worst_case(self): + """Test the worst case (currently) for linking issues.""" + + soup = self.soup(BAD_DOCUMENT) + self.linkage_validator(soup) + + +class HTML5TreeBuilderSmokeTest(HTMLTreeBuilderSmokeTest): + """Smoke test for a tree builder that supports HTML5.""" + + def test_real_xhtml_document(self): + # Since XHTML is not HTML5, HTML5 parsers are not tested to handle + # XHTML documents in any particular way. + pass + + def test_html_tags_have_namespace(self): + markup = "<a>" + soup = self.soup(markup) + self.assertEqual("http://www.w3.org/1999/xhtml", soup.a.namespace) + + def test_svg_tags_have_namespace(self): + markup = '<svg><circle/></svg>' + soup = self.soup(markup) + namespace = "http://www.w3.org/2000/svg" + self.assertEqual(namespace, soup.svg.namespace) + self.assertEqual(namespace, soup.circle.namespace) + + + def test_mathml_tags_have_namespace(self): + markup = '<math><msqrt>5</msqrt></math>' + soup = self.soup(markup) + namespace = 'http://www.w3.org/1998/Math/MathML' + self.assertEqual(namespace, soup.math.namespace) + self.assertEqual(namespace, soup.msqrt.namespace) + + def test_xml_declaration_becomes_comment(self): + markup = '<?xml version="1.0" encoding="utf-8"?><html></html>' + soup = self.soup(markup) + self.assertTrue(isinstance(soup.contents[0], Comment)) + self.assertEqual(soup.contents[0], '?xml version="1.0" encoding="utf-8"?') + self.assertEqual("html", soup.contents[0].next_element.name) + +def skipIf(condition, reason): + def nothing(test, *args, **kwargs): + return None + + def decorator(test_item): + if condition: + return nothing + else: + return test_item + + return decorator diff --git a/libs/bs4/tests/__init__.py b/libs/bs4/tests/__init__.py new file mode 100644 index 000000000..142c8cc3f --- /dev/null +++ b/libs/bs4/tests/__init__.py @@ -0,0 +1 @@ +"The beautifulsoup tests." diff --git a/libs/bs4/tests/test_builder_registry.py b/libs/bs4/tests/test_builder_registry.py new file mode 100644 index 000000000..90cad8293 --- /dev/null +++ b/libs/bs4/tests/test_builder_registry.py @@ -0,0 +1,147 @@ +"""Tests of the builder registry.""" + +import unittest +import warnings + +from bs4 import BeautifulSoup +from bs4.builder import ( + builder_registry as registry, + HTMLParserTreeBuilder, + TreeBuilderRegistry, +) + +try: + from bs4.builder import HTML5TreeBuilder + HTML5LIB_PRESENT = True +except ImportError: + HTML5LIB_PRESENT = False + +try: + from bs4.builder import ( + LXMLTreeBuilderForXML, + LXMLTreeBuilder, + ) + LXML_PRESENT = True +except ImportError: + LXML_PRESENT = False + + +class BuiltInRegistryTest(unittest.TestCase): + """Test the built-in registry with the default builders registered.""" + + def test_combination(self): + if LXML_PRESENT: + self.assertEqual(registry.lookup('fast', 'html'), + LXMLTreeBuilder) + + if LXML_PRESENT: + self.assertEqual(registry.lookup('permissive', 'xml'), + LXMLTreeBuilderForXML) + self.assertEqual(registry.lookup('strict', 'html'), + HTMLParserTreeBuilder) + if HTML5LIB_PRESENT: + self.assertEqual(registry.lookup('html5lib', 'html'), + HTML5TreeBuilder) + + def test_lookup_by_markup_type(self): + if LXML_PRESENT: + self.assertEqual(registry.lookup('html'), LXMLTreeBuilder) + self.assertEqual(registry.lookup('xml'), LXMLTreeBuilderForXML) + else: + self.assertEqual(registry.lookup('xml'), None) + if HTML5LIB_PRESENT: + self.assertEqual(registry.lookup('html'), HTML5TreeBuilder) + else: + self.assertEqual(registry.lookup('html'), HTMLParserTreeBuilder) + + def test_named_library(self): + if LXML_PRESENT: + self.assertEqual(registry.lookup('lxml', 'xml'), + LXMLTreeBuilderForXML) + self.assertEqual(registry.lookup('lxml', 'html'), + LXMLTreeBuilder) + if HTML5LIB_PRESENT: + self.assertEqual(registry.lookup('html5lib'), + HTML5TreeBuilder) + + self.assertEqual(registry.lookup('html.parser'), + HTMLParserTreeBuilder) + + def test_beautifulsoup_constructor_does_lookup(self): + + with warnings.catch_warnings(record=True) as w: + # This will create a warning about not explicitly + # specifying a parser, but we'll ignore it. + + # You can pass in a string. + BeautifulSoup("", features="html") + # Or a list of strings. + BeautifulSoup("", features=["html", "fast"]) + + # You'll get an exception if BS can't find an appropriate + # builder. + self.assertRaises(ValueError, BeautifulSoup, + "", features="no-such-feature") + +class RegistryTest(unittest.TestCase): + """Test the TreeBuilderRegistry class in general.""" + + def setUp(self): + self.registry = TreeBuilderRegistry() + + def builder_for_features(self, *feature_list): + cls = type('Builder_' + '_'.join(feature_list), + (object,), {'features' : feature_list}) + + self.registry.register(cls) + return cls + + def test_register_with_no_features(self): + builder = self.builder_for_features() + + # Since the builder advertises no features, you can't find it + # by looking up features. + self.assertEqual(self.registry.lookup('foo'), None) + + # But you can find it by doing a lookup with no features, if + # this happens to be the only registered builder. + self.assertEqual(self.registry.lookup(), builder) + + def test_register_with_features_makes_lookup_succeed(self): + builder = self.builder_for_features('foo', 'bar') + self.assertEqual(self.registry.lookup('foo'), builder) + self.assertEqual(self.registry.lookup('bar'), builder) + + def test_lookup_fails_when_no_builder_implements_feature(self): + builder = self.builder_for_features('foo', 'bar') + self.assertEqual(self.registry.lookup('baz'), None) + + def test_lookup_gets_most_recent_registration_when_no_feature_specified(self): + builder1 = self.builder_for_features('foo') + builder2 = self.builder_for_features('bar') + self.assertEqual(self.registry.lookup(), builder2) + + def test_lookup_fails_when_no_tree_builders_registered(self): + self.assertEqual(self.registry.lookup(), None) + + def test_lookup_gets_most_recent_builder_supporting_all_features(self): + has_one = self.builder_for_features('foo') + has_the_other = self.builder_for_features('bar') + has_both_early = self.builder_for_features('foo', 'bar', 'baz') + has_both_late = self.builder_for_features('foo', 'bar', 'quux') + lacks_one = self.builder_for_features('bar') + has_the_other = self.builder_for_features('foo') + + # There are two builders featuring 'foo' and 'bar', but + # the one that also features 'quux' was registered later. + self.assertEqual(self.registry.lookup('foo', 'bar'), + has_both_late) + + # There is only one builder featuring 'foo', 'bar', and 'baz'. + self.assertEqual(self.registry.lookup('foo', 'bar', 'baz'), + has_both_early) + + def test_lookup_fails_when_cannot_reconcile_requested_features(self): + builder1 = self.builder_for_features('foo', 'bar') + builder2 = self.builder_for_features('foo', 'baz') + self.assertEqual(self.registry.lookup('bar', 'baz'), None) diff --git a/libs/bs4/tests/test_docs.py b/libs/bs4/tests/test_docs.py new file mode 100644 index 000000000..5b9f67709 --- /dev/null +++ b/libs/bs4/tests/test_docs.py @@ -0,0 +1,36 @@ +"Test harness for doctests." + +# pylint: disable-msg=E0611,W0142 + +__metaclass__ = type +__all__ = [ + 'additional_tests', + ] + +import atexit +import doctest +import os +#from pkg_resources import ( +# resource_filename, resource_exists, resource_listdir, cleanup_resources) +import unittest + +DOCTEST_FLAGS = ( + doctest.ELLIPSIS | + doctest.NORMALIZE_WHITESPACE | + doctest.REPORT_NDIFF) + + +# def additional_tests(): +# "Run the doc tests (README.txt and docs/*, if any exist)" +# doctest_files = [ +# os.path.abspath(resource_filename('bs4', 'README.txt'))] +# if resource_exists('bs4', 'docs'): +# for name in resource_listdir('bs4', 'docs'): +# if name.endswith('.txt'): +# doctest_files.append( +# os.path.abspath( +# resource_filename('bs4', 'docs/%s' % name))) +# kwargs = dict(module_relative=False, optionflags=DOCTEST_FLAGS) +# atexit.register(cleanup_resources) +# return unittest.TestSuite(( +# doctest.DocFileSuite(*doctest_files, **kwargs))) diff --git a/libs/bs4/tests/test_html5lib.py b/libs/bs4/tests/test_html5lib.py new file mode 100644 index 000000000..96529b0b3 --- /dev/null +++ b/libs/bs4/tests/test_html5lib.py @@ -0,0 +1,170 @@ +"""Tests to ensure that the html5lib tree builder generates good trees.""" + +import warnings + +try: + from bs4.builder import HTML5TreeBuilder + HTML5LIB_PRESENT = True +except ImportError as e: + HTML5LIB_PRESENT = False +from bs4.element import SoupStrainer +from bs4.testing import ( + HTML5TreeBuilderSmokeTest, + SoupTest, + skipIf, +) + +@skipIf( + not HTML5LIB_PRESENT, + "html5lib seems not to be present, not testing its tree builder.") +class HTML5LibBuilderSmokeTest(SoupTest, HTML5TreeBuilderSmokeTest): + """See ``HTML5TreeBuilderSmokeTest``.""" + + @property + def default_builder(self): + return HTML5TreeBuilder + + def test_soupstrainer(self): + # The html5lib tree builder does not support SoupStrainers. + strainer = SoupStrainer("b") + markup = "<p>A <b>bold</b> statement.</p>" + with warnings.catch_warnings(record=True) as w: + soup = self.soup(markup, parse_only=strainer) + self.assertEqual( + soup.decode(), self.document_for(markup)) + + self.assertTrue( + "the html5lib tree builder doesn't support parse_only" in + str(w[0].message)) + + def test_correctly_nested_tables(self): + """html5lib inserts <tbody> tags where other parsers don't.""" + markup = ('<table id="1">' + '<tr>' + "<td>Here's another table:" + '<table id="2">' + '<tr><td>foo</td></tr>' + '</table></td>') + + self.assertSoupEquals( + markup, + '<table id="1"><tbody><tr><td>Here\'s another table:' + '<table id="2"><tbody><tr><td>foo</td></tr></tbody></table>' + '</td></tr></tbody></table>') + + self.assertSoupEquals( + "<table><thead><tr><td>Foo</td></tr></thead>" + "<tbody><tr><td>Bar</td></tr></tbody>" + "<tfoot><tr><td>Baz</td></tr></tfoot></table>") + + def test_xml_declaration_followed_by_doctype(self): + markup = '''<?xml version="1.0" encoding="utf-8"?> +<!DOCTYPE html> +<html> + <head> + </head> + <body> + <p>foo</p> + </body> +</html>''' + soup = self.soup(markup) + # Verify that we can reach the <p> tag; this means the tree is connected. + self.assertEqual(b"<p>foo</p>", soup.p.encode()) + + def test_reparented_markup(self): + markup = '<p><em>foo</p>\n<p>bar<a></a></em></p>' + soup = self.soup(markup) + self.assertEqual("<body><p><em>foo</em></p><em>\n</em><p><em>bar<a></a></em></p></body>", soup.body.decode()) + self.assertEqual(2, len(soup.find_all('p'))) + + + def test_reparented_markup_ends_with_whitespace(self): + markup = '<p><em>foo</p>\n<p>bar<a></a></em></p>\n' + soup = self.soup(markup) + self.assertEqual("<body><p><em>foo</em></p><em>\n</em><p><em>bar<a></a></em></p>\n</body>", soup.body.decode()) + self.assertEqual(2, len(soup.find_all('p'))) + + def test_reparented_markup_containing_identical_whitespace_nodes(self): + """Verify that we keep the two whitespace nodes in this + document distinct when reparenting the adjacent <tbody> tags. + """ + markup = '<table> <tbody><tbody><ims></tbody> </table>' + soup = self.soup(markup) + space1, space2 = soup.find_all(string=' ') + tbody1, tbody2 = soup.find_all('tbody') + assert space1.next_element is tbody1 + assert tbody2.next_element is space2 + + def test_reparented_markup_containing_children(self): + markup = '<div><a>aftermath<p><noscript>target</noscript>aftermath</a></p></div>' + soup = self.soup(markup) + noscript = soup.noscript + self.assertEqual("target", noscript.next_element) + target = soup.find(string='target') + + # The 'aftermath' string was duplicated; we want the second one. + final_aftermath = soup.find_all(string='aftermath')[-1] + + # The <noscript> tag was moved beneath a copy of the <a> tag, + # but the 'target' string within is still connected to the + # (second) 'aftermath' string. + self.assertEqual(final_aftermath, target.next_element) + self.assertEqual(target, final_aftermath.previous_element) + + def test_processing_instruction(self): + """Processing instructions become comments.""" + markup = b"""<?PITarget PIContent?>""" + soup = self.soup(markup) + assert str(soup).startswith("<!--?PITarget PIContent?-->") + + def test_cloned_multivalue_node(self): + markup = b"""<a class="my_class"><p></a>""" + soup = self.soup(markup) + a1, a2 = soup.find_all('a') + self.assertEqual(a1, a2) + assert a1 is not a2 + + def test_foster_parenting(self): + markup = b"""<table><td></tbody>A""" + soup = self.soup(markup) + self.assertEqual("<body>A<table><tbody><tr><td></td></tr></tbody></table></body>", soup.body.decode()) + + def test_extraction(self): + """ + Test that extraction does not destroy the tree. + + https://bugs.launchpad.net/beautifulsoup/+bug/1782928 + """ + + markup = """ +<html><head></head> +<style> +</style><script></script><body><p>hello</p></body></html> +""" + soup = self.soup(markup) + [s.extract() for s in soup('script')] + [s.extract() for s in soup('style')] + + self.assertEqual(len(soup.find_all("p")), 1) + + def test_empty_comment(self): + """ + Test that empty comment does not break structure. + + https://bugs.launchpad.net/beautifulsoup/+bug/1806598 + """ + + markup = """ +<html> +<body> +<form> +<!----><input type="text"> +</form> +</body> +</html> +""" + soup = self.soup(markup) + inputs = [] + for form in soup.find_all('form'): + inputs.extend(form.find_all('input')) + self.assertEqual(len(inputs), 1) diff --git a/libs/bs4/tests/test_htmlparser.py b/libs/bs4/tests/test_htmlparser.py new file mode 100644 index 000000000..790489aa1 --- /dev/null +++ b/libs/bs4/tests/test_htmlparser.py @@ -0,0 +1,47 @@ +"""Tests to ensure that the html.parser tree builder generates good +trees.""" + +from pdb import set_trace +import pickle +from bs4.testing import SoupTest, HTMLTreeBuilderSmokeTest +from bs4.builder import HTMLParserTreeBuilder +from bs4.builder._htmlparser import BeautifulSoupHTMLParser + +class HTMLParserTreeBuilderSmokeTest(SoupTest, HTMLTreeBuilderSmokeTest): + + default_builder = HTMLParserTreeBuilder + + def test_namespaced_system_doctype(self): + # html.parser can't handle namespaced doctypes, so skip this one. + pass + + def test_namespaced_public_doctype(self): + # html.parser can't handle namespaced doctypes, so skip this one. + pass + + def test_builder_is_pickled(self): + """Unlike most tree builders, HTMLParserTreeBuilder and will + be restored after pickling. + """ + tree = self.soup("<a><b>foo</a>") + dumped = pickle.dumps(tree, 2) + loaded = pickle.loads(dumped) + self.assertTrue(isinstance(loaded.builder, type(tree.builder))) + + def test_redundant_empty_element_closing_tags(self): + self.assertSoupEquals('<br></br><br></br><br></br>', "<br/><br/><br/>") + self.assertSoupEquals('</br></br></br>', "") + + def test_empty_element(self): + # This verifies that any buffered data present when the parser + # finishes working is handled. + self.assertSoupEquals("foo &# bar", "foo &# bar") + + +class TestHTMLParserSubclass(SoupTest): + def test_error(self): + """Verify that our HTMLParser subclass implements error() in a way + that doesn't cause a crash. + """ + parser = BeautifulSoupHTMLParser() + parser.error("don't crash") diff --git a/libs/bs4/tests/test_lxml.py b/libs/bs4/tests/test_lxml.py new file mode 100644 index 000000000..29da71149 --- /dev/null +++ b/libs/bs4/tests/test_lxml.py @@ -0,0 +1,100 @@ +"""Tests to ensure that the lxml tree builder generates good trees.""" + +import re +import warnings + +try: + import lxml.etree + LXML_PRESENT = True + LXML_VERSION = lxml.etree.LXML_VERSION +except ImportError as e: + LXML_PRESENT = False + LXML_VERSION = (0,) + +if LXML_PRESENT: + from bs4.builder import LXMLTreeBuilder, LXMLTreeBuilderForXML + +from bs4 import ( + BeautifulSoup, + BeautifulStoneSoup, + ) +from bs4.element import Comment, Doctype, SoupStrainer +from bs4.testing import skipIf +from bs4.tests import test_htmlparser +from bs4.testing import ( + HTMLTreeBuilderSmokeTest, + XMLTreeBuilderSmokeTest, + SoupTest, + skipIf, +) + +@skipIf( + not LXML_PRESENT, + "lxml seems not to be present, not testing its tree builder.") +class LXMLTreeBuilderSmokeTest(SoupTest, HTMLTreeBuilderSmokeTest): + """See ``HTMLTreeBuilderSmokeTest``.""" + + @property + def default_builder(self): + return LXMLTreeBuilder + + def test_out_of_range_entity(self): + self.assertSoupEquals( + "<p>foo�bar</p>", "<p>foobar</p>") + self.assertSoupEquals( + "<p>foo�bar</p>", "<p>foobar</p>") + self.assertSoupEquals( + "<p>foo�bar</p>", "<p>foobar</p>") + + def test_entities_in_foreign_document_encoding(self): + # We can't implement this case correctly because by the time we + # hear about markup like "“", it's been (incorrectly) converted into + # a string like u'\x93' + pass + + # In lxml < 2.3.5, an empty doctype causes a segfault. Skip this + # test if an old version of lxml is installed. + + @skipIf( + not LXML_PRESENT or LXML_VERSION < (2,3,5,0), + "Skipping doctype test for old version of lxml to avoid segfault.") + def test_empty_doctype(self): + soup = self.soup("<!DOCTYPE>") + doctype = soup.contents[0] + self.assertEqual("", doctype.strip()) + + def test_beautifulstonesoup_is_xml_parser(self): + # Make sure that the deprecated BSS class uses an xml builder + # if one is installed. + with warnings.catch_warnings(record=True) as w: + soup = BeautifulStoneSoup("<b />") + self.assertEqual("<b/>", str(soup.b)) + self.assertTrue("BeautifulStoneSoup class is deprecated" in str(w[0].message)) + +@skipIf( + not LXML_PRESENT, + "lxml seems not to be present, not testing its XML tree builder.") +class LXMLXMLTreeBuilderSmokeTest(SoupTest, XMLTreeBuilderSmokeTest): + """See ``HTMLTreeBuilderSmokeTest``.""" + + @property + def default_builder(self): + return LXMLTreeBuilderForXML + + def test_namespace_indexing(self): + # We should not track un-prefixed namespaces as we can only hold one + # and it will be recognized as the default namespace by soupsieve, + # which may be confusing in some situations. When no namespace is provided + # for a selector, the default namespace (if defined) is assumed. + + soup = self.soup( + '<?xml version="1.1"?>\n' + '<root>' + '<tag xmlns="http://unprefixed-namespace.com">content</tag>' + '<prefix:tag xmlns:prefix="http://prefixed-namespace.com">content</tag>' + '</root>' + ) + self.assertEqual( + soup._namespaces, + {'xml': 'http://www.w3.org/XML/1998/namespace', 'prefix': 'http://prefixed-namespace.com'} + ) diff --git a/libs/bs4/tests/test_soup.py b/libs/bs4/tests/test_soup.py new file mode 100644 index 000000000..1eda9484b --- /dev/null +++ b/libs/bs4/tests/test_soup.py @@ -0,0 +1,567 @@ +# -*- coding: utf-8 -*- +"""Tests of Beautiful Soup as a whole.""" + +from pdb import set_trace +import logging +import unittest +import sys +import tempfile + +from bs4 import ( + BeautifulSoup, + BeautifulStoneSoup, +) +from bs4.element import ( + CharsetMetaAttributeValue, + ContentMetaAttributeValue, + SoupStrainer, + NamespacedAttribute, + ) +import bs4.dammit +from bs4.dammit import ( + EntitySubstitution, + UnicodeDammit, + EncodingDetector, +) +from bs4.testing import ( + default_builder, + SoupTest, + skipIf, +) +import warnings + +try: + from bs4.builder import LXMLTreeBuilder, LXMLTreeBuilderForXML + LXML_PRESENT = True +except ImportError as e: + LXML_PRESENT = False + +PYTHON_3_PRE_3_2 = (sys.version_info[0] == 3 and sys.version_info < (3,2)) + +class TestConstructor(SoupTest): + + def test_short_unicode_input(self): + data = "<h1>éé</h1>" + soup = self.soup(data) + self.assertEqual("éé", soup.h1.string) + + def test_embedded_null(self): + data = "<h1>foo\0bar</h1>" + soup = self.soup(data) + self.assertEqual("foo\0bar", soup.h1.string) + + def test_exclude_encodings(self): + utf8_data = "Räksmörgås".encode("utf-8") + soup = self.soup(utf8_data, exclude_encodings=["utf-8"]) + self.assertEqual("windows-1252", soup.original_encoding) + + def test_custom_builder_class(self): + # Verify that you can pass in a custom Builder class and + # it'll be instantiated with the appropriate keyword arguments. + class Mock(object): + def __init__(self, **kwargs): + self.called_with = kwargs + self.is_xml = True + def initialize_soup(self, soup): + pass + def prepare_markup(self, *args, **kwargs): + return '' + + kwargs = dict( + var="value", + # This is a deprecated BS3-era keyword argument, which + # will be stripped out. + convertEntities=True, + ) + with warnings.catch_warnings(record=True): + soup = BeautifulSoup('', builder=Mock, **kwargs) + assert isinstance(soup.builder, Mock) + self.assertEqual(dict(var="value"), soup.builder.called_with) + + # You can also instantiate the TreeBuilder yourself. In this + # case, that specific object is used and any keyword arguments + # to the BeautifulSoup constructor are ignored. + builder = Mock(**kwargs) + with warnings.catch_warnings(record=True) as w: + soup = BeautifulSoup( + '', builder=builder, ignored_value=True, + ) + msg = str(w[0].message) + assert msg.startswith("Keyword arguments to the BeautifulSoup constructor will be ignored.") + self.assertEqual(builder, soup.builder) + self.assertEqual(kwargs, builder.called_with) + + def test_cdata_list_attributes(self): + # Most attribute values are represented as scalars, but the + # HTML standard says that some attributes, like 'class' have + # space-separated lists as values. + markup = '<a id=" an id " class=" a class "></a>' + soup = self.soup(markup) + + # Note that the spaces are stripped for 'class' but not for 'id'. + a = soup.a + self.assertEqual(" an id ", a['id']) + self.assertEqual(["a", "class"], a['class']) + + # TreeBuilder takes an argument called 'mutli_valued_attributes' which lets + # you customize or disable this. As always, you can customize the TreeBuilder + # by passing in a keyword argument to the BeautifulSoup constructor. + soup = self.soup(markup, builder=default_builder, multi_valued_attributes=None) + self.assertEqual(" a class ", soup.a['class']) + + # Here are two ways of saying that `id` is a multi-valued + # attribute in this context, but 'class' is not. + for switcheroo in ({'*': 'id'}, {'a': 'id'}): + with warnings.catch_warnings(record=True) as w: + # This will create a warning about not explicitly + # specifying a parser, but we'll ignore it. + soup = self.soup(markup, builder=None, multi_valued_attributes=switcheroo) + a = soup.a + self.assertEqual(["an", "id"], a['id']) + self.assertEqual(" a class ", a['class']) + + +class TestWarnings(SoupTest): + + def _no_parser_specified(self, s, is_there=True): + v = s.startswith(BeautifulSoup.NO_PARSER_SPECIFIED_WARNING[:80]) + self.assertTrue(v) + + def test_warning_if_no_parser_specified(self): + with warnings.catch_warnings(record=True) as w: + soup = self.soup("<a><b></b></a>") + msg = str(w[0].message) + self._assert_no_parser_specified(msg) + + def test_warning_if_parser_specified_too_vague(self): + with warnings.catch_warnings(record=True) as w: + soup = self.soup("<a><b></b></a>", "html") + msg = str(w[0].message) + self._assert_no_parser_specified(msg) + + def test_no_warning_if_explicit_parser_specified(self): + with warnings.catch_warnings(record=True) as w: + soup = self.soup("<a><b></b></a>", "html.parser") + self.assertEqual([], w) + + def test_parseOnlyThese_renamed_to_parse_only(self): + with warnings.catch_warnings(record=True) as w: + soup = self.soup("<a><b></b></a>", parseOnlyThese=SoupStrainer("b")) + msg = str(w[0].message) + self.assertTrue("parseOnlyThese" in msg) + self.assertTrue("parse_only" in msg) + self.assertEqual(b"<b></b>", soup.encode()) + + def test_fromEncoding_renamed_to_from_encoding(self): + with warnings.catch_warnings(record=True) as w: + utf8 = b"\xc3\xa9" + soup = self.soup(utf8, fromEncoding="utf8") + msg = str(w[0].message) + self.assertTrue("fromEncoding" in msg) + self.assertTrue("from_encoding" in msg) + self.assertEqual("utf8", soup.original_encoding) + + def test_unrecognized_keyword_argument(self): + self.assertRaises( + TypeError, self.soup, "<a>", no_such_argument=True) + +class TestWarnings(SoupTest): + + def test_disk_file_warning(self): + filehandle = tempfile.NamedTemporaryFile() + filename = filehandle.name + try: + with warnings.catch_warnings(record=True) as w: + soup = self.soup(filename) + msg = str(w[0].message) + self.assertTrue("looks like a filename" in msg) + finally: + filehandle.close() + + # The file no longer exists, so Beautiful Soup will no longer issue the warning. + with warnings.catch_warnings(record=True) as w: + soup = self.soup(filename) + self.assertEqual(0, len(w)) + + def test_url_warning_with_bytes_url(self): + with warnings.catch_warnings(record=True) as warning_list: + soup = self.soup(b"http://www.crummybytes.com/") + # Be aware this isn't the only warning that can be raised during + # execution.. + self.assertTrue(any("looks like a URL" in str(w.message) + for w in warning_list)) + + def test_url_warning_with_unicode_url(self): + with warnings.catch_warnings(record=True) as warning_list: + # note - this url must differ from the bytes one otherwise + # python's warnings system swallows the second warning + soup = self.soup("http://www.crummyunicode.com/") + self.assertTrue(any("looks like a URL" in str(w.message) + for w in warning_list)) + + def test_url_warning_with_bytes_and_space(self): + with warnings.catch_warnings(record=True) as warning_list: + soup = self.soup(b"http://www.crummybytes.com/ is great") + self.assertFalse(any("looks like a URL" in str(w.message) + for w in warning_list)) + + def test_url_warning_with_unicode_and_space(self): + with warnings.catch_warnings(record=True) as warning_list: + soup = self.soup("http://www.crummyuncode.com/ is great") + self.assertFalse(any("looks like a URL" in str(w.message) + for w in warning_list)) + + +class TestSelectiveParsing(SoupTest): + + def test_parse_with_soupstrainer(self): + markup = "No<b>Yes</b><a>No<b>Yes <c>Yes</c></b>" + strainer = SoupStrainer("b") + soup = self.soup(markup, parse_only=strainer) + self.assertEqual(soup.encode(), b"<b>Yes</b><b>Yes <c>Yes</c></b>") + + +class TestEntitySubstitution(unittest.TestCase): + """Standalone tests of the EntitySubstitution class.""" + def setUp(self): + self.sub = EntitySubstitution + + def test_simple_html_substitution(self): + # Unicode characters corresponding to named HTML entites + # are substituted, and no others. + s = "foo\u2200\N{SNOWMAN}\u00f5bar" + self.assertEqual(self.sub.substitute_html(s), + "foo∀\N{SNOWMAN}õbar") + + def test_smart_quote_substitution(self): + # MS smart quotes are a common source of frustration, so we + # give them a special test. + quotes = b"\x91\x92foo\x93\x94" + dammit = UnicodeDammit(quotes) + self.assertEqual(self.sub.substitute_html(dammit.markup), + "‘’foo“”") + + def test_xml_converstion_includes_no_quotes_if_make_quoted_attribute_is_false(self): + s = 'Welcome to "my bar"' + self.assertEqual(self.sub.substitute_xml(s, False), s) + + def test_xml_attribute_quoting_normally_uses_double_quotes(self): + self.assertEqual(self.sub.substitute_xml("Welcome", True), + '"Welcome"') + self.assertEqual(self.sub.substitute_xml("Bob's Bar", True), + '"Bob\'s Bar"') + + def test_xml_attribute_quoting_uses_single_quotes_when_value_contains_double_quotes(self): + s = 'Welcome to "my bar"' + self.assertEqual(self.sub.substitute_xml(s, True), + "'Welcome to \"my bar\"'") + + def test_xml_attribute_quoting_escapes_single_quotes_when_value_contains_both_single_and_double_quotes(self): + s = 'Welcome to "Bob\'s Bar"' + self.assertEqual( + self.sub.substitute_xml(s, True), + '"Welcome to "Bob\'s Bar""') + + def test_xml_quotes_arent_escaped_when_value_is_not_being_quoted(self): + quoted = 'Welcome to "Bob\'s Bar"' + self.assertEqual(self.sub.substitute_xml(quoted), quoted) + + def test_xml_quoting_handles_angle_brackets(self): + self.assertEqual( + self.sub.substitute_xml("foo<bar>"), + "foo<bar>") + + def test_xml_quoting_handles_ampersands(self): + self.assertEqual(self.sub.substitute_xml("AT&T"), "AT&T") + + def test_xml_quoting_including_ampersands_when_they_are_part_of_an_entity(self): + self.assertEqual( + self.sub.substitute_xml("ÁT&T"), + "&Aacute;T&T") + + def test_xml_quoting_ignoring_ampersands_when_they_are_part_of_an_entity(self): + self.assertEqual( + self.sub.substitute_xml_containing_entities("ÁT&T"), + "ÁT&T") + + def test_quotes_not_html_substituted(self): + """There's no need to do this except inside attribute values.""" + text = 'Bob\'s "bar"' + self.assertEqual(self.sub.substitute_html(text), text) + + +class TestEncodingConversion(SoupTest): + # Test Beautiful Soup's ability to decode and encode from various + # encodings. + + def setUp(self): + super(TestEncodingConversion, self).setUp() + self.unicode_data = '<html><head><meta charset="utf-8"/></head><body><foo>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</foo></body></html>' + self.utf8_data = self.unicode_data.encode("utf-8") + # Just so you know what it looks like. + self.assertEqual( + self.utf8_data, + b'<html><head><meta charset="utf-8"/></head><body><foo>Sacr\xc3\xa9 bleu!</foo></body></html>') + + def test_ascii_in_unicode_out(self): + # ASCII input is converted to Unicode. The original_encoding + # attribute is set to 'utf-8', a superset of ASCII. + chardet = bs4.dammit.chardet_dammit + logging.disable(logging.WARNING) + try: + def noop(str): + return None + # Disable chardet, which will realize that the ASCII is ASCII. + bs4.dammit.chardet_dammit = noop + ascii = b"<foo>a</foo>" + soup_from_ascii = self.soup(ascii) + unicode_output = soup_from_ascii.decode() + self.assertTrue(isinstance(unicode_output, str)) + self.assertEqual(unicode_output, self.document_for(ascii.decode())) + self.assertEqual(soup_from_ascii.original_encoding.lower(), "utf-8") + finally: + logging.disable(logging.NOTSET) + bs4.dammit.chardet_dammit = chardet + + def test_unicode_in_unicode_out(self): + # Unicode input is left alone. The original_encoding attribute + # is not set. + soup_from_unicode = self.soup(self.unicode_data) + self.assertEqual(soup_from_unicode.decode(), self.unicode_data) + self.assertEqual(soup_from_unicode.foo.string, 'Sacr\xe9 bleu!') + self.assertEqual(soup_from_unicode.original_encoding, None) + + def test_utf8_in_unicode_out(self): + # UTF-8 input is converted to Unicode. The original_encoding + # attribute is set. + soup_from_utf8 = self.soup(self.utf8_data) + self.assertEqual(soup_from_utf8.decode(), self.unicode_data) + self.assertEqual(soup_from_utf8.foo.string, 'Sacr\xe9 bleu!') + + def test_utf8_out(self): + # The internal data structures can be encoded as UTF-8. + soup_from_unicode = self.soup(self.unicode_data) + self.assertEqual(soup_from_unicode.encode('utf-8'), self.utf8_data) + + @skipIf( + PYTHON_3_PRE_3_2, + "Bad HTMLParser detected; skipping test of non-ASCII characters in attribute name.") + def test_attribute_name_containing_unicode_characters(self): + markup = '<div><a \N{SNOWMAN}="snowman"></a></div>' + self.assertEqual(self.soup(markup).div.encode("utf8"), markup.encode("utf8")) + +class TestUnicodeDammit(unittest.TestCase): + """Standalone tests of UnicodeDammit.""" + + def test_unicode_input(self): + markup = "I'm already Unicode! \N{SNOWMAN}" + dammit = UnicodeDammit(markup) + self.assertEqual(dammit.unicode_markup, markup) + + def test_smart_quotes_to_unicode(self): + markup = b"<foo>\x91\x92\x93\x94</foo>" + dammit = UnicodeDammit(markup) + self.assertEqual( + dammit.unicode_markup, "<foo>\u2018\u2019\u201c\u201d</foo>") + + def test_smart_quotes_to_xml_entities(self): + markup = b"<foo>\x91\x92\x93\x94</foo>" + dammit = UnicodeDammit(markup, smart_quotes_to="xml") + self.assertEqual( + dammit.unicode_markup, "<foo>‘’“”</foo>") + + def test_smart_quotes_to_html_entities(self): + markup = b"<foo>\x91\x92\x93\x94</foo>" + dammit = UnicodeDammit(markup, smart_quotes_to="html") + self.assertEqual( + dammit.unicode_markup, "<foo>‘’“”</foo>") + + def test_smart_quotes_to_ascii(self): + markup = b"<foo>\x91\x92\x93\x94</foo>" + dammit = UnicodeDammit(markup, smart_quotes_to="ascii") + self.assertEqual( + dammit.unicode_markup, """<foo>''""</foo>""") + + def test_detect_utf8(self): + utf8 = b"Sacr\xc3\xa9 bleu! \xe2\x98\x83" + dammit = UnicodeDammit(utf8) + self.assertEqual(dammit.original_encoding.lower(), 'utf-8') + self.assertEqual(dammit.unicode_markup, 'Sacr\xe9 bleu! \N{SNOWMAN}') + + + def test_convert_hebrew(self): + hebrew = b"\xed\xe5\xec\xf9" + dammit = UnicodeDammit(hebrew, ["iso-8859-8"]) + self.assertEqual(dammit.original_encoding.lower(), 'iso-8859-8') + self.assertEqual(dammit.unicode_markup, '\u05dd\u05d5\u05dc\u05e9') + + def test_dont_see_smart_quotes_where_there_are_none(self): + utf_8 = b"\343\202\261\343\203\274\343\202\277\343\202\244 Watch" + dammit = UnicodeDammit(utf_8) + self.assertEqual(dammit.original_encoding.lower(), 'utf-8') + self.assertEqual(dammit.unicode_markup.encode("utf-8"), utf_8) + + def test_ignore_inappropriate_codecs(self): + utf8_data = "Räksmörgås".encode("utf-8") + dammit = UnicodeDammit(utf8_data, ["iso-8859-8"]) + self.assertEqual(dammit.original_encoding.lower(), 'utf-8') + + def test_ignore_invalid_codecs(self): + utf8_data = "Räksmörgås".encode("utf-8") + for bad_encoding in ['.utf8', '...', 'utF---16.!']: + dammit = UnicodeDammit(utf8_data, [bad_encoding]) + self.assertEqual(dammit.original_encoding.lower(), 'utf-8') + + def test_exclude_encodings(self): + # This is UTF-8. + utf8_data = "Räksmörgås".encode("utf-8") + + # But if we exclude UTF-8 from consideration, the guess is + # Windows-1252. + dammit = UnicodeDammit(utf8_data, exclude_encodings=["utf-8"]) + self.assertEqual(dammit.original_encoding.lower(), 'windows-1252') + + # And if we exclude that, there is no valid guess at all. + dammit = UnicodeDammit( + utf8_data, exclude_encodings=["utf-8", "windows-1252"]) + self.assertEqual(dammit.original_encoding, None) + + def test_encoding_detector_replaces_junk_in_encoding_name_with_replacement_character(self): + detected = EncodingDetector( + b'<?xml version="1.0" encoding="UTF-\xdb" ?>') + encodings = list(detected.encodings) + assert 'utf-\N{REPLACEMENT CHARACTER}' in encodings + + def test_detect_html5_style_meta_tag(self): + + for data in ( + b'<html><meta charset="euc-jp" /></html>', + b"<html><meta charset='euc-jp' /></html>", + b"<html><meta charset=euc-jp /></html>", + b"<html><meta charset=euc-jp/></html>"): + dammit = UnicodeDammit(data, is_html=True) + self.assertEqual( + "euc-jp", dammit.original_encoding) + + def test_last_ditch_entity_replacement(self): + # This is a UTF-8 document that contains bytestrings + # completely incompatible with UTF-8 (ie. encoded with some other + # encoding). + # + # Since there is no consistent encoding for the document, + # Unicode, Dammit will eventually encode the document as UTF-8 + # and encode the incompatible characters as REPLACEMENT + # CHARACTER. + # + # If chardet is installed, it will detect that the document + # can be converted into ISO-8859-1 without errors. This happens + # to be the wrong encoding, but it is a consistent encoding, so the + # code we're testing here won't run. + # + # So we temporarily disable chardet if it's present. + doc = b"""\357\273\277<?xml version="1.0" encoding="UTF-8"?> +<html><b>\330\250\330\252\330\261</b> +<i>\310\322\321\220\312\321\355\344</i></html>""" + chardet = bs4.dammit.chardet_dammit + logging.disable(logging.WARNING) + try: + def noop(str): + return None + bs4.dammit.chardet_dammit = noop + dammit = UnicodeDammit(doc) + self.assertEqual(True, dammit.contains_replacement_characters) + self.assertTrue("\ufffd" in dammit.unicode_markup) + + soup = BeautifulSoup(doc, "html.parser") + self.assertTrue(soup.contains_replacement_characters) + finally: + logging.disable(logging.NOTSET) + bs4.dammit.chardet_dammit = chardet + + def test_byte_order_mark_removed(self): + # A document written in UTF-16LE will have its byte order marker stripped. + data = b'\xff\xfe<\x00a\x00>\x00\xe1\x00\xe9\x00<\x00/\x00a\x00>\x00' + dammit = UnicodeDammit(data) + self.assertEqual("<a>áé</a>", dammit.unicode_markup) + self.assertEqual("utf-16le", dammit.original_encoding) + + def test_detwingle(self): + # Here's a UTF8 document. + utf8 = ("\N{SNOWMAN}" * 3).encode("utf8") + + # Here's a Windows-1252 document. + windows_1252 = ( + "\N{LEFT DOUBLE QUOTATION MARK}Hi, I like Windows!" + "\N{RIGHT DOUBLE QUOTATION MARK}").encode("windows_1252") + + # Through some unholy alchemy, they've been stuck together. + doc = utf8 + windows_1252 + utf8 + + # The document can't be turned into UTF-8: + self.assertRaises(UnicodeDecodeError, doc.decode, "utf8") + + # Unicode, Dammit thinks the whole document is Windows-1252, + # and decodes it into "☃☃☃“Hi, I like Windows!”☃☃☃" + + # But if we run it through fix_embedded_windows_1252, it's fixed: + + fixed = UnicodeDammit.detwingle(doc) + self.assertEqual( + "☃☃☃“Hi, I like Windows!”☃☃☃", fixed.decode("utf8")) + + def test_detwingle_ignores_multibyte_characters(self): + # Each of these characters has a UTF-8 representation ending + # in \x93. \x93 is a smart quote if interpreted as + # Windows-1252. But our code knows to skip over multibyte + # UTF-8 characters, so they'll survive the process unscathed. + for tricky_unicode_char in ( + "\N{LATIN SMALL LIGATURE OE}", # 2-byte char '\xc5\x93' + "\N{LATIN SUBSCRIPT SMALL LETTER X}", # 3-byte char '\xe2\x82\x93' + "\xf0\x90\x90\x93", # This is a CJK character, not sure which one. + ): + input = tricky_unicode_char.encode("utf8") + self.assertTrue(input.endswith(b'\x93')) + output = UnicodeDammit.detwingle(input) + self.assertEqual(output, input) + +class TestNamedspacedAttribute(SoupTest): + + def test_name_may_be_none(self): + a = NamespacedAttribute("xmlns", None) + self.assertEqual(a, "xmlns") + + def test_attribute_is_equivalent_to_colon_separated_string(self): + a = NamespacedAttribute("a", "b") + self.assertEqual("a:b", a) + + def test_attributes_are_equivalent_if_prefix_and_name_identical(self): + a = NamespacedAttribute("a", "b", "c") + b = NamespacedAttribute("a", "b", "c") + self.assertEqual(a, b) + + # The actual namespace is not considered. + c = NamespacedAttribute("a", "b", None) + self.assertEqual(a, c) + + # But name and prefix are important. + d = NamespacedAttribute("a", "z", "c") + self.assertNotEqual(a, d) + + e = NamespacedAttribute("z", "b", "c") + self.assertNotEqual(a, e) + + +class TestAttributeValueWithCharsetSubstitution(unittest.TestCase): + + def test_content_meta_attribute_value(self): + value = CharsetMetaAttributeValue("euc-jp") + self.assertEqual("euc-jp", value) + self.assertEqual("euc-jp", value.original_value) + self.assertEqual("utf8", value.encode("utf8")) + + + def test_content_meta_attribute_value(self): + value = ContentMetaAttributeValue("text/html; charset=euc-jp") + self.assertEqual("text/html; charset=euc-jp", value) + self.assertEqual("text/html; charset=euc-jp", value.original_value) + self.assertEqual("text/html; charset=utf8", value.encode("utf8")) diff --git a/libs/bs4/tests/test_tree.py b/libs/bs4/tests/test_tree.py new file mode 100644 index 000000000..3b4beeb8f --- /dev/null +++ b/libs/bs4/tests/test_tree.py @@ -0,0 +1,2205 @@ +# -*- coding: utf-8 -*- +"""Tests for Beautiful Soup's tree traversal methods. + +The tree traversal methods are the main advantage of using Beautiful +Soup over just using a parser. + +Different parsers will build different Beautiful Soup trees given the +same markup, but all Beautiful Soup trees can be traversed with the +methods tested here. +""" + +from pdb import set_trace +import copy +import pickle +import re +import warnings +from bs4 import BeautifulSoup +from bs4.builder import ( + builder_registry, + HTMLParserTreeBuilder, +) +from bs4.element import ( + PY3K, + CData, + Comment, + Declaration, + Doctype, + Formatter, + NavigableString, + SoupStrainer, + Tag, +) +from bs4.testing import ( + SoupTest, + skipIf, +) + +XML_BUILDER_PRESENT = (builder_registry.lookup("xml") is not None) +LXML_PRESENT = (builder_registry.lookup("lxml") is not None) + +class TreeTest(SoupTest): + + def assertSelects(self, tags, should_match): + """Make sure that the given tags have the correct text. + + This is used in tests that define a bunch of tags, each + containing a single string, and then select certain strings by + some mechanism. + """ + self.assertEqual([tag.string for tag in tags], should_match) + + def assertSelectsIDs(self, tags, should_match): + """Make sure that the given tags have the correct IDs. + + This is used in tests that define a bunch of tags, each + containing a single string, and then select certain strings by + some mechanism. + """ + self.assertEqual([tag['id'] for tag in tags], should_match) + + +class TestFind(TreeTest): + """Basic tests of the find() method. + + find() just calls find_all() with limit=1, so it's not tested all + that thouroughly here. + """ + + def test_find_tag(self): + soup = self.soup("<a>1</a><b>2</b><a>3</a><b>4</b>") + self.assertEqual(soup.find("b").string, "2") + + def test_unicode_text_find(self): + soup = self.soup('<h1>Räksmörgås</h1>') + self.assertEqual(soup.find(string='Räksmörgås'), 'Räksmörgås') + + def test_unicode_attribute_find(self): + soup = self.soup('<h1 id="Räksmörgås">here it is</h1>') + str(soup) + self.assertEqual("here it is", soup.find(id='Räksmörgås').text) + + + def test_find_everything(self): + """Test an optimization that finds all tags.""" + soup = self.soup("<a>foo</a><b>bar</b>") + self.assertEqual(2, len(soup.find_all())) + + def test_find_everything_with_name(self): + """Test an optimization that finds all tags with a given name.""" + soup = self.soup("<a>foo</a><b>bar</b><a>baz</a>") + self.assertEqual(2, len(soup.find_all('a'))) + +class TestFindAll(TreeTest): + """Basic tests of the find_all() method.""" + + def test_find_all_text_nodes(self): + """You can search the tree for text nodes.""" + soup = self.soup("<html>Foo<b>bar</b>\xbb</html>") + # Exact match. + self.assertEqual(soup.find_all(string="bar"), ["bar"]) + self.assertEqual(soup.find_all(text="bar"), ["bar"]) + # Match any of a number of strings. + self.assertEqual( + soup.find_all(text=["Foo", "bar"]), ["Foo", "bar"]) + # Match a regular expression. + self.assertEqual(soup.find_all(text=re.compile('.*')), + ["Foo", "bar", '\xbb']) + # Match anything. + self.assertEqual(soup.find_all(text=True), + ["Foo", "bar", '\xbb']) + + def test_find_all_limit(self): + """You can limit the number of items returned by find_all.""" + soup = self.soup("<a>1</a><a>2</a><a>3</a><a>4</a><a>5</a>") + self.assertSelects(soup.find_all('a', limit=3), ["1", "2", "3"]) + self.assertSelects(soup.find_all('a', limit=1), ["1"]) + self.assertSelects( + soup.find_all('a', limit=10), ["1", "2", "3", "4", "5"]) + + # A limit of 0 means no limit. + self.assertSelects( + soup.find_all('a', limit=0), ["1", "2", "3", "4", "5"]) + + def test_calling_a_tag_is_calling_findall(self): + soup = self.soup("<a>1</a><b>2<a id='foo'>3</a></b>") + self.assertSelects(soup('a', limit=1), ["1"]) + self.assertSelects(soup.b(id="foo"), ["3"]) + + def test_find_all_with_self_referential_data_structure_does_not_cause_infinite_recursion(self): + soup = self.soup("<a></a>") + # Create a self-referential list. + l = [] + l.append(l) + + # Without special code in _normalize_search_value, this would cause infinite + # recursion. + self.assertEqual([], soup.find_all(l)) + + def test_find_all_resultset(self): + """All find_all calls return a ResultSet""" + soup = self.soup("<a></a>") + result = soup.find_all("a") + self.assertTrue(hasattr(result, "source")) + + result = soup.find_all(True) + self.assertTrue(hasattr(result, "source")) + + result = soup.find_all(text="foo") + self.assertTrue(hasattr(result, "source")) + + +class TestFindAllBasicNamespaces(TreeTest): + + def test_find_by_namespaced_name(self): + soup = self.soup('<mathml:msqrt>4</mathml:msqrt><a svg:fill="red">') + self.assertEqual("4", soup.find("mathml:msqrt").string) + self.assertEqual("a", soup.find(attrs= { "svg:fill" : "red" }).name) + + +class TestFindAllByName(TreeTest): + """Test ways of finding tags by tag name.""" + + def setUp(self): + super(TreeTest, self).setUp() + self.tree = self.soup("""<a>First tag.</a> + <b>Second tag.</b> + <c>Third <a>Nested tag.</a> tag.</c>""") + + def test_find_all_by_tag_name(self): + # Find all the <a> tags. + self.assertSelects( + self.tree.find_all('a'), ['First tag.', 'Nested tag.']) + + def test_find_all_by_name_and_text(self): + self.assertSelects( + self.tree.find_all('a', text='First tag.'), ['First tag.']) + + self.assertSelects( + self.tree.find_all('a', text=True), ['First tag.', 'Nested tag.']) + + self.assertSelects( + self.tree.find_all('a', text=re.compile("tag")), + ['First tag.', 'Nested tag.']) + + + def test_find_all_on_non_root_element(self): + # You can call find_all on any node, not just the root. + self.assertSelects(self.tree.c.find_all('a'), ['Nested tag.']) + + def test_calling_element_invokes_find_all(self): + self.assertSelects(self.tree('a'), ['First tag.', 'Nested tag.']) + + def test_find_all_by_tag_strainer(self): + self.assertSelects( + self.tree.find_all(SoupStrainer('a')), + ['First tag.', 'Nested tag.']) + + def test_find_all_by_tag_names(self): + self.assertSelects( + self.tree.find_all(['a', 'b']), + ['First tag.', 'Second tag.', 'Nested tag.']) + + def test_find_all_by_tag_dict(self): + self.assertSelects( + self.tree.find_all({'a' : True, 'b' : True}), + ['First tag.', 'Second tag.', 'Nested tag.']) + + def test_find_all_by_tag_re(self): + self.assertSelects( + self.tree.find_all(re.compile('^[ab]$')), + ['First tag.', 'Second tag.', 'Nested tag.']) + + def test_find_all_with_tags_matching_method(self): + # You can define an oracle method that determines whether + # a tag matches the search. + def id_matches_name(tag): + return tag.name == tag.get('id') + + tree = self.soup("""<a id="a">Match 1.</a> + <a id="1">Does not match.</a> + <b id="b">Match 2.</a>""") + + self.assertSelects( + tree.find_all(id_matches_name), ["Match 1.", "Match 2."]) + + def test_find_with_multi_valued_attribute(self): + soup = self.soup( + "<div class='a b'>1</div><div class='a c'>2</div><div class='a d'>3</div>" + ) + r1 = soup.find('div', 'a d'); + r2 = soup.find('div', re.compile(r'a d')); + r3, r4 = soup.find_all('div', ['a b', 'a d']); + self.assertEqual('3', r1.string) + self.assertEqual('3', r2.string) + self.assertEqual('1', r3.string) + self.assertEqual('3', r4.string) + + +class TestFindAllByAttribute(TreeTest): + + def test_find_all_by_attribute_name(self): + # You can pass in keyword arguments to find_all to search by + # attribute. + tree = self.soup(""" + <a id="first">Matching a.</a> + <a id="second"> + Non-matching <b id="first">Matching b.</b>a. + </a>""") + self.assertSelects(tree.find_all(id='first'), + ["Matching a.", "Matching b."]) + + def test_find_all_by_utf8_attribute_value(self): + peace = "םולש".encode("utf8") + data = '<a title="םולש"></a>'.encode("utf8") + soup = self.soup(data) + self.assertEqual([soup.a], soup.find_all(title=peace)) + self.assertEqual([soup.a], soup.find_all(title=peace.decode("utf8"))) + self.assertEqual([soup.a], soup.find_all(title=[peace, "something else"])) + + def test_find_all_by_attribute_dict(self): + # You can pass in a dictionary as the argument 'attrs'. This + # lets you search for attributes like 'name' (a fixed argument + # to find_all) and 'class' (a reserved word in Python.) + tree = self.soup(""" + <a name="name1" class="class1">Name match.</a> + <a name="name2" class="class2">Class match.</a> + <a name="name3" class="class3">Non-match.</a> + <name1>A tag called 'name1'.</name1> + """) + + # This doesn't do what you want. + self.assertSelects(tree.find_all(name='name1'), + ["A tag called 'name1'."]) + # This does what you want. + self.assertSelects(tree.find_all(attrs={'name' : 'name1'}), + ["Name match."]) + + self.assertSelects(tree.find_all(attrs={'class' : 'class2'}), + ["Class match."]) + + def test_find_all_by_class(self): + tree = self.soup(""" + <a class="1">Class 1.</a> + <a class="2">Class 2.</a> + <b class="1">Class 1.</b> + <c class="3 4">Class 3 and 4.</c> + """) + + # Passing in the class_ keyword argument will search against + # the 'class' attribute. + self.assertSelects(tree.find_all('a', class_='1'), ['Class 1.']) + self.assertSelects(tree.find_all('c', class_='3'), ['Class 3 and 4.']) + self.assertSelects(tree.find_all('c', class_='4'), ['Class 3 and 4.']) + + # Passing in a string to 'attrs' will also search the CSS class. + self.assertSelects(tree.find_all('a', '1'), ['Class 1.']) + self.assertSelects(tree.find_all(attrs='1'), ['Class 1.', 'Class 1.']) + self.assertSelects(tree.find_all('c', '3'), ['Class 3 and 4.']) + self.assertSelects(tree.find_all('c', '4'), ['Class 3 and 4.']) + + def test_find_by_class_when_multiple_classes_present(self): + tree = self.soup("<gar class='foo bar'>Found it</gar>") + + f = tree.find_all("gar", class_=re.compile("o")) + self.assertSelects(f, ["Found it"]) + + f = tree.find_all("gar", class_=re.compile("a")) + self.assertSelects(f, ["Found it"]) + + # If the search fails to match the individual strings "foo" and "bar", + # it will be tried against the combined string "foo bar". + f = tree.find_all("gar", class_=re.compile("o b")) + self.assertSelects(f, ["Found it"]) + + def test_find_all_with_non_dictionary_for_attrs_finds_by_class(self): + soup = self.soup("<a class='bar'>Found it</a>") + + self.assertSelects(soup.find_all("a", re.compile("ba")), ["Found it"]) + + def big_attribute_value(value): + return len(value) > 3 + + self.assertSelects(soup.find_all("a", big_attribute_value), []) + + def small_attribute_value(value): + return len(value) <= 3 + + self.assertSelects( + soup.find_all("a", small_attribute_value), ["Found it"]) + + def test_find_all_with_string_for_attrs_finds_multiple_classes(self): + soup = self.soup('<a class="foo bar"></a><a class="foo"></a>') + a, a2 = soup.find_all("a") + self.assertEqual([a, a2], soup.find_all("a", "foo")) + self.assertEqual([a], soup.find_all("a", "bar")) + + # If you specify the class as a string that contains a + # space, only that specific value will be found. + self.assertEqual([a], soup.find_all("a", class_="foo bar")) + self.assertEqual([a], soup.find_all("a", "foo bar")) + self.assertEqual([], soup.find_all("a", "bar foo")) + + def test_find_all_by_attribute_soupstrainer(self): + tree = self.soup(""" + <a id="first">Match.</a> + <a id="second">Non-match.</a>""") + + strainer = SoupStrainer(attrs={'id' : 'first'}) + self.assertSelects(tree.find_all(strainer), ['Match.']) + + def test_find_all_with_missing_attribute(self): + # You can pass in None as the value of an attribute to find_all. + # This will match tags that do not have that attribute set. + tree = self.soup("""<a id="1">ID present.</a> + <a>No ID present.</a> + <a id="">ID is empty.</a>""") + self.assertSelects(tree.find_all('a', id=None), ["No ID present."]) + + def test_find_all_with_defined_attribute(self): + # You can pass in None as the value of an attribute to find_all. + # This will match tags that have that attribute set to any value. + tree = self.soup("""<a id="1">ID present.</a> + <a>No ID present.</a> + <a id="">ID is empty.</a>""") + self.assertSelects( + tree.find_all(id=True), ["ID present.", "ID is empty."]) + + def test_find_all_with_numeric_attribute(self): + # If you search for a number, it's treated as a string. + tree = self.soup("""<a id=1>Unquoted attribute.</a> + <a id="1">Quoted attribute.</a>""") + + expected = ["Unquoted attribute.", "Quoted attribute."] + self.assertSelects(tree.find_all(id=1), expected) + self.assertSelects(tree.find_all(id="1"), expected) + + def test_find_all_with_list_attribute_values(self): + # You can pass a list of attribute values instead of just one, + # and you'll get tags that match any of the values. + tree = self.soup("""<a id="1">1</a> + <a id="2">2</a> + <a id="3">3</a> + <a>No ID.</a>""") + self.assertSelects(tree.find_all(id=["1", "3", "4"]), + ["1", "3"]) + + def test_find_all_with_regular_expression_attribute_value(self): + # You can pass a regular expression as an attribute value, and + # you'll get tags whose values for that attribute match the + # regular expression. + tree = self.soup("""<a id="a">One a.</a> + <a id="aa">Two as.</a> + <a id="ab">Mixed as and bs.</a> + <a id="b">One b.</a> + <a>No ID.</a>""") + + self.assertSelects(tree.find_all(id=re.compile("^a+$")), + ["One a.", "Two as."]) + + def test_find_by_name_and_containing_string(self): + soup = self.soup("<b>foo</b><b>bar</b><a>foo</a>") + a = soup.a + + self.assertEqual([a], soup.find_all("a", text="foo")) + self.assertEqual([], soup.find_all("a", text="bar")) + self.assertEqual([], soup.find_all("a", text="bar")) + + def test_find_by_name_and_containing_string_when_string_is_buried(self): + soup = self.soup("<a>foo</a><a><b><c>foo</c></b></a>") + self.assertEqual(soup.find_all("a"), soup.find_all("a", text="foo")) + + def test_find_by_attribute_and_containing_string(self): + soup = self.soup('<b id="1">foo</b><a id="2">foo</a>') + a = soup.a + + self.assertEqual([a], soup.find_all(id=2, text="foo")) + self.assertEqual([], soup.find_all(id=1, text="bar")) + + +class TestSmooth(TreeTest): + """Test Tag.smooth.""" + + def test_smooth(self): + soup = self.soup("<div>a</div>") + div = soup.div + div.append("b") + div.append("c") + div.append(Comment("Comment 1")) + div.append(Comment("Comment 2")) + div.append("d") + builder = self.default_builder() + span = Tag(soup, builder, 'span') + span.append('1') + span.append('2') + div.append(span) + + # At this point the tree has a bunch of adjacent + # NavigableStrings. This is normal, but it has no meaning in + # terms of HTML, so we may want to smooth things out for + # output. + + # Since the <span> tag has two children, its .string is None. + self.assertEqual(None, div.span.string) + + self.assertEqual(7, len(div.contents)) + div.smooth() + self.assertEqual(5, len(div.contents)) + + # The three strings at the beginning of div.contents have been + # merged into on string. + # + self.assertEqual('abc', div.contents[0]) + + # The call is recursive -- the <span> tag was also smoothed. + self.assertEqual('12', div.span.string) + + # The two comments have _not_ been merged, even though + # comments are strings. Merging comments would change the + # meaning of the HTML. + self.assertEqual('Comment 1', div.contents[1]) + self.assertEqual('Comment 2', div.contents[2]) + + +class TestIndex(TreeTest): + """Test Tag.index""" + def test_index(self): + tree = self.soup("""<div> + <a>Identical</a> + <b>Not identical</b> + <a>Identical</a> + + <c><d>Identical with child</d></c> + <b>Also not identical</b> + <c><d>Identical with child</d></c> + </div>""") + div = tree.div + for i, element in enumerate(div.contents): + self.assertEqual(i, div.index(element)) + self.assertRaises(ValueError, tree.index, 1) + + +class TestParentOperations(TreeTest): + """Test navigation and searching through an element's parents.""" + + def setUp(self): + super(TestParentOperations, self).setUp() + self.tree = self.soup('''<ul id="empty"></ul> + <ul id="top"> + <ul id="middle"> + <ul id="bottom"> + <b>Start here</b> + </ul> + </ul>''') + self.start = self.tree.b + + + def test_parent(self): + self.assertEqual(self.start.parent['id'], 'bottom') + self.assertEqual(self.start.parent.parent['id'], 'middle') + self.assertEqual(self.start.parent.parent.parent['id'], 'top') + + def test_parent_of_top_tag_is_soup_object(self): + top_tag = self.tree.contents[0] + self.assertEqual(top_tag.parent, self.tree) + + def test_soup_object_has_no_parent(self): + self.assertEqual(None, self.tree.parent) + + def test_find_parents(self): + self.assertSelectsIDs( + self.start.find_parents('ul'), ['bottom', 'middle', 'top']) + self.assertSelectsIDs( + self.start.find_parents('ul', id="middle"), ['middle']) + + def test_find_parent(self): + self.assertEqual(self.start.find_parent('ul')['id'], 'bottom') + self.assertEqual(self.start.find_parent('ul', id='top')['id'], 'top') + + def test_parent_of_text_element(self): + text = self.tree.find(text="Start here") + self.assertEqual(text.parent.name, 'b') + + def test_text_element_find_parent(self): + text = self.tree.find(text="Start here") + self.assertEqual(text.find_parent('ul')['id'], 'bottom') + + def test_parent_generator(self): + parents = [parent['id'] for parent in self.start.parents + if parent is not None and 'id' in parent.attrs] + self.assertEqual(parents, ['bottom', 'middle', 'top']) + + +class ProximityTest(TreeTest): + + def setUp(self): + super(TreeTest, self).setUp() + self.tree = self.soup( + '<html id="start"><head></head><body><b id="1">One</b><b id="2">Two</b><b id="3">Three</b></body></html>') + + +class TestNextOperations(ProximityTest): + + def setUp(self): + super(TestNextOperations, self).setUp() + self.start = self.tree.b + + def test_next(self): + self.assertEqual(self.start.next_element, "One") + self.assertEqual(self.start.next_element.next_element['id'], "2") + + def test_next_of_last_item_is_none(self): + last = self.tree.find(text="Three") + self.assertEqual(last.next_element, None) + + def test_next_of_root_is_none(self): + # The document root is outside the next/previous chain. + self.assertEqual(self.tree.next_element, None) + + def test_find_all_next(self): + self.assertSelects(self.start.find_all_next('b'), ["Two", "Three"]) + self.start.find_all_next(id=3) + self.assertSelects(self.start.find_all_next(id=3), ["Three"]) + + def test_find_next(self): + self.assertEqual(self.start.find_next('b')['id'], '2') + self.assertEqual(self.start.find_next(text="Three"), "Three") + + def test_find_next_for_text_element(self): + text = self.tree.find(text="One") + self.assertEqual(text.find_next("b").string, "Two") + self.assertSelects(text.find_all_next("b"), ["Two", "Three"]) + + def test_next_generator(self): + start = self.tree.find(text="Two") + successors = [node for node in start.next_elements] + # There are two successors: the final <b> tag and its text contents. + tag, contents = successors + self.assertEqual(tag['id'], '3') + self.assertEqual(contents, "Three") + +class TestPreviousOperations(ProximityTest): + + def setUp(self): + super(TestPreviousOperations, self).setUp() + self.end = self.tree.find(text="Three") + + def test_previous(self): + self.assertEqual(self.end.previous_element['id'], "3") + self.assertEqual(self.end.previous_element.previous_element, "Two") + + def test_previous_of_first_item_is_none(self): + first = self.tree.find('html') + self.assertEqual(first.previous_element, None) + + def test_previous_of_root_is_none(self): + # The document root is outside the next/previous chain. + # XXX This is broken! + #self.assertEqual(self.tree.previous_element, None) + pass + + def test_find_all_previous(self): + # The <b> tag containing the "Three" node is the predecessor + # of the "Three" node itself, which is why "Three" shows up + # here. + self.assertSelects( + self.end.find_all_previous('b'), ["Three", "Two", "One"]) + self.assertSelects(self.end.find_all_previous(id=1), ["One"]) + + def test_find_previous(self): + self.assertEqual(self.end.find_previous('b')['id'], '3') + self.assertEqual(self.end.find_previous(text="One"), "One") + + def test_find_previous_for_text_element(self): + text = self.tree.find(text="Three") + self.assertEqual(text.find_previous("b").string, "Three") + self.assertSelects( + text.find_all_previous("b"), ["Three", "Two", "One"]) + + def test_previous_generator(self): + start = self.tree.find(text="One") + predecessors = [node for node in start.previous_elements] + + # There are four predecessors: the <b> tag containing "One" + # the <body> tag, the <head> tag, and the <html> tag. + b, body, head, html = predecessors + self.assertEqual(b['id'], '1') + self.assertEqual(body.name, "body") + self.assertEqual(head.name, "head") + self.assertEqual(html.name, "html") + + +class SiblingTest(TreeTest): + + def setUp(self): + super(SiblingTest, self).setUp() + markup = '''<html> + <span id="1"> + <span id="1.1"></span> + </span> + <span id="2"> + <span id="2.1"></span> + </span> + <span id="3"> + <span id="3.1"></span> + </span> + <span id="4"></span> + </html>''' + # All that whitespace looks good but makes the tests more + # difficult. Get rid of it. + markup = re.compile(r"\n\s*").sub("", markup) + self.tree = self.soup(markup) + + +class TestNextSibling(SiblingTest): + + def setUp(self): + super(TestNextSibling, self).setUp() + self.start = self.tree.find(id="1") + + def test_next_sibling_of_root_is_none(self): + self.assertEqual(self.tree.next_sibling, None) + + def test_next_sibling(self): + self.assertEqual(self.start.next_sibling['id'], '2') + self.assertEqual(self.start.next_sibling.next_sibling['id'], '3') + + # Note the difference between next_sibling and next_element. + self.assertEqual(self.start.next_element['id'], '1.1') + + def test_next_sibling_may_not_exist(self): + self.assertEqual(self.tree.html.next_sibling, None) + + nested_span = self.tree.find(id="1.1") + self.assertEqual(nested_span.next_sibling, None) + + last_span = self.tree.find(id="4") + self.assertEqual(last_span.next_sibling, None) + + def test_find_next_sibling(self): + self.assertEqual(self.start.find_next_sibling('span')['id'], '2') + + def test_next_siblings(self): + self.assertSelectsIDs(self.start.find_next_siblings("span"), + ['2', '3', '4']) + + self.assertSelectsIDs(self.start.find_next_siblings(id='3'), ['3']) + + def test_next_sibling_for_text_element(self): + soup = self.soup("Foo<b>bar</b>baz") + start = soup.find(text="Foo") + self.assertEqual(start.next_sibling.name, 'b') + self.assertEqual(start.next_sibling.next_sibling, 'baz') + + self.assertSelects(start.find_next_siblings('b'), ['bar']) + self.assertEqual(start.find_next_sibling(text="baz"), "baz") + self.assertEqual(start.find_next_sibling(text="nonesuch"), None) + + +class TestPreviousSibling(SiblingTest): + + def setUp(self): + super(TestPreviousSibling, self).setUp() + self.end = self.tree.find(id="4") + + def test_previous_sibling_of_root_is_none(self): + self.assertEqual(self.tree.previous_sibling, None) + + def test_previous_sibling(self): + self.assertEqual(self.end.previous_sibling['id'], '3') + self.assertEqual(self.end.previous_sibling.previous_sibling['id'], '2') + + # Note the difference between previous_sibling and previous_element. + self.assertEqual(self.end.previous_element['id'], '3.1') + + def test_previous_sibling_may_not_exist(self): + self.assertEqual(self.tree.html.previous_sibling, None) + + nested_span = self.tree.find(id="1.1") + self.assertEqual(nested_span.previous_sibling, None) + + first_span = self.tree.find(id="1") + self.assertEqual(first_span.previous_sibling, None) + + def test_find_previous_sibling(self): + self.assertEqual(self.end.find_previous_sibling('span')['id'], '3') + + def test_previous_siblings(self): + self.assertSelectsIDs(self.end.find_previous_siblings("span"), + ['3', '2', '1']) + + self.assertSelectsIDs(self.end.find_previous_siblings(id='1'), ['1']) + + def test_previous_sibling_for_text_element(self): + soup = self.soup("Foo<b>bar</b>baz") + start = soup.find(text="baz") + self.assertEqual(start.previous_sibling.name, 'b') + self.assertEqual(start.previous_sibling.previous_sibling, 'Foo') + + self.assertSelects(start.find_previous_siblings('b'), ['bar']) + self.assertEqual(start.find_previous_sibling(text="Foo"), "Foo") + self.assertEqual(start.find_previous_sibling(text="nonesuch"), None) + + +class TestTagCreation(SoupTest): + """Test the ability to create new tags.""" + def test_new_tag(self): + soup = self.soup("") + new_tag = soup.new_tag("foo", bar="baz", attrs={"name": "a name"}) + self.assertTrue(isinstance(new_tag, Tag)) + self.assertEqual("foo", new_tag.name) + self.assertEqual(dict(bar="baz", name="a name"), new_tag.attrs) + self.assertEqual(None, new_tag.parent) + + def test_tag_inherits_self_closing_rules_from_builder(self): + if XML_BUILDER_PRESENT: + xml_soup = BeautifulSoup("", "lxml-xml") + xml_br = xml_soup.new_tag("br") + xml_p = xml_soup.new_tag("p") + + # Both the <br> and <p> tag are empty-element, just because + # they have no contents. + self.assertEqual(b"<br/>", xml_br.encode()) + self.assertEqual(b"<p/>", xml_p.encode()) + + html_soup = BeautifulSoup("", "html.parser") + html_br = html_soup.new_tag("br") + html_p = html_soup.new_tag("p") + + # The HTML builder users HTML's rules about which tags are + # empty-element tags, and the new tags reflect these rules. + self.assertEqual(b"<br/>", html_br.encode()) + self.assertEqual(b"<p></p>", html_p.encode()) + + def test_new_string_creates_navigablestring(self): + soup = self.soup("") + s = soup.new_string("foo") + self.assertEqual("foo", s) + self.assertTrue(isinstance(s, NavigableString)) + + def test_new_string_can_create_navigablestring_subclass(self): + soup = self.soup("") + s = soup.new_string("foo", Comment) + self.assertEqual("foo", s) + self.assertTrue(isinstance(s, Comment)) + +class TestTreeModification(SoupTest): + + def test_attribute_modification(self): + soup = self.soup('<a id="1"></a>') + soup.a['id'] = 2 + self.assertEqual(soup.decode(), self.document_for('<a id="2"></a>')) + del(soup.a['id']) + self.assertEqual(soup.decode(), self.document_for('<a></a>')) + soup.a['id2'] = 'foo' + self.assertEqual(soup.decode(), self.document_for('<a id2="foo"></a>')) + + def test_new_tag_creation(self): + builder = builder_registry.lookup('html')() + soup = self.soup("<body></body>", builder=builder) + a = Tag(soup, builder, 'a') + ol = Tag(soup, builder, 'ol') + a['href'] = 'http://foo.com/' + soup.body.insert(0, a) + soup.body.insert(1, ol) + self.assertEqual( + soup.body.encode(), + b'<body><a href="http://foo.com/"></a><ol></ol></body>') + + def test_append_to_contents_moves_tag(self): + doc = """<p id="1">Don't leave me <b>here</b>.</p> + <p id="2">Don\'t leave!</p>""" + soup = self.soup(doc) + second_para = soup.find(id='2') + bold = soup.b + + # Move the <b> tag to the end of the second paragraph. + soup.find(id='2').append(soup.b) + + # The <b> tag is now a child of the second paragraph. + self.assertEqual(bold.parent, second_para) + + self.assertEqual( + soup.decode(), self.document_for( + '<p id="1">Don\'t leave me .</p>\n' + '<p id="2">Don\'t leave!<b>here</b></p>')) + + def test_replace_with_returns_thing_that_was_replaced(self): + text = "<a></a><b><c></c></b>" + soup = self.soup(text) + a = soup.a + new_a = a.replace_with(soup.c) + self.assertEqual(a, new_a) + + def test_unwrap_returns_thing_that_was_replaced(self): + text = "<a><b></b><c></c></a>" + soup = self.soup(text) + a = soup.a + new_a = a.unwrap() + self.assertEqual(a, new_a) + + def test_replace_with_and_unwrap_give_useful_exception_when_tag_has_no_parent(self): + soup = self.soup("<a><b>Foo</b></a><c>Bar</c>") + a = soup.a + a.extract() + self.assertEqual(None, a.parent) + self.assertRaises(ValueError, a.unwrap) + self.assertRaises(ValueError, a.replace_with, soup.c) + + def test_replace_tag_with_itself(self): + text = "<a><b></b><c>Foo<d></d></c></a><a><e></e></a>" + soup = self.soup(text) + c = soup.c + soup.c.replace_with(c) + self.assertEqual(soup.decode(), self.document_for(text)) + + def test_replace_tag_with_its_parent_raises_exception(self): + text = "<a><b></b></a>" + soup = self.soup(text) + self.assertRaises(ValueError, soup.b.replace_with, soup.a) + + def test_insert_tag_into_itself_raises_exception(self): + text = "<a><b></b></a>" + soup = self.soup(text) + self.assertRaises(ValueError, soup.a.insert, 0, soup.a) + + def test_insert_beautifulsoup_object_inserts_children(self): + """Inserting one BeautifulSoup object into another actually inserts all + of its children -- you'll never combine BeautifulSoup objects. + """ + soup = self.soup("<p>And now, a word:</p><p>And we're back.</p>") + + text = "<p>p2</p><p>p3</p>" + to_insert = self.soup(text) + soup.insert(1, to_insert) + + for i in soup.descendants: + assert not isinstance(i, BeautifulSoup) + + p1, p2, p3, p4 = list(soup.children) + self.assertEqual("And now, a word:", p1.string) + self.assertEqual("p2", p2.string) + self.assertEqual("p3", p3.string) + self.assertEqual("And we're back.", p4.string) + + + def test_replace_with_maintains_next_element_throughout(self): + soup = self.soup('<p><a>one</a><b>three</b></p>') + a = soup.a + b = a.contents[0] + # Make it so the <a> tag has two text children. + a.insert(1, "two") + + # Now replace each one with the empty string. + left, right = a.contents + left.replaceWith('') + right.replaceWith('') + + # The <b> tag is still connected to the tree. + self.assertEqual("three", soup.b.string) + + def test_replace_final_node(self): + soup = self.soup("<b>Argh!</b>") + soup.find(text="Argh!").replace_with("Hooray!") + new_text = soup.find(text="Hooray!") + b = soup.b + self.assertEqual(new_text.previous_element, b) + self.assertEqual(new_text.parent, b) + self.assertEqual(new_text.previous_element.next_element, new_text) + self.assertEqual(new_text.next_element, None) + + def test_consecutive_text_nodes(self): + # A builder should never create two consecutive text nodes, + # but if you insert one next to another, Beautiful Soup will + # handle it correctly. + soup = self.soup("<a><b>Argh!</b><c></c></a>") + soup.b.insert(1, "Hooray!") + + self.assertEqual( + soup.decode(), self.document_for( + "<a><b>Argh!Hooray!</b><c></c></a>")) + + new_text = soup.find(text="Hooray!") + self.assertEqual(new_text.previous_element, "Argh!") + self.assertEqual(new_text.previous_element.next_element, new_text) + + self.assertEqual(new_text.previous_sibling, "Argh!") + self.assertEqual(new_text.previous_sibling.next_sibling, new_text) + + self.assertEqual(new_text.next_sibling, None) + self.assertEqual(new_text.next_element, soup.c) + + def test_insert_string(self): + soup = self.soup("<a></a>") + soup.a.insert(0, "bar") + soup.a.insert(0, "foo") + # The string were added to the tag. + self.assertEqual(["foo", "bar"], soup.a.contents) + # And they were converted to NavigableStrings. + self.assertEqual(soup.a.contents[0].next_element, "bar") + + def test_insert_tag(self): + builder = self.default_builder() + soup = self.soup( + "<a><b>Find</b><c>lady!</c><d></d></a>", builder=builder) + magic_tag = Tag(soup, builder, 'magictag') + magic_tag.insert(0, "the") + soup.a.insert(1, magic_tag) + + self.assertEqual( + soup.decode(), self.document_for( + "<a><b>Find</b><magictag>the</magictag><c>lady!</c><d></d></a>")) + + # Make sure all the relationships are hooked up correctly. + b_tag = soup.b + self.assertEqual(b_tag.next_sibling, magic_tag) + self.assertEqual(magic_tag.previous_sibling, b_tag) + + find = b_tag.find(text="Find") + self.assertEqual(find.next_element, magic_tag) + self.assertEqual(magic_tag.previous_element, find) + + c_tag = soup.c + self.assertEqual(magic_tag.next_sibling, c_tag) + self.assertEqual(c_tag.previous_sibling, magic_tag) + + the = magic_tag.find(text="the") + self.assertEqual(the.parent, magic_tag) + self.assertEqual(the.next_element, c_tag) + self.assertEqual(c_tag.previous_element, the) + + def test_append_child_thats_already_at_the_end(self): + data = "<a><b></b></a>" + soup = self.soup(data) + soup.a.append(soup.b) + self.assertEqual(data, soup.decode()) + + def test_extend(self): + data = "<a><b><c><d><e><f><g></g></f></e></d></c></b></a>" + soup = self.soup(data) + l = [soup.g, soup.f, soup.e, soup.d, soup.c, soup.b] + soup.a.extend(l) + self.assertEqual("<a><g></g><f></f><e></e><d></d><c></c><b></b></a>", soup.decode()) + + def test_move_tag_to_beginning_of_parent(self): + data = "<a><b></b><c></c><d></d></a>" + soup = self.soup(data) + soup.a.insert(0, soup.d) + self.assertEqual("<a><d></d><b></b><c></c></a>", soup.decode()) + + def test_insert_works_on_empty_element_tag(self): + # This is a little strange, since most HTML parsers don't allow + # markup like this to come through. But in general, we don't + # know what the parser would or wouldn't have allowed, so + # I'm letting this succeed for now. + soup = self.soup("<br/>") + soup.br.insert(1, "Contents") + self.assertEqual(str(soup.br), "<br>Contents</br>") + + def test_insert_before(self): + soup = self.soup("<a>foo</a><b>bar</b>") + soup.b.insert_before("BAZ") + soup.a.insert_before("QUUX") + self.assertEqual( + soup.decode(), self.document_for("QUUX<a>foo</a>BAZ<b>bar</b>")) + + soup.a.insert_before(soup.b) + self.assertEqual( + soup.decode(), self.document_for("QUUX<b>bar</b><a>foo</a>BAZ")) + + # Can't insert an element before itself. + b = soup.b + self.assertRaises(ValueError, b.insert_before, b) + + # Can't insert before if an element has no parent. + b.extract() + self.assertRaises(ValueError, b.insert_before, "nope") + + # Can insert an identical element + soup = self.soup("<a>") + soup.a.insert_before(soup.new_tag("a")) + + def test_insert_multiple_before(self): + soup = self.soup("<a>foo</a><b>bar</b>") + soup.b.insert_before("BAZ", " ", "QUUX") + soup.a.insert_before("QUUX", " ", "BAZ") + self.assertEqual( + soup.decode(), self.document_for("QUUX BAZ<a>foo</a>BAZ QUUX<b>bar</b>")) + + soup.a.insert_before(soup.b, "FOO") + self.assertEqual( + soup.decode(), self.document_for("QUUX BAZ<b>bar</b>FOO<a>foo</a>BAZ QUUX")) + + def test_insert_after(self): + soup = self.soup("<a>foo</a><b>bar</b>") + soup.b.insert_after("BAZ") + soup.a.insert_after("QUUX") + self.assertEqual( + soup.decode(), self.document_for("<a>foo</a>QUUX<b>bar</b>BAZ")) + soup.b.insert_after(soup.a) + self.assertEqual( + soup.decode(), self.document_for("QUUX<b>bar</b><a>foo</a>BAZ")) + + # Can't insert an element after itself. + b = soup.b + self.assertRaises(ValueError, b.insert_after, b) + + # Can't insert after if an element has no parent. + b.extract() + self.assertRaises(ValueError, b.insert_after, "nope") + + # Can insert an identical element + soup = self.soup("<a>") + soup.a.insert_before(soup.new_tag("a")) + + def test_insert_multiple_after(self): + soup = self.soup("<a>foo</a><b>bar</b>") + soup.b.insert_after("BAZ", " ", "QUUX") + soup.a.insert_after("QUUX", " ", "BAZ") + self.assertEqual( + soup.decode(), self.document_for("<a>foo</a>QUUX BAZ<b>bar</b>BAZ QUUX")) + soup.b.insert_after(soup.a, "FOO ") + self.assertEqual( + soup.decode(), self.document_for("QUUX BAZ<b>bar</b><a>foo</a>FOO BAZ QUUX")) + + def test_insert_after_raises_exception_if_after_has_no_meaning(self): + soup = self.soup("") + tag = soup.new_tag("a") + string = soup.new_string("") + self.assertRaises(ValueError, string.insert_after, tag) + self.assertRaises(NotImplementedError, soup.insert_after, tag) + self.assertRaises(ValueError, tag.insert_after, tag) + + def test_insert_before_raises_notimplementederror_if_before_has_no_meaning(self): + soup = self.soup("") + tag = soup.new_tag("a") + string = soup.new_string("") + self.assertRaises(ValueError, string.insert_before, tag) + self.assertRaises(NotImplementedError, soup.insert_before, tag) + self.assertRaises(ValueError, tag.insert_before, tag) + + def test_replace_with(self): + soup = self.soup( + "<p>There's <b>no</b> business like <b>show</b> business</p>") + no, show = soup.find_all('b') + show.replace_with(no) + self.assertEqual( + soup.decode(), + self.document_for( + "<p>There's business like <b>no</b> business</p>")) + + self.assertEqual(show.parent, None) + self.assertEqual(no.parent, soup.p) + self.assertEqual(no.next_element, "no") + self.assertEqual(no.next_sibling, " business") + + def test_replace_first_child(self): + data = "<a><b></b><c></c></a>" + soup = self.soup(data) + soup.b.replace_with(soup.c) + self.assertEqual("<a><c></c></a>", soup.decode()) + + def test_replace_last_child(self): + data = "<a><b></b><c></c></a>" + soup = self.soup(data) + soup.c.replace_with(soup.b) + self.assertEqual("<a><b></b></a>", soup.decode()) + + def test_nested_tag_replace_with(self): + soup = self.soup( + """<a>We<b>reserve<c>the</c><d>right</d></b></a><e>to<f>refuse</f><g>service</g></e>""") + + # Replace the entire <b> tag and its contents ("reserve the + # right") with the <f> tag ("refuse"). + remove_tag = soup.b + move_tag = soup.f + remove_tag.replace_with(move_tag) + + self.assertEqual( + soup.decode(), self.document_for( + "<a>We<f>refuse</f></a><e>to<g>service</g></e>")) + + # The <b> tag is now an orphan. + self.assertEqual(remove_tag.parent, None) + self.assertEqual(remove_tag.find(text="right").next_element, None) + self.assertEqual(remove_tag.previous_element, None) + self.assertEqual(remove_tag.next_sibling, None) + self.assertEqual(remove_tag.previous_sibling, None) + + # The <f> tag is now connected to the <a> tag. + self.assertEqual(move_tag.parent, soup.a) + self.assertEqual(move_tag.previous_element, "We") + self.assertEqual(move_tag.next_element.next_element, soup.e) + self.assertEqual(move_tag.next_sibling, None) + + # The gap where the <f> tag used to be has been mended, and + # the word "to" is now connected to the <g> tag. + to_text = soup.find(text="to") + g_tag = soup.g + self.assertEqual(to_text.next_element, g_tag) + self.assertEqual(to_text.next_sibling, g_tag) + self.assertEqual(g_tag.previous_element, to_text) + self.assertEqual(g_tag.previous_sibling, to_text) + + def test_unwrap(self): + tree = self.soup(""" + <p>Unneeded <em>formatting</em> is unneeded</p> + """) + tree.em.unwrap() + self.assertEqual(tree.em, None) + self.assertEqual(tree.p.text, "Unneeded formatting is unneeded") + + def test_wrap(self): + soup = self.soup("I wish I was bold.") + value = soup.string.wrap(soup.new_tag("b")) + self.assertEqual(value.decode(), "<b>I wish I was bold.</b>") + self.assertEqual( + soup.decode(), self.document_for("<b>I wish I was bold.</b>")) + + def test_wrap_extracts_tag_from_elsewhere(self): + soup = self.soup("<b></b>I wish I was bold.") + soup.b.next_sibling.wrap(soup.b) + self.assertEqual( + soup.decode(), self.document_for("<b>I wish I was bold.</b>")) + + def test_wrap_puts_new_contents_at_the_end(self): + soup = self.soup("<b>I like being bold.</b>I wish I was bold.") + soup.b.next_sibling.wrap(soup.b) + self.assertEqual(2, len(soup.b.contents)) + self.assertEqual( + soup.decode(), self.document_for( + "<b>I like being bold.I wish I was bold.</b>")) + + def test_extract(self): + soup = self.soup( + '<html><body>Some content. <div id="nav">Nav crap</div> More content.</body></html>') + + self.assertEqual(len(soup.body.contents), 3) + extracted = soup.find(id="nav").extract() + + self.assertEqual( + soup.decode(), "<html><body>Some content. More content.</body></html>") + self.assertEqual(extracted.decode(), '<div id="nav">Nav crap</div>') + + # The extracted tag is now an orphan. + self.assertEqual(len(soup.body.contents), 2) + self.assertEqual(extracted.parent, None) + self.assertEqual(extracted.previous_element, None) + self.assertEqual(extracted.next_element.next_element, None) + + # The gap where the extracted tag used to be has been mended. + content_1 = soup.find(text="Some content. ") + content_2 = soup.find(text=" More content.") + self.assertEqual(content_1.next_element, content_2) + self.assertEqual(content_1.next_sibling, content_2) + self.assertEqual(content_2.previous_element, content_1) + self.assertEqual(content_2.previous_sibling, content_1) + + def test_extract_distinguishes_between_identical_strings(self): + soup = self.soup("<a>foo</a><b>bar</b>") + foo_1 = soup.a.string + bar_1 = soup.b.string + foo_2 = soup.new_string("foo") + bar_2 = soup.new_string("bar") + soup.a.append(foo_2) + soup.b.append(bar_2) + + # Now there are two identical strings in the <a> tag, and two + # in the <b> tag. Let's remove the first "foo" and the second + # "bar". + foo_1.extract() + bar_2.extract() + self.assertEqual(foo_2, soup.a.string) + self.assertEqual(bar_2, soup.b.string) + + def test_extract_multiples_of_same_tag(self): + soup = self.soup(""" +<html> +<head> +<script>foo</script> +</head> +<body> + <script>bar</script> + <a></a> +</body> +<script>baz</script> +</html>""") + [soup.script.extract() for i in soup.find_all("script")] + self.assertEqual("<body>\n\n<a></a>\n</body>", str(soup.body)) + + + def test_extract_works_when_element_is_surrounded_by_identical_strings(self): + soup = self.soup( + '<html>\n' + '<body>hi</body>\n' + '</html>') + soup.find('body').extract() + self.assertEqual(None, soup.find('body')) + + + def test_clear(self): + """Tag.clear()""" + soup = self.soup("<p><a>String <em>Italicized</em></a> and another</p>") + # clear using extract() + a = soup.a + soup.p.clear() + self.assertEqual(len(soup.p.contents), 0) + self.assertTrue(hasattr(a, "contents")) + + # clear using decompose() + em = a.em + a.clear(decompose=True) + self.assertEqual(0, len(em.contents)) + + def test_string_set(self): + """Tag.string = 'string'""" + soup = self.soup("<a></a> <b><c></c></b>") + soup.a.string = "foo" + self.assertEqual(soup.a.contents, ["foo"]) + soup.b.string = "bar" + self.assertEqual(soup.b.contents, ["bar"]) + + def test_string_set_does_not_affect_original_string(self): + soup = self.soup("<a><b>foo</b><c>bar</c>") + soup.b.string = soup.c.string + self.assertEqual(soup.a.encode(), b"<a><b>bar</b><c>bar</c></a>") + + def test_set_string_preserves_class_of_string(self): + soup = self.soup("<a></a>") + cdata = CData("foo") + soup.a.string = cdata + self.assertTrue(isinstance(soup.a.string, CData)) + +class TestElementObjects(SoupTest): + """Test various features of element objects.""" + + def test_len(self): + """The length of an element is its number of children.""" + soup = self.soup("<top>1<b>2</b>3</top>") + + # The BeautifulSoup object itself contains one element: the + # <top> tag. + self.assertEqual(len(soup.contents), 1) + self.assertEqual(len(soup), 1) + + # The <top> tag contains three elements: the text node "1", the + # <b> tag, and the text node "3". + self.assertEqual(len(soup.top), 3) + self.assertEqual(len(soup.top.contents), 3) + + def test_member_access_invokes_find(self): + """Accessing a Python member .foo invokes find('foo')""" + soup = self.soup('<b><i></i></b>') + self.assertEqual(soup.b, soup.find('b')) + self.assertEqual(soup.b.i, soup.find('b').find('i')) + self.assertEqual(soup.a, None) + + def test_deprecated_member_access(self): + soup = self.soup('<b><i></i></b>') + with warnings.catch_warnings(record=True) as w: + tag = soup.bTag + self.assertEqual(soup.b, tag) + self.assertEqual( + '.bTag is deprecated, use .find("b") instead. If you really were looking for a tag called bTag, use .find("bTag")', + str(w[0].message)) + + def test_has_attr(self): + """has_attr() checks for the presence of an attribute. + + Please note note: has_attr() is different from + __in__. has_attr() checks the tag's attributes and __in__ + checks the tag's chidlren. + """ + soup = self.soup("<foo attr='bar'>") + self.assertTrue(soup.foo.has_attr('attr')) + self.assertFalse(soup.foo.has_attr('attr2')) + + + def test_attributes_come_out_in_alphabetical_order(self): + markup = '<b a="1" z="5" m="3" f="2" y="4"></b>' + self.assertSoupEquals(markup, '<b a="1" f="2" m="3" y="4" z="5"></b>') + + def test_string(self): + # A tag that contains only a text node makes that node + # available as .string. + soup = self.soup("<b>foo</b>") + self.assertEqual(soup.b.string, 'foo') + + def test_empty_tag_has_no_string(self): + # A tag with no children has no .stirng. + soup = self.soup("<b></b>") + self.assertEqual(soup.b.string, None) + + def test_tag_with_multiple_children_has_no_string(self): + # A tag with no children has no .string. + soup = self.soup("<a>foo<b></b><b></b></b>") + self.assertEqual(soup.b.string, None) + + soup = self.soup("<a>foo<b></b>bar</b>") + self.assertEqual(soup.b.string, None) + + # Even if all the children are strings, due to trickery, + # it won't work--but this would be a good optimization. + soup = self.soup("<a>foo</b>") + soup.a.insert(1, "bar") + self.assertEqual(soup.a.string, None) + + def test_tag_with_recursive_string_has_string(self): + # A tag with a single child which has a .string inherits that + # .string. + soup = self.soup("<a><b>foo</b></a>") + self.assertEqual(soup.a.string, "foo") + self.assertEqual(soup.string, "foo") + + def test_lack_of_string(self): + """Only a tag containing a single text node has a .string.""" + soup = self.soup("<b>f<i>e</i>o</b>") + self.assertFalse(soup.b.string) + + soup = self.soup("<b></b>") + self.assertFalse(soup.b.string) + + def test_all_text(self): + """Tag.text and Tag.get_text(sep=u"") -> all child text, concatenated""" + soup = self.soup("<a>a<b>r</b> <r> t </r></a>") + self.assertEqual(soup.a.text, "ar t ") + self.assertEqual(soup.a.get_text(strip=True), "art") + self.assertEqual(soup.a.get_text(","), "a,r, , t ") + self.assertEqual(soup.a.get_text(",", strip=True), "a,r,t") + + def test_get_text_ignores_comments(self): + soup = self.soup("foo<!--IGNORE-->bar") + self.assertEqual(soup.get_text(), "foobar") + + self.assertEqual( + soup.get_text(types=(NavigableString, Comment)), "fooIGNOREbar") + self.assertEqual( + soup.get_text(types=None), "fooIGNOREbar") + + def test_all_strings_ignores_comments(self): + soup = self.soup("foo<!--IGNORE-->bar") + self.assertEqual(['foo', 'bar'], list(soup.strings)) + +class TestCDAtaListAttributes(SoupTest): + + """Testing cdata-list attributes like 'class'. + """ + def test_single_value_becomes_list(self): + soup = self.soup("<a class='foo'>") + self.assertEqual(["foo"],soup.a['class']) + + def test_multiple_values_becomes_list(self): + soup = self.soup("<a class='foo bar'>") + self.assertEqual(["foo", "bar"], soup.a['class']) + + def test_multiple_values_separated_by_weird_whitespace(self): + soup = self.soup("<a class='foo\tbar\nbaz'>") + self.assertEqual(["foo", "bar", "baz"],soup.a['class']) + + def test_attributes_joined_into_string_on_output(self): + soup = self.soup("<a class='foo\tbar'>") + self.assertEqual(b'<a class="foo bar"></a>', soup.a.encode()) + + def test_get_attribute_list(self): + soup = self.soup("<a id='abc def'>") + self.assertEqual(['abc def'], soup.a.get_attribute_list('id')) + + def test_accept_charset(self): + soup = self.soup('<form accept-charset="ISO-8859-1 UTF-8">') + self.assertEqual(['ISO-8859-1', 'UTF-8'], soup.form['accept-charset']) + + def test_cdata_attribute_applying_only_to_one_tag(self): + data = '<a accept-charset="ISO-8859-1 UTF-8"></a>' + soup = self.soup(data) + # We saw in another test that accept-charset is a cdata-list + # attribute for the <form> tag. But it's not a cdata-list + # attribute for any other tag. + self.assertEqual('ISO-8859-1 UTF-8', soup.a['accept-charset']) + + def test_string_has_immutable_name_property(self): + string = self.soup("s").string + self.assertEqual(None, string.name) + def t(): + string.name = 'foo' + self.assertRaises(AttributeError, t) + +class TestPersistence(SoupTest): + "Testing features like pickle and deepcopy." + + def setUp(self): + super(TestPersistence, self).setUp() + self.page = """<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.0 Transitional//EN" +"http://www.w3.org/TR/REC-html40/transitional.dtd"> +<html> +<head> +<meta http-equiv="Content-Type" content="text/html; charset=utf-8"> +<title>Beautiful Soup: We called him Tortoise because he taught us.</title> +<link rev="made" href="mailto:[email protected]"> +<meta name="Description" content="Beautiful Soup: an HTML parser optimized for screen-scraping."> +<meta name="generator" content="Markov Approximation 1.4 (module: leonardr)"> +<meta name="author" content="Leonard Richardson"> +</head> +<body> +<a href="foo">foo</a> +<a href="foo"><b>bar</b></a> +</body> +</html>""" + self.tree = self.soup(self.page) + + def test_pickle_and_unpickle_identity(self): + # Pickling a tree, then unpickling it, yields a tree identical + # to the original. + dumped = pickle.dumps(self.tree, 2) + loaded = pickle.loads(dumped) + self.assertEqual(loaded.__class__, BeautifulSoup) + self.assertEqual(loaded.decode(), self.tree.decode()) + + def test_deepcopy_identity(self): + # Making a deepcopy of a tree yields an identical tree. + copied = copy.deepcopy(self.tree) + self.assertEqual(copied.decode(), self.tree.decode()) + + def test_copy_preserves_encoding(self): + soup = BeautifulSoup(b'<p> </p>', 'html.parser') + encoding = soup.original_encoding + copy = soup.__copy__() + self.assertEqual("<p> </p>", str(copy)) + self.assertEqual(encoding, copy.original_encoding) + + def test_unicode_pickle(self): + # A tree containing Unicode characters can be pickled. + html = "<b>\N{SNOWMAN}</b>" + soup = self.soup(html) + dumped = pickle.dumps(soup, pickle.HIGHEST_PROTOCOL) + loaded = pickle.loads(dumped) + self.assertEqual(loaded.decode(), soup.decode()) + + def test_copy_navigablestring_is_not_attached_to_tree(self): + html = "<b>Foo<a></a></b><b>Bar</b>" + soup = self.soup(html) + s1 = soup.find(string="Foo") + s2 = copy.copy(s1) + self.assertEqual(s1, s2) + self.assertEqual(None, s2.parent) + self.assertEqual(None, s2.next_element) + self.assertNotEqual(None, s1.next_sibling) + self.assertEqual(None, s2.next_sibling) + self.assertEqual(None, s2.previous_element) + + def test_copy_navigablestring_subclass_has_same_type(self): + html = "<b><!--Foo--></b>" + soup = self.soup(html) + s1 = soup.string + s2 = copy.copy(s1) + self.assertEqual(s1, s2) + self.assertTrue(isinstance(s2, Comment)) + + def test_copy_entire_soup(self): + html = "<div><b>Foo<a></a></b><b>Bar</b></div>end" + soup = self.soup(html) + soup_copy = copy.copy(soup) + self.assertEqual(soup, soup_copy) + + def test_copy_tag_copies_contents(self): + html = "<div><b>Foo<a></a></b><b>Bar</b></div>end" + soup = self.soup(html) + div = soup.div + div_copy = copy.copy(div) + + # The two tags look the same, and evaluate to equal. + self.assertEqual(str(div), str(div_copy)) + self.assertEqual(div, div_copy) + + # But they're not the same object. + self.assertFalse(div is div_copy) + + # And they don't have the same relation to the parse tree. The + # copy is not associated with a parse tree at all. + self.assertEqual(None, div_copy.parent) + self.assertEqual(None, div_copy.previous_element) + self.assertEqual(None, div_copy.find(string='Bar').next_element) + self.assertNotEqual(None, div.find(string='Bar').next_element) + +class TestSubstitutions(SoupTest): + + def test_default_formatter_is_minimal(self): + markup = "<b><<Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!>></b>" + soup = self.soup(markup) + decoded = soup.decode(formatter="minimal") + # The < is converted back into < but the e-with-acute is left alone. + self.assertEqual( + decoded, + self.document_for( + "<b><<Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!>></b>")) + + def test_formatter_html(self): + markup = "<br><b><<Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!>></b>" + soup = self.soup(markup) + decoded = soup.decode(formatter="html") + self.assertEqual( + decoded, + self.document_for("<br/><b><<Sacré bleu!>></b>")) + + def test_formatter_html5(self): + markup = "<br><b><<Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!>></b>" + soup = self.soup(markup) + decoded = soup.decode(formatter="html5") + self.assertEqual( + decoded, + self.document_for("<br><b><<Sacré bleu!>></b>")) + + def test_formatter_minimal(self): + markup = "<b><<Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!>></b>" + soup = self.soup(markup) + decoded = soup.decode(formatter="minimal") + # The < is converted back into < but the e-with-acute is left alone. + self.assertEqual( + decoded, + self.document_for( + "<b><<Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!>></b>")) + + def test_formatter_null(self): + markup = "<b><<Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!>></b>" + soup = self.soup(markup) + decoded = soup.decode(formatter=None) + # Neither the angle brackets nor the e-with-acute are converted. + # This is not valid HTML, but it's what the user wanted. + self.assertEqual(decoded, + self.document_for("<b><<Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!>></b>")) + + def test_formatter_custom(self): + markup = "<b><foo></b><b>bar</b><br/>" + soup = self.soup(markup) + decoded = soup.decode(formatter = lambda x: x.upper()) + # Instead of normal entity conversion code, the custom + # callable is called on every string. + self.assertEqual( + decoded, + self.document_for("<b><FOO></b><b>BAR</b><br/>")) + + def test_formatter_is_run_on_attribute_values(self): + markup = '<a href="http://a.com?a=b&c=é">e</a>' + soup = self.soup(markup) + a = soup.a + + expect_minimal = '<a href="http://a.com?a=b&c=é">e</a>' + + self.assertEqual(expect_minimal, a.decode()) + self.assertEqual(expect_minimal, a.decode(formatter="minimal")) + + expect_html = '<a href="http://a.com?a=b&c=é">e</a>' + self.assertEqual(expect_html, a.decode(formatter="html")) + + self.assertEqual(markup, a.decode(formatter=None)) + expect_upper = '<a href="HTTP://A.COM?A=B&C=É">E</a>' + self.assertEqual(expect_upper, a.decode(formatter=lambda x: x.upper())) + + def test_formatter_skips_script_tag_for_html_documents(self): + doc = """ + <script type="text/javascript"> + console.log("< < hey > > "); + </script> +""" + encoded = BeautifulSoup(doc, 'html.parser').encode() + self.assertTrue(b"< < hey > >" in encoded) + + def test_formatter_skips_style_tag_for_html_documents(self): + doc = """ + <style type="text/css"> + console.log("< < hey > > "); + </style> +""" + encoded = BeautifulSoup(doc, 'html.parser').encode() + self.assertTrue(b"< < hey > >" in encoded) + + def test_prettify_leaves_preformatted_text_alone(self): + soup = self.soup("<div> foo <pre> \tbar\n \n </pre> baz <textarea> eee\nfff\t</textarea></div>") + # Everything outside the <pre> tag is reformatted, but everything + # inside is left alone. + self.assertEqual( + '<div>\n foo\n <pre> \tbar\n \n </pre>\n baz\n <textarea> eee\nfff\t</textarea>\n</div>', + soup.div.prettify()) + + def test_prettify_accepts_formatter_function(self): + soup = BeautifulSoup("<html><body>foo</body></html>", 'html.parser') + pretty = soup.prettify(formatter = lambda x: x.upper()) + self.assertTrue("FOO" in pretty) + + def test_prettify_outputs_unicode_by_default(self): + soup = self.soup("<a></a>") + self.assertEqual(str, type(soup.prettify())) + + def test_prettify_can_encode_data(self): + soup = self.soup("<a></a>") + self.assertEqual(bytes, type(soup.prettify("utf-8"))) + + def test_html_entity_substitution_off_by_default(self): + markup = "<b>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</b>" + soup = self.soup(markup) + encoded = soup.b.encode("utf-8") + self.assertEqual(encoded, markup.encode('utf-8')) + + def test_encoding_substitution(self): + # Here's the <meta> tag saying that a document is + # encoded in Shift-JIS. + meta_tag = ('<meta content="text/html; charset=x-sjis" ' + 'http-equiv="Content-type"/>') + soup = self.soup(meta_tag) + + # Parse the document, and the charset apprears unchanged. + self.assertEqual(soup.meta['content'], 'text/html; charset=x-sjis') + + # Encode the document into some encoding, and the encoding is + # substituted into the meta tag. + utf_8 = soup.encode("utf-8") + self.assertTrue(b"charset=utf-8" in utf_8) + + euc_jp = soup.encode("euc_jp") + self.assertTrue(b"charset=euc_jp" in euc_jp) + + shift_jis = soup.encode("shift-jis") + self.assertTrue(b"charset=shift-jis" in shift_jis) + + utf_16_u = soup.encode("utf-16").decode("utf-16") + self.assertTrue("charset=utf-16" in utf_16_u) + + def test_encoding_substitution_doesnt_happen_if_tag_is_strained(self): + markup = ('<head><meta content="text/html; charset=x-sjis" ' + 'http-equiv="Content-type"/></head><pre>foo</pre>') + + # Beautiful Soup used to try to rewrite the meta tag even if the + # meta tag got filtered out by the strainer. This test makes + # sure that doesn't happen. + strainer = SoupStrainer('pre') + soup = self.soup(markup, parse_only=strainer) + self.assertEqual(soup.contents[0].name, 'pre') + +class TestEncoding(SoupTest): + """Test the ability to encode objects into strings.""" + + def test_unicode_string_can_be_encoded(self): + html = "<b>\N{SNOWMAN}</b>" + soup = self.soup(html) + self.assertEqual(soup.b.string.encode("utf-8"), + "\N{SNOWMAN}".encode("utf-8")) + + def test_tag_containing_unicode_string_can_be_encoded(self): + html = "<b>\N{SNOWMAN}</b>" + soup = self.soup(html) + self.assertEqual( + soup.b.encode("utf-8"), html.encode("utf-8")) + + def test_encoding_substitutes_unrecognized_characters_by_default(self): + html = "<b>\N{SNOWMAN}</b>" + soup = self.soup(html) + self.assertEqual(soup.b.encode("ascii"), b"<b>☃</b>") + + def test_encoding_can_be_made_strict(self): + html = "<b>\N{SNOWMAN}</b>" + soup = self.soup(html) + self.assertRaises( + UnicodeEncodeError, soup.encode, "ascii", errors="strict") + + def test_decode_contents(self): + html = "<b>\N{SNOWMAN}</b>" + soup = self.soup(html) + self.assertEqual("\N{SNOWMAN}", soup.b.decode_contents()) + + def test_encode_contents(self): + html = "<b>\N{SNOWMAN}</b>" + soup = self.soup(html) + self.assertEqual( + "\N{SNOWMAN}".encode("utf8"), soup.b.encode_contents( + encoding="utf8")) + + def test_deprecated_renderContents(self): + html = "<b>\N{SNOWMAN}</b>" + soup = self.soup(html) + self.assertEqual( + "\N{SNOWMAN}".encode("utf8"), soup.b.renderContents()) + + def test_repr(self): + html = "<b>\N{SNOWMAN}</b>" + soup = self.soup(html) + if PY3K: + self.assertEqual(html, repr(soup)) + else: + self.assertEqual(b'<b>\\u2603</b>', repr(soup)) + +class TestFormatter(SoupTest): + + def test_sort_attributes(self): + # Test the ability to override Formatter.attributes() to, + # e.g., disable the normal sorting of attributes. + class UnsortedFormatter(Formatter): + def attributes(self, tag): + self.called_with = tag + for k, v in sorted(tag.attrs.items()): + if k == 'ignore': + continue + yield k,v + + soup = self.soup('<p cval="1" aval="2" ignore="ignored"></p>') + formatter = UnsortedFormatter() + decoded = soup.decode(formatter=formatter) + + # attributes() was called on the <p> tag. It filtered out one + # attribute and sorted the other two. + self.assertEqual(formatter.called_with, soup.p) + self.assertEqual('<p aval="2" cval="1"></p>', decoded) + + +class TestNavigableStringSubclasses(SoupTest): + + def test_cdata(self): + # None of the current builders turn CDATA sections into CData + # objects, but you can create them manually. + soup = self.soup("") + cdata = CData("foo") + soup.insert(1, cdata) + self.assertEqual(str(soup), "<![CDATA[foo]]>") + self.assertEqual(soup.find(text="foo"), "foo") + self.assertEqual(soup.contents[0], "foo") + + def test_cdata_is_never_formatted(self): + """Text inside a CData object is passed into the formatter. + + But the return value is ignored. + """ + + self.count = 0 + def increment(*args): + self.count += 1 + return "BITTER FAILURE" + + soup = self.soup("") + cdata = CData("<><><>") + soup.insert(1, cdata) + self.assertEqual( + b"<![CDATA[<><><>]]>", soup.encode(formatter=increment)) + self.assertEqual(1, self.count) + + def test_doctype_ends_in_newline(self): + # Unlike other NavigableString subclasses, a DOCTYPE always ends + # in a newline. + doctype = Doctype("foo") + soup = self.soup("") + soup.insert(1, doctype) + self.assertEqual(soup.encode(), b"<!DOCTYPE foo>\n") + + def test_declaration(self): + d = Declaration("foo") + self.assertEqual("<?foo?>", d.output_ready()) + +class TestSoupSelector(TreeTest): + + HTML = """ +<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01//EN" +"http://www.w3.org/TR/html4/strict.dtd"> +<html> +<head> +<title>The title</title> +<link rel="stylesheet" href="blah.css" type="text/css" id="l1"> +</head> +<body> +<custom-dashed-tag class="dashed" id="dash1">Hello there.</custom-dashed-tag> +<div id="main" class="fancy"> +<div id="inner"> +<h1 id="header1">An H1</h1> +<p>Some text</p> +<p class="onep" id="p1">Some more text</p> +<h2 id="header2">An H2</h2> +<p class="class1 class2 class3" id="pmulti">Another</p> +<a href="http://bob.example.org/" rel="friend met" id="bob">Bob</a> +<h2 id="header3">Another H2</h2> +<a id="me" href="http://simonwillison.net/" rel="me">me</a> +<span class="s1"> +<a href="#" id="s1a1">span1a1</a> +<a href="#" id="s1a2">span1a2 <span id="s1a2s1">test</span></a> +<span class="span2"> +<a href="#" id="s2a1">span2a1</a> +</span> +<span class="span3"></span> +<custom-dashed-tag class="dashed" id="dash2"/> +<div data-tag="dashedvalue" id="data1"/> +</span> +</div> +<x id="xid"> +<z id="zida"/> +<z id="zidab"/> +<z id="zidac"/> +</x> +<y id="yid"> +<z id="zidb"/> +</y> +<p lang="en" id="lang-en">English</p> +<p lang="en-gb" id="lang-en-gb">English UK</p> +<p lang="en-us" id="lang-en-us">English US</p> +<p lang="fr" id="lang-fr">French</p> +</div> + +<div id="footer"> +</div> +""" + + def setUp(self): + self.soup = BeautifulSoup(self.HTML, 'html.parser') + + def assertSelects(self, selector, expected_ids, **kwargs): + el_ids = [el['id'] for el in self.soup.select(selector, **kwargs)] + el_ids.sort() + expected_ids.sort() + self.assertEqual(expected_ids, el_ids, + "Selector %s, expected [%s], got [%s]" % ( + selector, ', '.join(expected_ids), ', '.join(el_ids) + ) + ) + + assertSelect = assertSelects + + def assertSelectMultiple(self, *tests): + for selector, expected_ids in tests: + self.assertSelect(selector, expected_ids) + + def test_one_tag_one(self): + els = self.soup.select('title') + self.assertEqual(len(els), 1) + self.assertEqual(els[0].name, 'title') + self.assertEqual(els[0].contents, ['The title']) + + def test_one_tag_many(self): + els = self.soup.select('div') + self.assertEqual(len(els), 4) + for div in els: + self.assertEqual(div.name, 'div') + + el = self.soup.select_one('div') + self.assertEqual('main', el['id']) + + def test_select_one_returns_none_if_no_match(self): + match = self.soup.select_one('nonexistenttag') + self.assertEqual(None, match) + + + def test_tag_in_tag_one(self): + els = self.soup.select('div div') + self.assertSelects('div div', ['inner', 'data1']) + + def test_tag_in_tag_many(self): + for selector in ('html div', 'html body div', 'body div'): + self.assertSelects(selector, ['data1', 'main', 'inner', 'footer']) + + + def test_limit(self): + self.assertSelects('html div', ['main'], limit=1) + self.assertSelects('html body div', ['inner', 'main'], limit=2) + self.assertSelects('body div', ['data1', 'main', 'inner', 'footer'], + limit=10) + + def test_tag_no_match(self): + self.assertEqual(len(self.soup.select('del')), 0) + + def test_invalid_tag(self): + self.assertRaises(SyntaxError, self.soup.select, 'tag%t') + + def test_select_dashed_tag_ids(self): + self.assertSelects('custom-dashed-tag', ['dash1', 'dash2']) + + def test_select_dashed_by_id(self): + dashed = self.soup.select('custom-dashed-tag[id=\"dash2\"]') + self.assertEqual(dashed[0].name, 'custom-dashed-tag') + self.assertEqual(dashed[0]['id'], 'dash2') + + def test_dashed_tag_text(self): + self.assertEqual(self.soup.select('body > custom-dashed-tag')[0].text, 'Hello there.') + + def test_select_dashed_matches_find_all(self): + self.assertEqual(self.soup.select('custom-dashed-tag'), self.soup.find_all('custom-dashed-tag')) + + def test_header_tags(self): + self.assertSelectMultiple( + ('h1', ['header1']), + ('h2', ['header2', 'header3']), + ) + + def test_class_one(self): + for selector in ('.onep', 'p.onep', 'html p.onep'): + els = self.soup.select(selector) + self.assertEqual(len(els), 1) + self.assertEqual(els[0].name, 'p') + self.assertEqual(els[0]['class'], ['onep']) + + def test_class_mismatched_tag(self): + els = self.soup.select('div.onep') + self.assertEqual(len(els), 0) + + def test_one_id(self): + for selector in ('div#inner', '#inner', 'div div#inner'): + self.assertSelects(selector, ['inner']) + + def test_bad_id(self): + els = self.soup.select('#doesnotexist') + self.assertEqual(len(els), 0) + + def test_items_in_id(self): + els = self.soup.select('div#inner p') + self.assertEqual(len(els), 3) + for el in els: + self.assertEqual(el.name, 'p') + self.assertEqual(els[1]['class'], ['onep']) + self.assertFalse(els[0].has_attr('class')) + + def test_a_bunch_of_emptys(self): + for selector in ('div#main del', 'div#main div.oops', 'div div#main'): + self.assertEqual(len(self.soup.select(selector)), 0) + + def test_multi_class_support(self): + for selector in ('.class1', 'p.class1', '.class2', 'p.class2', + '.class3', 'p.class3', 'html p.class2', 'div#inner .class2'): + self.assertSelects(selector, ['pmulti']) + + def test_multi_class_selection(self): + for selector in ('.class1.class3', '.class3.class2', + '.class1.class2.class3'): + self.assertSelects(selector, ['pmulti']) + + def test_child_selector(self): + self.assertSelects('.s1 > a', ['s1a1', 's1a2']) + self.assertSelects('.s1 > a span', ['s1a2s1']) + + def test_child_selector_id(self): + self.assertSelects('.s1 > a#s1a2 span', ['s1a2s1']) + + def test_attribute_equals(self): + self.assertSelectMultiple( + ('p[class="onep"]', ['p1']), + ('p[id="p1"]', ['p1']), + ('[class="onep"]', ['p1']), + ('[id="p1"]', ['p1']), + ('link[rel="stylesheet"]', ['l1']), + ('link[type="text/css"]', ['l1']), + ('link[href="blah.css"]', ['l1']), + ('link[href="no-blah.css"]', []), + ('[rel="stylesheet"]', ['l1']), + ('[type="text/css"]', ['l1']), + ('[href="blah.css"]', ['l1']), + ('[href="no-blah.css"]', []), + ('p[href="no-blah.css"]', []), + ('[href="no-blah.css"]', []), + ) + + def test_attribute_tilde(self): + self.assertSelectMultiple( + ('p[class~="class1"]', ['pmulti']), + ('p[class~="class2"]', ['pmulti']), + ('p[class~="class3"]', ['pmulti']), + ('[class~="class1"]', ['pmulti']), + ('[class~="class2"]', ['pmulti']), + ('[class~="class3"]', ['pmulti']), + ('a[rel~="friend"]', ['bob']), + ('a[rel~="met"]', ['bob']), + ('[rel~="friend"]', ['bob']), + ('[rel~="met"]', ['bob']), + ) + + def test_attribute_startswith(self): + self.assertSelectMultiple( + ('[rel^="style"]', ['l1']), + ('link[rel^="style"]', ['l1']), + ('notlink[rel^="notstyle"]', []), + ('[rel^="notstyle"]', []), + ('link[rel^="notstyle"]', []), + ('link[href^="bla"]', ['l1']), + ('a[href^="http://"]', ['bob', 'me']), + ('[href^="http://"]', ['bob', 'me']), + ('[id^="p"]', ['pmulti', 'p1']), + ('[id^="m"]', ['me', 'main']), + ('div[id^="m"]', ['main']), + ('a[id^="m"]', ['me']), + ('div[data-tag^="dashed"]', ['data1']) + ) + + def test_attribute_endswith(self): + self.assertSelectMultiple( + ('[href$=".css"]', ['l1']), + ('link[href$=".css"]', ['l1']), + ('link[id$="1"]', ['l1']), + ('[id$="1"]', ['data1', 'l1', 'p1', 'header1', 's1a1', 's2a1', 's1a2s1', 'dash1']), + ('div[id$="1"]', ['data1']), + ('[id$="noending"]', []), + ) + + def test_attribute_contains(self): + self.assertSelectMultiple( + # From test_attribute_startswith + ('[rel*="style"]', ['l1']), + ('link[rel*="style"]', ['l1']), + ('notlink[rel*="notstyle"]', []), + ('[rel*="notstyle"]', []), + ('link[rel*="notstyle"]', []), + ('link[href*="bla"]', ['l1']), + ('[href*="http://"]', ['bob', 'me']), + ('[id*="p"]', ['pmulti', 'p1']), + ('div[id*="m"]', ['main']), + ('a[id*="m"]', ['me']), + # From test_attribute_endswith + ('[href*=".css"]', ['l1']), + ('link[href*=".css"]', ['l1']), + ('link[id*="1"]', ['l1']), + ('[id*="1"]', ['data1', 'l1', 'p1', 'header1', 's1a1', 's1a2', 's2a1', 's1a2s1', 'dash1']), + ('div[id*="1"]', ['data1']), + ('[id*="noending"]', []), + # New for this test + ('[href*="."]', ['bob', 'me', 'l1']), + ('a[href*="."]', ['bob', 'me']), + ('link[href*="."]', ['l1']), + ('div[id*="n"]', ['main', 'inner']), + ('div[id*="nn"]', ['inner']), + ('div[data-tag*="edval"]', ['data1']) + ) + + def test_attribute_exact_or_hypen(self): + self.assertSelectMultiple( + ('p[lang|="en"]', ['lang-en', 'lang-en-gb', 'lang-en-us']), + ('[lang|="en"]', ['lang-en', 'lang-en-gb', 'lang-en-us']), + ('p[lang|="fr"]', ['lang-fr']), + ('p[lang|="gb"]', []), + ) + + def test_attribute_exists(self): + self.assertSelectMultiple( + ('[rel]', ['l1', 'bob', 'me']), + ('link[rel]', ['l1']), + ('a[rel]', ['bob', 'me']), + ('[lang]', ['lang-en', 'lang-en-gb', 'lang-en-us', 'lang-fr']), + ('p[class]', ['p1', 'pmulti']), + ('[blah]', []), + ('p[blah]', []), + ('div[data-tag]', ['data1']) + ) + + def test_quoted_space_in_selector_name(self): + html = """<div style="display: wrong">nope</div> + <div style="display: right">yes</div> + """ + soup = BeautifulSoup(html, 'html.parser') + [chosen] = soup.select('div[style="display: right"]') + self.assertEqual("yes", chosen.string) + + def test_unsupported_pseudoclass(self): + self.assertRaises( + NotImplementedError, self.soup.select, "a:no-such-pseudoclass") + + self.assertRaises( + SyntaxError, self.soup.select, "a:nth-of-type(a)") + + def test_nth_of_type(self): + # Try to select first paragraph + els = self.soup.select('div#inner p:nth-of-type(1)') + self.assertEqual(len(els), 1) + self.assertEqual(els[0].string, 'Some text') + + # Try to select third paragraph + els = self.soup.select('div#inner p:nth-of-type(3)') + self.assertEqual(len(els), 1) + self.assertEqual(els[0].string, 'Another') + + # Try to select (non-existent!) fourth paragraph + els = self.soup.select('div#inner p:nth-of-type(4)') + self.assertEqual(len(els), 0) + + # Zero will select no tags. + els = self.soup.select('div p:nth-of-type(0)') + self.assertEqual(len(els), 0) + + def test_nth_of_type_direct_descendant(self): + els = self.soup.select('div#inner > p:nth-of-type(1)') + self.assertEqual(len(els), 1) + self.assertEqual(els[0].string, 'Some text') + + def test_id_child_selector_nth_of_type(self): + self.assertSelects('#inner > p:nth-of-type(2)', ['p1']) + + def test_select_on_element(self): + # Other tests operate on the tree; this operates on an element + # within the tree. + inner = self.soup.find("div", id="main") + selected = inner.select("div") + # The <div id="inner"> tag was selected. The <div id="footer"> + # tag was not. + self.assertSelectsIDs(selected, ['inner', 'data1']) + + def test_overspecified_child_id(self): + self.assertSelects(".fancy #inner", ['inner']) + self.assertSelects(".normal #inner", []) + + def test_adjacent_sibling_selector(self): + self.assertSelects('#p1 + h2', ['header2']) + self.assertSelects('#p1 + h2 + p', ['pmulti']) + self.assertSelects('#p1 + #header2 + .class1', ['pmulti']) + self.assertEqual([], self.soup.select('#p1 + p')) + + def test_general_sibling_selector(self): + self.assertSelects('#p1 ~ h2', ['header2', 'header3']) + self.assertSelects('#p1 ~ #header2', ['header2']) + self.assertSelects('#p1 ~ h2 + a', ['me']) + self.assertSelects('#p1 ~ h2 + [rel="me"]', ['me']) + self.assertEqual([], self.soup.select('#inner ~ h2')) + + def test_dangling_combinator(self): + self.assertRaises(SyntaxError, self.soup.select, 'h1 >') + + def test_sibling_combinator_wont_select_same_tag_twice(self): + self.assertSelects('p[lang] ~ p', ['lang-en-gb', 'lang-en-us', 'lang-fr']) + + # Test the selector grouping operator (the comma) + def test_multiple_select(self): + self.assertSelects('x, y', ['xid', 'yid']) + + def test_multiple_select_with_no_space(self): + self.assertSelects('x,y', ['xid', 'yid']) + + def test_multiple_select_with_more_space(self): + self.assertSelects('x, y', ['xid', 'yid']) + + def test_multiple_select_duplicated(self): + self.assertSelects('x, x', ['xid']) + + def test_multiple_select_sibling(self): + self.assertSelects('x, y ~ p[lang=fr]', ['xid', 'lang-fr']) + + def test_multiple_select_tag_and_direct_descendant(self): + self.assertSelects('x, y > z', ['xid', 'zidb']) + + def test_multiple_select_direct_descendant_and_tags(self): + self.assertSelects('div > x, y, z', ['xid', 'yid', 'zida', 'zidb', 'zidab', 'zidac']) + + def test_multiple_select_indirect_descendant(self): + self.assertSelects('div x,y, z', ['xid', 'yid', 'zida', 'zidb', 'zidab', 'zidac']) + + def test_invalid_multiple_select(self): + self.assertRaises(SyntaxError, self.soup.select, ',x, y') + self.assertRaises(SyntaxError, self.soup.select, 'x,,y') + + def test_multiple_select_attrs(self): + self.assertSelects('p[lang=en], p[lang=en-gb]', ['lang-en', 'lang-en-gb']) + + def test_multiple_select_ids(self): + self.assertSelects('x, y > z[id=zida], z[id=zidab], z[id=zidb]', ['xid', 'zidb', 'zidab']) + + def test_multiple_select_nested(self): + self.assertSelects('body > div > x, y > z', ['xid', 'zidb']) + + def test_select_duplicate_elements(self): + # When markup contains duplicate elements, a multiple select + # will find all of them. + markup = '<div class="c1"/><div class="c2"/><div class="c1"/>' + soup = BeautifulSoup(markup, 'html.parser') + selected = soup.select(".c1, .c2") + self.assertEqual(3, len(selected)) + + # Verify that find_all finds the same elements, though because + # of an implementation detail it finds them in a different + # order. + for element in soup.find_all(class_=['c1', 'c2']): + assert element in selected diff --git a/libs/engineio/__init__.py b/libs/engineio/__init__.py new file mode 100644 index 000000000..f2c5b774c --- /dev/null +++ b/libs/engineio/__init__.py @@ -0,0 +1,25 @@ +import sys + +from .client import Client +from .middleware import WSGIApp, Middleware +from .server import Server +if sys.version_info >= (3, 5): # pragma: no cover + from .asyncio_server import AsyncServer + from .asyncio_client import AsyncClient + from .async_drivers.asgi import ASGIApp + try: + from .async_drivers.tornado import get_tornado_handler + except ImportError: + get_tornado_handler = None +else: # pragma: no cover + AsyncServer = None + AsyncClient = None + get_tornado_handler = None + ASGIApp = None + +__version__ = '3.11.2' + +__all__ = ['__version__', 'Server', 'WSGIApp', 'Middleware', 'Client'] +if AsyncServer is not None: # pragma: no cover + __all__ += ['AsyncServer', 'ASGIApp', 'get_tornado_handler', + 'AsyncClient'], diff --git a/libs/engineio/async_drivers/__init__.py b/libs/engineio/async_drivers/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/libs/engineio/async_drivers/__init__.py diff --git a/libs/engineio/async_drivers/aiohttp.py b/libs/engineio/async_drivers/aiohttp.py new file mode 100644 index 000000000..ad6987649 --- /dev/null +++ b/libs/engineio/async_drivers/aiohttp.py @@ -0,0 +1,128 @@ +import asyncio +import sys +from urllib.parse import urlsplit + +from aiohttp.web import Response, WebSocketResponse +import six + + +def create_route(app, engineio_server, engineio_endpoint): + """This function sets up the engine.io endpoint as a route for the + application. + + Note that both GET and POST requests must be hooked up on the engine.io + endpoint. + """ + app.router.add_get(engineio_endpoint, engineio_server.handle_request) + app.router.add_post(engineio_endpoint, engineio_server.handle_request) + app.router.add_route('OPTIONS', engineio_endpoint, + engineio_server.handle_request) + + +def translate_request(request): + """This function takes the arguments passed to the request handler and + uses them to generate a WSGI compatible environ dictionary. + """ + message = request._message + payload = request._payload + + uri_parts = urlsplit(message.path) + environ = { + 'wsgi.input': payload, + 'wsgi.errors': sys.stderr, + 'wsgi.version': (1, 0), + 'wsgi.async': True, + 'wsgi.multithread': False, + 'wsgi.multiprocess': False, + 'wsgi.run_once': False, + 'SERVER_SOFTWARE': 'aiohttp', + 'REQUEST_METHOD': message.method, + 'QUERY_STRING': uri_parts.query or '', + 'RAW_URI': message.path, + 'SERVER_PROTOCOL': 'HTTP/%s.%s' % message.version, + 'REMOTE_ADDR': '127.0.0.1', + 'REMOTE_PORT': '0', + 'SERVER_NAME': 'aiohttp', + 'SERVER_PORT': '0', + 'aiohttp.request': request + } + + for hdr_name, hdr_value in message.headers.items(): + hdr_name = hdr_name.upper() + if hdr_name == 'CONTENT-TYPE': + environ['CONTENT_TYPE'] = hdr_value + continue + elif hdr_name == 'CONTENT-LENGTH': + environ['CONTENT_LENGTH'] = hdr_value + continue + + key = 'HTTP_%s' % hdr_name.replace('-', '_') + if key in environ: + hdr_value = '%s,%s' % (environ[key], hdr_value) + + environ[key] = hdr_value + + environ['wsgi.url_scheme'] = environ.get('HTTP_X_FORWARDED_PROTO', 'http') + + path_info = uri_parts.path + + environ['PATH_INFO'] = path_info + environ['SCRIPT_NAME'] = '' + + return environ + + +def make_response(status, headers, payload, environ): + """This function generates an appropriate response object for this async + mode. + """ + return Response(body=payload, status=int(status.split()[0]), + headers=headers) + + +class WebSocket(object): # pragma: no cover + """ + This wrapper class provides a aiohttp WebSocket interface that is + somewhat compatible with eventlet's implementation. + """ + def __init__(self, handler): + self.handler = handler + self._sock = None + + async def __call__(self, environ): + request = environ['aiohttp.request'] + self._sock = WebSocketResponse() + await self._sock.prepare(request) + + self.environ = environ + await self.handler(self) + return self._sock + + async def close(self): + await self._sock.close() + + async def send(self, message): + if isinstance(message, bytes): + f = self._sock.send_bytes + else: + f = self._sock.send_str + if asyncio.iscoroutinefunction(f): + await f(message) + else: + f(message) + + async def wait(self): + msg = await self._sock.receive() + if not isinstance(msg.data, six.binary_type) and \ + not isinstance(msg.data, six.text_type): + raise IOError() + return msg.data + + +_async = { + 'asyncio': True, + 'create_route': create_route, + 'translate_request': translate_request, + 'make_response': make_response, + 'websocket': WebSocket, +} diff --git a/libs/engineio/async_drivers/asgi.py b/libs/engineio/async_drivers/asgi.py new file mode 100644 index 000000000..9f14ef05f --- /dev/null +++ b/libs/engineio/async_drivers/asgi.py @@ -0,0 +1,214 @@ +import os +import sys + +from engineio.static_files import get_static_file + + +class ASGIApp: + """ASGI application middleware for Engine.IO. + + This middleware dispatches traffic to an Engine.IO application. It can + also serve a list of static files to the client, or forward unrelated + HTTP traffic to another ASGI application. + + :param engineio_server: The Engine.IO server. Must be an instance of the + ``engineio.AsyncServer`` class. + :param static_files: A dictionary with static file mapping rules. See the + documentation for details on this argument. + :param other_asgi_app: A separate ASGI app that receives all other traffic. + :param engineio_path: The endpoint where the Engine.IO application should + be installed. The default value is appropriate for + most cases. + + Example usage:: + + import engineio + import uvicorn + + eio = engineio.AsyncServer() + app = engineio.ASGIApp(eio, static_files={ + '/': {'content_type': 'text/html', 'filename': 'index.html'}, + '/index.html': {'content_type': 'text/html', + 'filename': 'index.html'}, + }) + uvicorn.run(app, '127.0.0.1', 5000) + """ + def __init__(self, engineio_server, other_asgi_app=None, + static_files=None, engineio_path='engine.io'): + self.engineio_server = engineio_server + self.other_asgi_app = other_asgi_app + self.engineio_path = engineio_path.strip('/') + self.static_files = static_files or {} + + async def __call__(self, scope, receive, send): + if scope['type'] in ['http', 'websocket'] and \ + scope['path'].startswith('/{0}/'.format(self.engineio_path)): + await self.engineio_server.handle_request(scope, receive, send) + else: + static_file = get_static_file(scope['path'], self.static_files) \ + if scope['type'] == 'http' and self.static_files else None + if static_file: + await self.serve_static_file(static_file, receive, send) + elif self.other_asgi_app is not None: + await self.other_asgi_app(scope, receive, send) + elif scope['type'] == 'lifespan': + await self.lifespan(receive, send) + else: + await self.not_found(receive, send) + + async def serve_static_file(self, static_file, receive, + send): # pragma: no cover + event = await receive() + if event['type'] == 'http.request': + if os.path.exists(static_file['filename']): + with open(static_file['filename'], 'rb') as f: + payload = f.read() + await send({'type': 'http.response.start', + 'status': 200, + 'headers': [(b'Content-Type', static_file[ + 'content_type'].encode('utf-8'))]}) + await send({'type': 'http.response.body', + 'body': payload}) + else: + await self.not_found(receive, send) + + async def lifespan(self, receive, send): + event = await receive() + if event['type'] == 'lifespan.startup': + await send({'type': 'lifespan.startup.complete'}) + elif event['type'] == 'lifespan.shutdown': + await send({'type': 'lifespan.shutdown.complete'}) + + async def not_found(self, receive, send): + """Return a 404 Not Found error to the client.""" + await send({'type': 'http.response.start', + 'status': 404, + 'headers': [(b'Content-Type', b'text/plain')]}) + await send({'type': 'http.response.body', + 'body': b'Not Found'}) + + +async def translate_request(scope, receive, send): + class AwaitablePayload(object): # pragma: no cover + def __init__(self, payload): + self.payload = payload or b'' + + async def read(self, length=None): + if length is None: + r = self.payload + self.payload = b'' + else: + r = self.payload[:length] + self.payload = self.payload[length:] + return r + + event = await receive() + payload = b'' + if event['type'] == 'http.request': + payload += event.get('body') or b'' + while event.get('more_body'): + event = await receive() + if event['type'] == 'http.request': + payload += event.get('body') or b'' + elif event['type'] == 'websocket.connect': + await send({'type': 'websocket.accept'}) + else: + return {} + + raw_uri = scope['path'].encode('utf-8') + if 'query_string' in scope and scope['query_string']: + raw_uri += b'?' + scope['query_string'] + environ = { + 'wsgi.input': AwaitablePayload(payload), + 'wsgi.errors': sys.stderr, + 'wsgi.version': (1, 0), + 'wsgi.async': True, + 'wsgi.multithread': False, + 'wsgi.multiprocess': False, + 'wsgi.run_once': False, + 'SERVER_SOFTWARE': 'asgi', + 'REQUEST_METHOD': scope.get('method', 'GET'), + 'PATH_INFO': scope['path'], + 'QUERY_STRING': scope.get('query_string', b'').decode('utf-8'), + 'RAW_URI': raw_uri.decode('utf-8'), + 'SCRIPT_NAME': '', + 'SERVER_PROTOCOL': 'HTTP/1.1', + 'REMOTE_ADDR': '127.0.0.1', + 'REMOTE_PORT': '0', + 'SERVER_NAME': 'asgi', + 'SERVER_PORT': '0', + 'asgi.receive': receive, + 'asgi.send': send, + } + + for hdr_name, hdr_value in scope['headers']: + hdr_name = hdr_name.upper().decode('utf-8') + hdr_value = hdr_value.decode('utf-8') + if hdr_name == 'CONTENT-TYPE': + environ['CONTENT_TYPE'] = hdr_value + continue + elif hdr_name == 'CONTENT-LENGTH': + environ['CONTENT_LENGTH'] = hdr_value + continue + + key = 'HTTP_%s' % hdr_name.replace('-', '_') + if key in environ: + hdr_value = '%s,%s' % (environ[key], hdr_value) + + environ[key] = hdr_value + + environ['wsgi.url_scheme'] = environ.get('HTTP_X_FORWARDED_PROTO', 'http') + return environ + + +async def make_response(status, headers, payload, environ): + headers = [(h[0].encode('utf-8'), h[1].encode('utf-8')) for h in headers] + await environ['asgi.send']({'type': 'http.response.start', + 'status': int(status.split(' ')[0]), + 'headers': headers}) + await environ['asgi.send']({'type': 'http.response.body', + 'body': payload}) + + +class WebSocket(object): # pragma: no cover + """ + This wrapper class provides an asgi WebSocket interface that is + somewhat compatible with eventlet's implementation. + """ + def __init__(self, handler): + self.handler = handler + self.asgi_receive = None + self.asgi_send = None + + async def __call__(self, environ): + self.asgi_receive = environ['asgi.receive'] + self.asgi_send = environ['asgi.send'] + await self.handler(self) + + async def close(self): + await self.asgi_send({'type': 'websocket.close'}) + + async def send(self, message): + msg_bytes = None + msg_text = None + if isinstance(message, bytes): + msg_bytes = message + else: + msg_text = message + await self.asgi_send({'type': 'websocket.send', + 'bytes': msg_bytes, + 'text': msg_text}) + + async def wait(self): + event = await self.asgi_receive() + if event['type'] != 'websocket.receive': + raise IOError() + return event.get('bytes') or event.get('text') + + +_async = { + 'asyncio': True, + 'translate_request': translate_request, + 'make_response': make_response, + 'websocket': WebSocket, +} diff --git a/libs/engineio/async_drivers/eventlet.py b/libs/engineio/async_drivers/eventlet.py new file mode 100644 index 000000000..9be3797cd --- /dev/null +++ b/libs/engineio/async_drivers/eventlet.py @@ -0,0 +1,30 @@ +from __future__ import absolute_import + +from eventlet.green.threading import Thread, Event +from eventlet import queue +from eventlet import sleep +from eventlet.websocket import WebSocketWSGI as _WebSocketWSGI + + +class WebSocketWSGI(_WebSocketWSGI): + def __init__(self, *args, **kwargs): + super(WebSocketWSGI, self).__init__(*args, **kwargs) + self._sock = None + + def __call__(self, environ, start_response): + if 'eventlet.input' not in environ: + raise RuntimeError('You need to use the eventlet server. ' + 'See the Deployment section of the ' + 'documentation for more information.') + self._sock = environ['eventlet.input'].get_socket() + return super(WebSocketWSGI, self).__call__(environ, start_response) + + +_async = { + 'thread': Thread, + 'queue': queue.Queue, + 'queue_empty': queue.Empty, + 'event': Event, + 'websocket': WebSocketWSGI, + 'sleep': sleep, +} diff --git a/libs/engineio/async_drivers/gevent.py b/libs/engineio/async_drivers/gevent.py new file mode 100644 index 000000000..024dd0aad --- /dev/null +++ b/libs/engineio/async_drivers/gevent.py @@ -0,0 +1,63 @@ +from __future__ import absolute_import + +import gevent +from gevent import queue +from gevent.event import Event +try: + import geventwebsocket # noqa + _websocket_available = True +except ImportError: + _websocket_available = False + + +class Thread(gevent.Greenlet): # pragma: no cover + """ + This wrapper class provides gevent Greenlet interface that is compatible + with the standard library's Thread class. + """ + def __init__(self, target, args=[], kwargs={}): + super(Thread, self).__init__(target, *args, **kwargs) + + def _run(self): + return self.run() + + +class WebSocketWSGI(object): # pragma: no cover + """ + This wrapper class provides a gevent WebSocket interface that is + compatible with eventlet's implementation. + """ + def __init__(self, app): + self.app = app + + def __call__(self, environ, start_response): + if 'wsgi.websocket' not in environ: + raise RuntimeError('You need to use the gevent-websocket server. ' + 'See the Deployment section of the ' + 'documentation for more information.') + self._sock = environ['wsgi.websocket'] + self.environ = environ + self.version = self._sock.version + self.path = self._sock.path + self.origin = self._sock.origin + self.protocol = self._sock.protocol + return self.app(self) + + def close(self): + return self._sock.close() + + def send(self, message): + return self._sock.send(message) + + def wait(self): + return self._sock.receive() + + +_async = { + 'thread': Thread, + 'queue': queue.JoinableQueue, + 'queue_empty': queue.Empty, + 'event': Event, + 'websocket': WebSocketWSGI if _websocket_available else None, + 'sleep': gevent.sleep, +} diff --git a/libs/engineio/async_drivers/gevent_uwsgi.py b/libs/engineio/async_drivers/gevent_uwsgi.py new file mode 100644 index 000000000..07fa2a79d --- /dev/null +++ b/libs/engineio/async_drivers/gevent_uwsgi.py @@ -0,0 +1,156 @@ +from __future__ import absolute_import + +import six + +import gevent +from gevent import queue +from gevent.event import Event +import uwsgi +_websocket_available = hasattr(uwsgi, 'websocket_handshake') + + +class Thread(gevent.Greenlet): # pragma: no cover + """ + This wrapper class provides gevent Greenlet interface that is compatible + with the standard library's Thread class. + """ + def __init__(self, target, args=[], kwargs={}): + super(Thread, self).__init__(target, *args, **kwargs) + + def _run(self): + return self.run() + + +class uWSGIWebSocket(object): # pragma: no cover + """ + This wrapper class provides a uWSGI WebSocket interface that is + compatible with eventlet's implementation. + """ + def __init__(self, app): + self.app = app + self._sock = None + + def __call__(self, environ, start_response): + self._sock = uwsgi.connection_fd() + self.environ = environ + + uwsgi.websocket_handshake() + + self._req_ctx = None + if hasattr(uwsgi, 'request_context'): + # uWSGI >= 2.1.x with support for api access across-greenlets + self._req_ctx = uwsgi.request_context() + else: + # use event and queue for sending messages + from gevent.event import Event + from gevent.queue import Queue + from gevent.select import select + self._event = Event() + self._send_queue = Queue() + + # spawn a select greenlet + def select_greenlet_runner(fd, event): + """Sets event when data becomes available to read on fd.""" + while True: + event.set() + try: + select([fd], [], [])[0] + except ValueError: + break + self._select_greenlet = gevent.spawn( + select_greenlet_runner, + self._sock, + self._event) + + self.app(self) + + def close(self): + """Disconnects uWSGI from the client.""" + uwsgi.disconnect() + if self._req_ctx is None: + # better kill it here in case wait() is not called again + self._select_greenlet.kill() + self._event.set() + + def _send(self, msg): + """Transmits message either in binary or UTF-8 text mode, + depending on its type.""" + if isinstance(msg, six.binary_type): + method = uwsgi.websocket_send_binary + else: + method = uwsgi.websocket_send + if self._req_ctx is not None: + method(msg, request_context=self._req_ctx) + else: + method(msg) + + def _decode_received(self, msg): + """Returns either bytes or str, depending on message type.""" + if not isinstance(msg, six.binary_type): + # already decoded - do nothing + return msg + # only decode from utf-8 if message is not binary data + type = six.byte2int(msg[0:1]) + if type >= 48: # no binary + return msg.decode('utf-8') + # binary message, don't try to decode + return msg + + def send(self, msg): + """Queues a message for sending. Real transmission is done in + wait method. + Sends directly if uWSGI version is new enough.""" + if self._req_ctx is not None: + self._send(msg) + else: + self._send_queue.put(msg) + self._event.set() + + def wait(self): + """Waits and returns received messages. + If running in compatibility mode for older uWSGI versions, + it also sends messages that have been queued by send(). + A return value of None means that connection was closed. + This must be called repeatedly. For uWSGI < 2.1.x it must + be called from the main greenlet.""" + while True: + if self._req_ctx is not None: + try: + msg = uwsgi.websocket_recv(request_context=self._req_ctx) + except IOError: # connection closed + return None + return self._decode_received(msg) + else: + # we wake up at least every 3 seconds to let uWSGI + # do its ping/ponging + event_set = self._event.wait(timeout=3) + if event_set: + self._event.clear() + # maybe there is something to send + msgs = [] + while True: + try: + msgs.append(self._send_queue.get(block=False)) + except gevent.queue.Empty: + break + for msg in msgs: + self._send(msg) + # maybe there is something to receive, if not, at least + # ensure uWSGI does its ping/ponging + try: + msg = uwsgi.websocket_recv_nb() + except IOError: # connection closed + self._select_greenlet.kill() + return None + if msg: # message available + return self._decode_received(msg) + + +_async = { + 'thread': Thread, + 'queue': queue.JoinableQueue, + 'queue_empty': queue.Empty, + 'event': Event, + 'websocket': uWSGIWebSocket if _websocket_available else None, + 'sleep': gevent.sleep, +} diff --git a/libs/engineio/async_drivers/sanic.py b/libs/engineio/async_drivers/sanic.py new file mode 100644 index 000000000..6929654b9 --- /dev/null +++ b/libs/engineio/async_drivers/sanic.py @@ -0,0 +1,144 @@ +import sys +from urllib.parse import urlsplit + +from sanic.response import HTTPResponse +try: + from sanic.websocket import WebSocketProtocol +except ImportError: + # the installed version of sanic does not have websocket support + WebSocketProtocol = None +import six + + +def create_route(app, engineio_server, engineio_endpoint): + """This function sets up the engine.io endpoint as a route for the + application. + + Note that both GET and POST requests must be hooked up on the engine.io + endpoint. + """ + app.add_route(engineio_server.handle_request, engineio_endpoint, + methods=['GET', 'POST', 'OPTIONS']) + try: + app.enable_websocket() + except AttributeError: + # ignore, this version does not support websocket + pass + + +def translate_request(request): + """This function takes the arguments passed to the request handler and + uses them to generate a WSGI compatible environ dictionary. + """ + class AwaitablePayload(object): + def __init__(self, payload): + self.payload = payload or b'' + + async def read(self, length=None): + if length is None: + r = self.payload + self.payload = b'' + else: + r = self.payload[:length] + self.payload = self.payload[length:] + return r + + uri_parts = urlsplit(request.url) + environ = { + 'wsgi.input': AwaitablePayload(request.body), + 'wsgi.errors': sys.stderr, + 'wsgi.version': (1, 0), + 'wsgi.async': True, + 'wsgi.multithread': False, + 'wsgi.multiprocess': False, + 'wsgi.run_once': False, + 'SERVER_SOFTWARE': 'sanic', + 'REQUEST_METHOD': request.method, + 'QUERY_STRING': uri_parts.query or '', + 'RAW_URI': request.url, + 'SERVER_PROTOCOL': 'HTTP/' + request.version, + 'REMOTE_ADDR': '127.0.0.1', + 'REMOTE_PORT': '0', + 'SERVER_NAME': 'sanic', + 'SERVER_PORT': '0', + 'sanic.request': request + } + + for hdr_name, hdr_value in request.headers.items(): + hdr_name = hdr_name.upper() + if hdr_name == 'CONTENT-TYPE': + environ['CONTENT_TYPE'] = hdr_value + continue + elif hdr_name == 'CONTENT-LENGTH': + environ['CONTENT_LENGTH'] = hdr_value + continue + + key = 'HTTP_%s' % hdr_name.replace('-', '_') + if key in environ: + hdr_value = '%s,%s' % (environ[key], hdr_value) + + environ[key] = hdr_value + + environ['wsgi.url_scheme'] = environ.get('HTTP_X_FORWARDED_PROTO', 'http') + + path_info = uri_parts.path + + environ['PATH_INFO'] = path_info + environ['SCRIPT_NAME'] = '' + + return environ + + +def make_response(status, headers, payload, environ): + """This function generates an appropriate response object for this async + mode. + """ + headers_dict = {} + content_type = None + for h in headers: + if h[0].lower() == 'content-type': + content_type = h[1] + else: + headers_dict[h[0]] = h[1] + return HTTPResponse(body_bytes=payload, content_type=content_type, + status=int(status.split()[0]), headers=headers_dict) + + +class WebSocket(object): # pragma: no cover + """ + This wrapper class provides a sanic WebSocket interface that is + somewhat compatible with eventlet's implementation. + """ + def __init__(self, handler): + self.handler = handler + self._sock = None + + async def __call__(self, environ): + request = environ['sanic.request'] + protocol = request.transport.get_protocol() + self._sock = await protocol.websocket_handshake(request) + + self.environ = environ + await self.handler(self) + + async def close(self): + await self._sock.close() + + async def send(self, message): + await self._sock.send(message) + + async def wait(self): + data = await self._sock.recv() + if not isinstance(data, six.binary_type) and \ + not isinstance(data, six.text_type): + raise IOError() + return data + + +_async = { + 'asyncio': True, + 'create_route': create_route, + 'translate_request': translate_request, + 'make_response': make_response, + 'websocket': WebSocket if WebSocketProtocol else None, +} diff --git a/libs/engineio/async_drivers/threading.py b/libs/engineio/async_drivers/threading.py new file mode 100644 index 000000000..9b5375668 --- /dev/null +++ b/libs/engineio/async_drivers/threading.py @@ -0,0 +1,17 @@ +from __future__ import absolute_import +import threading +import time + +try: + import queue +except ImportError: # pragma: no cover + import Queue as queue + +_async = { + 'thread': threading.Thread, + 'queue': queue.Queue, + 'queue_empty': queue.Empty, + 'event': threading.Event, + 'websocket': None, + 'sleep': time.sleep, +} diff --git a/libs/engineio/async_drivers/tornado.py b/libs/engineio/async_drivers/tornado.py new file mode 100644 index 000000000..adfe18f5a --- /dev/null +++ b/libs/engineio/async_drivers/tornado.py @@ -0,0 +1,184 @@ +import asyncio +import sys +from urllib.parse import urlsplit +from .. import exceptions + +import tornado.web +import tornado.websocket +import six + + +def get_tornado_handler(engineio_server): + class Handler(tornado.websocket.WebSocketHandler): # pragma: no cover + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if isinstance(engineio_server.cors_allowed_origins, + six.string_types): + if engineio_server.cors_allowed_origins == '*': + self.allowed_origins = None + else: + self.allowed_origins = [ + engineio_server.cors_allowed_origins] + else: + self.allowed_origins = engineio_server.cors_allowed_origins + self.receive_queue = asyncio.Queue() + + async def get(self, *args, **kwargs): + if self.request.headers.get('Upgrade', '').lower() == 'websocket': + ret = super().get(*args, **kwargs) + if asyncio.iscoroutine(ret): + await ret + else: + await engineio_server.handle_request(self) + + async def open(self, *args, **kwargs): + # this is the handler for the websocket request + asyncio.ensure_future(engineio_server.handle_request(self)) + + async def post(self, *args, **kwargs): + await engineio_server.handle_request(self) + + async def options(self, *args, **kwargs): + await engineio_server.handle_request(self) + + async def on_message(self, message): + await self.receive_queue.put(message) + + async def get_next_message(self): + return await self.receive_queue.get() + + def on_close(self): + self.receive_queue.put_nowait(None) + + def check_origin(self, origin): + if self.allowed_origins is None or origin in self.allowed_origins: + return True + return super().check_origin(origin) + + def get_compression_options(self): + # enable compression + return {} + + return Handler + + +def translate_request(handler): + """This function takes the arguments passed to the request handler and + uses them to generate a WSGI compatible environ dictionary. + """ + class AwaitablePayload(object): + def __init__(self, payload): + self.payload = payload or b'' + + async def read(self, length=None): + if length is None: + r = self.payload + self.payload = b'' + else: + r = self.payload[:length] + self.payload = self.payload[length:] + return r + + payload = handler.request.body + + uri_parts = urlsplit(handler.request.path) + full_uri = handler.request.path + if handler.request.query: # pragma: no cover + full_uri += '?' + handler.request.query + environ = { + 'wsgi.input': AwaitablePayload(payload), + 'wsgi.errors': sys.stderr, + 'wsgi.version': (1, 0), + 'wsgi.async': True, + 'wsgi.multithread': False, + 'wsgi.multiprocess': False, + 'wsgi.run_once': False, + 'SERVER_SOFTWARE': 'aiohttp', + 'REQUEST_METHOD': handler.request.method, + 'QUERY_STRING': handler.request.query or '', + 'RAW_URI': full_uri, + 'SERVER_PROTOCOL': 'HTTP/%s' % handler.request.version, + 'REMOTE_ADDR': '127.0.0.1', + 'REMOTE_PORT': '0', + 'SERVER_NAME': 'aiohttp', + 'SERVER_PORT': '0', + 'tornado.handler': handler + } + + for hdr_name, hdr_value in handler.request.headers.items(): + hdr_name = hdr_name.upper() + if hdr_name == 'CONTENT-TYPE': + environ['CONTENT_TYPE'] = hdr_value + continue + elif hdr_name == 'CONTENT-LENGTH': + environ['CONTENT_LENGTH'] = hdr_value + continue + + key = 'HTTP_%s' % hdr_name.replace('-', '_') + environ[key] = hdr_value + + environ['wsgi.url_scheme'] = environ.get('HTTP_X_FORWARDED_PROTO', 'http') + + path_info = uri_parts.path + + environ['PATH_INFO'] = path_info + environ['SCRIPT_NAME'] = '' + + return environ + + +def make_response(status, headers, payload, environ): + """This function generates an appropriate response object for this async + mode. + """ + tornado_handler = environ['tornado.handler'] + try: + tornado_handler.set_status(int(status.split()[0])) + except RuntimeError: # pragma: no cover + # for websocket connections Tornado does not accept a response, since + # it already emitted the 101 status code + return + for header, value in headers: + tornado_handler.set_header(header, value) + tornado_handler.write(payload) + tornado_handler.finish() + + +class WebSocket(object): # pragma: no cover + """ + This wrapper class provides a tornado WebSocket interface that is + somewhat compatible with eventlet's implementation. + """ + def __init__(self, handler): + self.handler = handler + self.tornado_handler = None + + async def __call__(self, environ): + self.tornado_handler = environ['tornado.handler'] + self.environ = environ + await self.handler(self) + + async def close(self): + self.tornado_handler.close() + + async def send(self, message): + try: + self.tornado_handler.write_message( + message, binary=isinstance(message, bytes)) + except tornado.websocket.WebSocketClosedError: + raise exceptions.EngineIOError() + + async def wait(self): + msg = await self.tornado_handler.get_next_message() + if not isinstance(msg, six.binary_type) and \ + not isinstance(msg, six.text_type): + raise IOError() + return msg + + +_async = { + 'asyncio': True, + 'translate_request': translate_request, + 'make_response': make_response, + 'websocket': WebSocket, +} diff --git a/libs/engineio/asyncio_client.py b/libs/engineio/asyncio_client.py new file mode 100644 index 000000000..049b4bd95 --- /dev/null +++ b/libs/engineio/asyncio_client.py @@ -0,0 +1,585 @@ +import asyncio +import ssl + +try: + import aiohttp +except ImportError: # pragma: no cover + aiohttp = None +import six + +from . import client +from . import exceptions +from . import packet +from . import payload + + +class AsyncClient(client.Client): + """An Engine.IO client for asyncio. + + This class implements a fully compliant Engine.IO web client with support + for websocket and long-polling transports, compatible with the asyncio + framework on Python 3.5 or newer. + + :param logger: To enable logging set to ``True`` or pass a logger object to + use. To disable logging set to ``False``. The default is + ``False``. + :param json: An alternative json module to use for encoding and decoding + packets. Custom json modules must have ``dumps`` and ``loads`` + functions that are compatible with the standard library + versions. + :param request_timeout: A timeout in seconds for requests. The default is + 5 seconds. + :param ssl_verify: ``True`` to verify SSL certificates, or ``False`` to + skip SSL certificate verification, allowing + connections to servers with self signed certificates. + The default is ``True``. + """ + def is_asyncio_based(self): + return True + + async def connect(self, url, headers={}, transports=None, + engineio_path='engine.io'): + """Connect to an Engine.IO server. + + :param url: The URL of the Engine.IO server. It can include custom + query string parameters if required by the server. + :param headers: A dictionary with custom headers to send with the + connection request. + :param transports: The list of allowed transports. Valid transports + are ``'polling'`` and ``'websocket'``. If not + given, the polling transport is connected first, + then an upgrade to websocket is attempted. + :param engineio_path: The endpoint where the Engine.IO server is + installed. The default value is appropriate for + most cases. + + Note: this method is a coroutine. + + Example usage:: + + eio = engineio.Client() + await eio.connect('http://localhost:5000') + """ + if self.state != 'disconnected': + raise ValueError('Client is not in a disconnected state') + valid_transports = ['polling', 'websocket'] + if transports is not None: + if isinstance(transports, six.text_type): + transports = [transports] + transports = [transport for transport in transports + if transport in valid_transports] + if not transports: + raise ValueError('No valid transports provided') + self.transports = transports or valid_transports + self.queue = self.create_queue() + return await getattr(self, '_connect_' + self.transports[0])( + url, headers, engineio_path) + + async def wait(self): + """Wait until the connection with the server ends. + + Client applications can use this function to block the main thread + during the life of the connection. + + Note: this method is a coroutine. + """ + if self.read_loop_task: + await self.read_loop_task + + async def send(self, data, binary=None): + """Send a message to a client. + + :param data: The data to send to the client. Data can be of type + ``str``, ``bytes``, ``list`` or ``dict``. If a ``list`` + or ``dict``, the data will be serialized as JSON. + :param binary: ``True`` to send packet as binary, ``False`` to send + as text. If not given, unicode (Python 2) and str + (Python 3) are sent as text, and str (Python 2) and + bytes (Python 3) are sent as binary. + + Note: this method is a coroutine. + """ + await self._send_packet(packet.Packet(packet.MESSAGE, data=data, + binary=binary)) + + async def disconnect(self, abort=False): + """Disconnect from the server. + + :param abort: If set to ``True``, do not wait for background tasks + associated with the connection to end. + + Note: this method is a coroutine. + """ + if self.state == 'connected': + await self._send_packet(packet.Packet(packet.CLOSE)) + await self.queue.put(None) + self.state = 'disconnecting' + await self._trigger_event('disconnect', run_async=False) + if self.current_transport == 'websocket': + await self.ws.close() + if not abort: + await self.read_loop_task + self.state = 'disconnected' + try: + client.connected_clients.remove(self) + except ValueError: # pragma: no cover + pass + self._reset() + + def start_background_task(self, target, *args, **kwargs): + """Start a background task. + + This is a utility function that applications can use to start a + background task. + + :param target: the target function to execute. + :param args: arguments to pass to the function. + :param kwargs: keyword arguments to pass to the function. + + This function returns an object compatible with the `Thread` class in + the Python standard library. The `start()` method on this object is + already called by this function. + + Note: this method is a coroutine. + """ + return asyncio.ensure_future(target(*args, **kwargs)) + + async def sleep(self, seconds=0): + """Sleep for the requested amount of time. + + Note: this method is a coroutine. + """ + return await asyncio.sleep(seconds) + + def create_queue(self): + """Create a queue object.""" + q = asyncio.Queue() + q.Empty = asyncio.QueueEmpty + return q + + def create_event(self): + """Create an event object.""" + return asyncio.Event() + + def _reset(self): + if self.http: # pragma: no cover + asyncio.ensure_future(self.http.close()) + super()._reset() + + async def _connect_polling(self, url, headers, engineio_path): + """Establish a long-polling connection to the Engine.IO server.""" + if aiohttp is None: # pragma: no cover + self.logger.error('aiohttp not installed -- cannot make HTTP ' + 'requests!') + return + self.base_url = self._get_engineio_url(url, engineio_path, 'polling') + self.logger.info('Attempting polling connection to ' + self.base_url) + r = await self._send_request( + 'GET', self.base_url + self._get_url_timestamp(), headers=headers, + timeout=self.request_timeout) + if r is None: + self._reset() + raise exceptions.ConnectionError( + 'Connection refused by the server') + if r.status < 200 or r.status >= 300: + raise exceptions.ConnectionError( + 'Unexpected status code {} in server response'.format( + r.status)) + try: + p = payload.Payload(encoded_payload=await r.read()) + except ValueError: + six.raise_from(exceptions.ConnectionError( + 'Unexpected response from server'), None) + open_packet = p.packets[0] + if open_packet.packet_type != packet.OPEN: + raise exceptions.ConnectionError( + 'OPEN packet not returned by server') + self.logger.info( + 'Polling connection accepted with ' + str(open_packet.data)) + self.sid = open_packet.data['sid'] + self.upgrades = open_packet.data['upgrades'] + self.ping_interval = open_packet.data['pingInterval'] / 1000.0 + self.ping_timeout = open_packet.data['pingTimeout'] / 1000.0 + self.current_transport = 'polling' + self.base_url += '&sid=' + self.sid + + self.state = 'connected' + client.connected_clients.append(self) + await self._trigger_event('connect', run_async=False) + + for pkt in p.packets[1:]: + await self._receive_packet(pkt) + + if 'websocket' in self.upgrades and 'websocket' in self.transports: + # attempt to upgrade to websocket + if await self._connect_websocket(url, headers, engineio_path): + # upgrade to websocket succeeded, we're done here + return + + self.ping_loop_task = self.start_background_task(self._ping_loop) + self.write_loop_task = self.start_background_task(self._write_loop) + self.read_loop_task = self.start_background_task( + self._read_loop_polling) + + async def _connect_websocket(self, url, headers, engineio_path): + """Establish or upgrade to a WebSocket connection with the server.""" + if aiohttp is None: # pragma: no cover + self.logger.error('aiohttp package not installed') + return False + websocket_url = self._get_engineio_url(url, engineio_path, + 'websocket') + if self.sid: + self.logger.info( + 'Attempting WebSocket upgrade to ' + websocket_url) + upgrade = True + websocket_url += '&sid=' + self.sid + else: + upgrade = False + self.base_url = websocket_url + self.logger.info( + 'Attempting WebSocket connection to ' + websocket_url) + + if self.http is None or self.http.closed: # pragma: no cover + self.http = aiohttp.ClientSession() + + try: + if not self.ssl_verify: + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + ws = await self.http.ws_connect( + websocket_url + self._get_url_timestamp(), + headers=headers, ssl=ssl_context) + else: + ws = await self.http.ws_connect( + websocket_url + self._get_url_timestamp(), + headers=headers) + except (aiohttp.client_exceptions.WSServerHandshakeError, + aiohttp.client_exceptions.ServerConnectionError): + if upgrade: + self.logger.warning( + 'WebSocket upgrade failed: connection error') + return False + else: + raise exceptions.ConnectionError('Connection error') + if upgrade: + p = packet.Packet(packet.PING, data='probe').encode( + always_bytes=False) + try: + await ws.send_str(p) + except Exception as e: # pragma: no cover + self.logger.warning( + 'WebSocket upgrade failed: unexpected send exception: %s', + str(e)) + return False + try: + p = (await ws.receive()).data + except Exception as e: # pragma: no cover + self.logger.warning( + 'WebSocket upgrade failed: unexpected recv exception: %s', + str(e)) + return False + pkt = packet.Packet(encoded_packet=p) + if pkt.packet_type != packet.PONG or pkt.data != 'probe': + self.logger.warning( + 'WebSocket upgrade failed: no PONG packet') + return False + p = packet.Packet(packet.UPGRADE).encode(always_bytes=False) + try: + await ws.send_str(p) + except Exception as e: # pragma: no cover + self.logger.warning( + 'WebSocket upgrade failed: unexpected send exception: %s', + str(e)) + return False + self.current_transport = 'websocket' + self.logger.info('WebSocket upgrade was successful') + else: + try: + p = (await ws.receive()).data + except Exception as e: # pragma: no cover + raise exceptions.ConnectionError( + 'Unexpected recv exception: ' + str(e)) + open_packet = packet.Packet(encoded_packet=p) + if open_packet.packet_type != packet.OPEN: + raise exceptions.ConnectionError('no OPEN packet') + self.logger.info( + 'WebSocket connection accepted with ' + str(open_packet.data)) + self.sid = open_packet.data['sid'] + self.upgrades = open_packet.data['upgrades'] + self.ping_interval = open_packet.data['pingInterval'] / 1000.0 + self.ping_timeout = open_packet.data['pingTimeout'] / 1000.0 + self.current_transport = 'websocket' + + self.state = 'connected' + client.connected_clients.append(self) + await self._trigger_event('connect', run_async=False) + + self.ws = ws + self.ping_loop_task = self.start_background_task(self._ping_loop) + self.write_loop_task = self.start_background_task(self._write_loop) + self.read_loop_task = self.start_background_task( + self._read_loop_websocket) + return True + + async def _receive_packet(self, pkt): + """Handle incoming packets from the server.""" + packet_name = packet.packet_names[pkt.packet_type] \ + if pkt.packet_type < len(packet.packet_names) else 'UNKNOWN' + self.logger.info( + 'Received packet %s data %s', packet_name, + pkt.data if not isinstance(pkt.data, bytes) else '<binary>') + if pkt.packet_type == packet.MESSAGE: + await self._trigger_event('message', pkt.data, run_async=True) + elif pkt.packet_type == packet.PONG: + self.pong_received = True + elif pkt.packet_type == packet.CLOSE: + await self.disconnect(abort=True) + elif pkt.packet_type == packet.NOOP: + pass + else: + self.logger.error('Received unexpected packet of type %s', + pkt.packet_type) + + async def _send_packet(self, pkt): + """Queue a packet to be sent to the server.""" + if self.state != 'connected': + return + await self.queue.put(pkt) + self.logger.info( + 'Sending packet %s data %s', + packet.packet_names[pkt.packet_type], + pkt.data if not isinstance(pkt.data, bytes) else '<binary>') + + async def _send_request( + self, method, url, headers=None, body=None, + timeout=None): # pragma: no cover + if self.http is None or self.http.closed: + self.http = aiohttp.ClientSession() + http_method = getattr(self.http, method.lower()) + + try: + if not self.ssl_verify: + return await http_method( + url, headers=headers, data=body, + timeout=aiohttp.ClientTimeout(total=timeout), ssl=False) + else: + return await http_method( + url, headers=headers, data=body, + timeout=aiohttp.ClientTimeout(total=timeout)) + + except (aiohttp.ClientError, asyncio.TimeoutError) as exc: + self.logger.info('HTTP %s request to %s failed with error %s.', + method, url, exc) + + async def _trigger_event(self, event, *args, **kwargs): + """Invoke an event handler.""" + run_async = kwargs.pop('run_async', False) + ret = None + if event in self.handlers: + if asyncio.iscoroutinefunction(self.handlers[event]) is True: + if run_async: + return self.start_background_task(self.handlers[event], + *args) + else: + try: + ret = await self.handlers[event](*args) + except asyncio.CancelledError: # pragma: no cover + pass + except: + self.logger.exception(event + ' async handler error') + if event == 'connect': + # if connect handler raised error we reject the + # connection + return False + else: + if run_async: + async def async_handler(): + return self.handlers[event](*args) + + return self.start_background_task(async_handler) + else: + try: + ret = self.handlers[event](*args) + except: + self.logger.exception(event + ' handler error') + if event == 'connect': + # if connect handler raised error we reject the + # connection + return False + return ret + + async def _ping_loop(self): + """This background task sends a PING to the server at the requested + interval. + """ + self.pong_received = True + if self.ping_loop_event is None: + self.ping_loop_event = self.create_event() + else: + self.ping_loop_event.clear() + while self.state == 'connected': + if not self.pong_received: + self.logger.info( + 'PONG response has not been received, aborting') + if self.ws: + await self.ws.close() + await self.queue.put(None) + break + self.pong_received = False + await self._send_packet(packet.Packet(packet.PING)) + try: + await asyncio.wait_for(self.ping_loop_event.wait(), + self.ping_interval) + except (asyncio.TimeoutError, + asyncio.CancelledError): # pragma: no cover + pass + self.logger.info('Exiting ping task') + + async def _read_loop_polling(self): + """Read packets by polling the Engine.IO server.""" + while self.state == 'connected': + self.logger.info( + 'Sending polling GET request to ' + self.base_url) + r = await self._send_request( + 'GET', self.base_url + self._get_url_timestamp(), + timeout=max(self.ping_interval, self.ping_timeout) + 5) + if r is None: + self.logger.warning( + 'Connection refused by the server, aborting') + await self.queue.put(None) + break + if r.status < 200 or r.status >= 300: + self.logger.warning('Unexpected status code %s in server ' + 'response, aborting', r.status) + await self.queue.put(None) + break + try: + p = payload.Payload(encoded_payload=await r.read()) + except ValueError: + self.logger.warning( + 'Unexpected packet from server, aborting') + await self.queue.put(None) + break + for pkt in p.packets: + await self._receive_packet(pkt) + + self.logger.info('Waiting for write loop task to end') + await self.write_loop_task + self.logger.info('Waiting for ping loop task to end') + if self.ping_loop_event: # pragma: no cover + self.ping_loop_event.set() + await self.ping_loop_task + if self.state == 'connected': + await self._trigger_event('disconnect', run_async=False) + try: + client.connected_clients.remove(self) + except ValueError: # pragma: no cover + pass + self._reset() + self.logger.info('Exiting read loop task') + + async def _read_loop_websocket(self): + """Read packets from the Engine.IO WebSocket connection.""" + while self.state == 'connected': + p = None + try: + p = (await self.ws.receive()).data + if p is None: # pragma: no cover + raise RuntimeError('WebSocket read returned None') + except aiohttp.client_exceptions.ServerDisconnectedError: + self.logger.info( + 'Read loop: WebSocket connection was closed, aborting') + await self.queue.put(None) + break + except Exception as e: + self.logger.info( + 'Unexpected error "%s", aborting', str(e)) + await self.queue.put(None) + break + if isinstance(p, six.text_type): # pragma: no cover + p = p.encode('utf-8') + pkt = packet.Packet(encoded_packet=p) + await self._receive_packet(pkt) + + self.logger.info('Waiting for write loop task to end') + await self.write_loop_task + self.logger.info('Waiting for ping loop task to end') + if self.ping_loop_event: # pragma: no cover + self.ping_loop_event.set() + await self.ping_loop_task + if self.state == 'connected': + await self._trigger_event('disconnect', run_async=False) + try: + client.connected_clients.remove(self) + except ValueError: # pragma: no cover + pass + self._reset() + self.logger.info('Exiting read loop task') + + async def _write_loop(self): + """This background task sends packages to the server as they are + pushed to the send queue. + """ + while self.state == 'connected': + # to simplify the timeout handling, use the maximum of the + # ping interval and ping timeout as timeout, with an extra 5 + # seconds grace period + timeout = max(self.ping_interval, self.ping_timeout) + 5 + packets = None + try: + packets = [await asyncio.wait_for(self.queue.get(), timeout)] + except (self.queue.Empty, asyncio.TimeoutError, + asyncio.CancelledError): + self.logger.error('packet queue is empty, aborting') + break + if packets == [None]: + self.queue.task_done() + packets = [] + else: + while True: + try: + packets.append(self.queue.get_nowait()) + except self.queue.Empty: + break + if packets[-1] is None: + packets = packets[:-1] + self.queue.task_done() + break + if not packets: + # empty packet list returned -> connection closed + break + if self.current_transport == 'polling': + p = payload.Payload(packets=packets) + r = await self._send_request( + 'POST', self.base_url, body=p.encode(), + headers={'Content-Type': 'application/octet-stream'}, + timeout=self.request_timeout) + for pkt in packets: + self.queue.task_done() + if r is None: + self.logger.warning( + 'Connection refused by the server, aborting') + break + if r.status < 200 or r.status >= 300: + self.logger.warning('Unexpected status code %s in server ' + 'response, aborting', r.status) + self._reset() + break + else: + # websocket + try: + for pkt in packets: + if pkt.binary: + await self.ws.send_bytes(pkt.encode( + always_bytes=False)) + else: + await self.ws.send_str(pkt.encode( + always_bytes=False)) + self.queue.task_done() + except aiohttp.client_exceptions.ServerDisconnectedError: + self.logger.info( + 'Write loop: WebSocket connection was closed, ' + 'aborting') + break + self.logger.info('Exiting write loop task') diff --git a/libs/engineio/asyncio_server.py b/libs/engineio/asyncio_server.py new file mode 100644 index 000000000..d52b556db --- /dev/null +++ b/libs/engineio/asyncio_server.py @@ -0,0 +1,472 @@ +import asyncio + +import six +from six.moves import urllib + +from . import exceptions +from . import packet +from . import server +from . import asyncio_socket + + +class AsyncServer(server.Server): + """An Engine.IO server for asyncio. + + This class implements a fully compliant Engine.IO web server with support + for websocket and long-polling transports, compatible with the asyncio + framework on Python 3.5 or newer. + + :param async_mode: The asynchronous model to use. See the Deployment + section in the documentation for a description of the + available options. Valid async modes are "aiohttp", + "sanic", "tornado" and "asgi". If this argument is not + given, "aiohttp" is tried first, followed by "sanic", + "tornado", and finally "asgi". The first async mode that + has all its dependencies installed is the one that is + chosen. + :param ping_timeout: The time in seconds that the client waits for the + server to respond before disconnecting. + :param ping_interval: The interval in seconds at which the client pings + the server. The default is 25 seconds. For advanced + control, a two element tuple can be given, where + the first number is the ping interval and the second + is a grace period added by the server. The default + grace period is 5 seconds. + :param max_http_buffer_size: The maximum size of a message when using the + polling transport. + :param allow_upgrades: Whether to allow transport upgrades or not. + :param http_compression: Whether to compress packages when using the + polling transport. + :param compression_threshold: Only compress messages when their byte size + is greater than this value. + :param cookie: Name of the HTTP cookie that contains the client session + id. If set to ``None``, a cookie is not sent to the client. + :param cors_allowed_origins: Origin or list of origins that are allowed to + connect to this server. Only the same origin + is allowed by default. Set this argument to + ``'*'`` to allow all origins, or to ``[]`` to + disable CORS handling. + :param cors_credentials: Whether credentials (cookies, authentication) are + allowed in requests to this server. + :param logger: To enable logging set to ``True`` or pass a logger object to + use. To disable logging set to ``False``. + :param json: An alternative json module to use for encoding and decoding + packets. Custom json modules must have ``dumps`` and ``loads`` + functions that are compatible with the standard library + versions. + :param async_handlers: If set to ``True``, run message event handlers in + non-blocking threads. To run handlers synchronously, + set to ``False``. The default is ``True``. + :param kwargs: Reserved for future extensions, any additional parameters + given as keyword arguments will be silently ignored. + """ + def is_asyncio_based(self): + return True + + def async_modes(self): + return ['aiohttp', 'sanic', 'tornado', 'asgi'] + + def attach(self, app, engineio_path='engine.io'): + """Attach the Engine.IO server to an application.""" + engineio_path = engineio_path.strip('/') + self._async['create_route'](app, self, '/{}/'.format(engineio_path)) + + async def send(self, sid, data, binary=None): + """Send a message to a client. + + :param sid: The session id of the recipient client. + :param data: The data to send to the client. Data can be of type + ``str``, ``bytes``, ``list`` or ``dict``. If a ``list`` + or ``dict``, the data will be serialized as JSON. + :param binary: ``True`` to send packet as binary, ``False`` to send + as text. If not given, unicode (Python 2) and str + (Python 3) are sent as text, and str (Python 2) and + bytes (Python 3) are sent as binary. + + Note: this method is a coroutine. + """ + try: + socket = self._get_socket(sid) + except KeyError: + # the socket is not available + self.logger.warning('Cannot send to sid %s', sid) + return + await socket.send(packet.Packet(packet.MESSAGE, data=data, + binary=binary)) + + async def get_session(self, sid): + """Return the user session for a client. + + :param sid: The session id of the client. + + The return value is a dictionary. Modifications made to this + dictionary are not guaranteed to be preserved. If you want to modify + the user session, use the ``session`` context manager instead. + """ + socket = self._get_socket(sid) + return socket.session + + async def save_session(self, sid, session): + """Store the user session for a client. + + :param sid: The session id of the client. + :param session: The session dictionary. + """ + socket = self._get_socket(sid) + socket.session = session + + def session(self, sid): + """Return the user session for a client with context manager syntax. + + :param sid: The session id of the client. + + This is a context manager that returns the user session dictionary for + the client. Any changes that are made to this dictionary inside the + context manager block are saved back to the session. Example usage:: + + @eio.on('connect') + def on_connect(sid, environ): + username = authenticate_user(environ) + if not username: + return False + with eio.session(sid) as session: + session['username'] = username + + @eio.on('message') + def on_message(sid, msg): + async with eio.session(sid) as session: + print('received message from ', session['username']) + """ + class _session_context_manager(object): + def __init__(self, server, sid): + self.server = server + self.sid = sid + self.session = None + + async def __aenter__(self): + self.session = await self.server.get_session(sid) + return self.session + + async def __aexit__(self, *args): + await self.server.save_session(sid, self.session) + + return _session_context_manager(self, sid) + + async def disconnect(self, sid=None): + """Disconnect a client. + + :param sid: The session id of the client to close. If this parameter + is not given, then all clients are closed. + + Note: this method is a coroutine. + """ + if sid is not None: + try: + socket = self._get_socket(sid) + except KeyError: # pragma: no cover + # the socket was already closed or gone + pass + else: + await socket.close() + if sid in self.sockets: # pragma: no cover + del self.sockets[sid] + else: + await asyncio.wait([client.close() + for client in six.itervalues(self.sockets)]) + self.sockets = {} + + async def handle_request(self, *args, **kwargs): + """Handle an HTTP request from the client. + + This is the entry point of the Engine.IO application. This function + returns the HTTP response to deliver to the client. + + Note: this method is a coroutine. + """ + translate_request = self._async['translate_request'] + if asyncio.iscoroutinefunction(translate_request): + environ = await translate_request(*args, **kwargs) + else: + environ = translate_request(*args, **kwargs) + + if self.cors_allowed_origins != []: + # Validate the origin header if present + # This is important for WebSocket more than for HTTP, since + # browsers only apply CORS controls to HTTP. + origin = environ.get('HTTP_ORIGIN') + if origin: + allowed_origins = self._cors_allowed_origins(environ) + if allowed_origins is not None and origin not in \ + allowed_origins: + self.logger.info(origin + ' is not an accepted origin.') + r = self._bad_request() + make_response = self._async['make_response'] + if asyncio.iscoroutinefunction(make_response): + response = await make_response( + r['status'], r['headers'], r['response'], environ) + else: + response = make_response(r['status'], r['headers'], + r['response'], environ) + return response + + method = environ['REQUEST_METHOD'] + query = urllib.parse.parse_qs(environ.get('QUERY_STRING', '')) + + sid = query['sid'][0] if 'sid' in query else None + b64 = False + jsonp = False + jsonp_index = None + + if 'b64' in query: + if query['b64'][0] == "1" or query['b64'][0].lower() == "true": + b64 = True + if 'j' in query: + jsonp = True + try: + jsonp_index = int(query['j'][0]) + except (ValueError, KeyError, IndexError): + # Invalid JSONP index number + pass + + if jsonp and jsonp_index is None: + self.logger.warning('Invalid JSONP index number') + r = self._bad_request() + elif method == 'GET': + if sid is None: + transport = query.get('transport', ['polling'])[0] + if transport != 'polling' and transport != 'websocket': + self.logger.warning('Invalid transport %s', transport) + r = self._bad_request() + else: + r = await self._handle_connect(environ, transport, + b64, jsonp_index) + else: + if sid not in self.sockets: + self.logger.warning('Invalid session %s', sid) + r = self._bad_request() + else: + socket = self._get_socket(sid) + try: + packets = await socket.handle_get_request(environ) + if isinstance(packets, list): + r = self._ok(packets, b64=b64, + jsonp_index=jsonp_index) + else: + r = packets + except exceptions.EngineIOError: + if sid in self.sockets: # pragma: no cover + await self.disconnect(sid) + r = self._bad_request() + if sid in self.sockets and self.sockets[sid].closed: + del self.sockets[sid] + elif method == 'POST': + if sid is None or sid not in self.sockets: + self.logger.warning('Invalid session %s', sid) + r = self._bad_request() + else: + socket = self._get_socket(sid) + try: + await socket.handle_post_request(environ) + r = self._ok(jsonp_index=jsonp_index) + except exceptions.EngineIOError: + if sid in self.sockets: # pragma: no cover + await self.disconnect(sid) + r = self._bad_request() + except: # pragma: no cover + # for any other unexpected errors, we log the error + # and keep going + self.logger.exception('post request handler error') + r = self._ok(jsonp_index=jsonp_index) + elif method == 'OPTIONS': + r = self._ok() + else: + self.logger.warning('Method %s not supported', method) + r = self._method_not_found() + if not isinstance(r, dict): + return r + if self.http_compression and \ + len(r['response']) >= self.compression_threshold: + encodings = [e.split(';')[0].strip() for e in + environ.get('HTTP_ACCEPT_ENCODING', '').split(',')] + for encoding in encodings: + if encoding in self.compression_methods: + r['response'] = \ + getattr(self, '_' + encoding)(r['response']) + r['headers'] += [('Content-Encoding', encoding)] + break + cors_headers = self._cors_headers(environ) + make_response = self._async['make_response'] + if asyncio.iscoroutinefunction(make_response): + response = await make_response(r['status'], + r['headers'] + cors_headers, + r['response'], environ) + else: + response = make_response(r['status'], r['headers'] + cors_headers, + r['response'], environ) + return response + + def start_background_task(self, target, *args, **kwargs): + """Start a background task using the appropriate async model. + + This is a utility function that applications can use to start a + background task using the method that is compatible with the + selected async mode. + + :param target: the target function to execute. + :param args: arguments to pass to the function. + :param kwargs: keyword arguments to pass to the function. + + The return value is a ``asyncio.Task`` object. + """ + return asyncio.ensure_future(target(*args, **kwargs)) + + async def sleep(self, seconds=0): + """Sleep for the requested amount of time using the appropriate async + model. + + This is a utility function that applications can use to put a task to + sleep without having to worry about using the correct call for the + selected async mode. + + Note: this method is a coroutine. + """ + return await asyncio.sleep(seconds) + + def create_queue(self, *args, **kwargs): + """Create a queue object using the appropriate async model. + + This is a utility function that applications can use to create a queue + without having to worry about using the correct call for the selected + async mode. For asyncio based async modes, this returns an instance of + ``asyncio.Queue``. + """ + return asyncio.Queue(*args, **kwargs) + + def get_queue_empty_exception(self): + """Return the queue empty exception for the appropriate async model. + + This is a utility function that applications can use to work with a + queue without having to worry about using the correct call for the + selected async mode. For asyncio based async modes, this returns an + instance of ``asyncio.QueueEmpty``. + """ + return asyncio.QueueEmpty + + def create_event(self, *args, **kwargs): + """Create an event object using the appropriate async model. + + This is a utility function that applications can use to create an + event without having to worry about using the correct call for the + selected async mode. For asyncio based async modes, this returns + an instance of ``asyncio.Event``. + """ + return asyncio.Event(*args, **kwargs) + + async def _handle_connect(self, environ, transport, b64=False, + jsonp_index=None): + """Handle a client connection request.""" + if self.start_service_task: + # start the service task to monitor connected clients + self.start_service_task = False + self.start_background_task(self._service_task) + + sid = self._generate_id() + s = asyncio_socket.AsyncSocket(self, sid) + self.sockets[sid] = s + + pkt = packet.Packet( + packet.OPEN, {'sid': sid, + 'upgrades': self._upgrades(sid, transport), + 'pingTimeout': int(self.ping_timeout * 1000), + 'pingInterval': int(self.ping_interval * 1000)}) + await s.send(pkt) + + ret = await self._trigger_event('connect', sid, environ, + run_async=False) + if ret is False: + del self.sockets[sid] + self.logger.warning('Application rejected connection') + return self._unauthorized() + + if transport == 'websocket': + ret = await s.handle_get_request(environ) + if s.closed: + # websocket connection ended, so we are done + del self.sockets[sid] + return ret + else: + s.connected = True + headers = None + if self.cookie: + headers = [('Set-Cookie', self.cookie + '=' + sid)] + try: + return self._ok(await s.poll(), headers=headers, b64=b64, + jsonp_index=jsonp_index) + except exceptions.QueueEmpty: + return self._bad_request() + + async def _trigger_event(self, event, *args, **kwargs): + """Invoke an event handler.""" + run_async = kwargs.pop('run_async', False) + ret = None + if event in self.handlers: + if asyncio.iscoroutinefunction(self.handlers[event]) is True: + if run_async: + return self.start_background_task(self.handlers[event], + *args) + else: + try: + ret = await self.handlers[event](*args) + except asyncio.CancelledError: # pragma: no cover + pass + except: + self.logger.exception(event + ' async handler error') + if event == 'connect': + # if connect handler raised error we reject the + # connection + return False + else: + if run_async: + async def async_handler(): + return self.handlers[event](*args) + + return self.start_background_task(async_handler) + else: + try: + ret = self.handlers[event](*args) + except: + self.logger.exception(event + ' handler error') + if event == 'connect': + # if connect handler raised error we reject the + # connection + return False + return ret + + async def _service_task(self): # pragma: no cover + """Monitor connected clients and clean up those that time out.""" + while True: + if len(self.sockets) == 0: + # nothing to do + await self.sleep(self.ping_timeout) + continue + + # go through the entire client list in a ping interval cycle + sleep_interval = self.ping_timeout / len(self.sockets) + + try: + # iterate over the current clients + for socket in self.sockets.copy().values(): + if not socket.closing and not socket.closed: + await socket.check_ping_timeout() + await self.sleep(sleep_interval) + except (SystemExit, KeyboardInterrupt, asyncio.CancelledError): + self.logger.info('service task canceled') + break + except: + if asyncio.get_event_loop().is_closed(): + self.logger.info('event loop is closed, exiting service ' + 'task') + break + + # an unexpected exception has occurred, log it and continue + self.logger.exception('service task exception') diff --git a/libs/engineio/asyncio_socket.py b/libs/engineio/asyncio_socket.py new file mode 100644 index 000000000..7057a6cc3 --- /dev/null +++ b/libs/engineio/asyncio_socket.py @@ -0,0 +1,236 @@ +import asyncio +import six +import sys +import time + +from . import exceptions +from . import packet +from . import payload +from . import socket + + +class AsyncSocket(socket.Socket): + async def poll(self): + """Wait for packets to send to the client.""" + try: + packets = [await asyncio.wait_for(self.queue.get(), + self.server.ping_timeout)] + self.queue.task_done() + except (asyncio.TimeoutError, asyncio.CancelledError): + raise exceptions.QueueEmpty() + if packets == [None]: + return [] + try: + packets.append(self.queue.get_nowait()) + self.queue.task_done() + except asyncio.QueueEmpty: + pass + return packets + + async def receive(self, pkt): + """Receive packet from the client.""" + self.server.logger.info('%s: Received packet %s data %s', + self.sid, packet.packet_names[pkt.packet_type], + pkt.data if not isinstance(pkt.data, bytes) + else '<binary>') + if pkt.packet_type == packet.PING: + self.last_ping = time.time() + await self.send(packet.Packet(packet.PONG, pkt.data)) + elif pkt.packet_type == packet.MESSAGE: + await self.server._trigger_event( + 'message', self.sid, pkt.data, + run_async=self.server.async_handlers) + elif pkt.packet_type == packet.UPGRADE: + await self.send(packet.Packet(packet.NOOP)) + elif pkt.packet_type == packet.CLOSE: + await self.close(wait=False, abort=True) + else: + raise exceptions.UnknownPacketError() + + async def check_ping_timeout(self): + """Make sure the client is still sending pings. + + This helps detect disconnections for long-polling clients. + """ + if self.closed: + raise exceptions.SocketIsClosedError() + if time.time() - self.last_ping > self.server.ping_interval + \ + self.server.ping_interval_grace_period: + self.server.logger.info('%s: Client is gone, closing socket', + self.sid) + # Passing abort=False here will cause close() to write a + # CLOSE packet. This has the effect of updating half-open sockets + # to their correct state of disconnected + await self.close(wait=False, abort=False) + return False + return True + + async def send(self, pkt): + """Send a packet to the client.""" + if not await self.check_ping_timeout(): + return + if self.upgrading: + self.packet_backlog.append(pkt) + else: + await self.queue.put(pkt) + self.server.logger.info('%s: Sending packet %s data %s', + self.sid, packet.packet_names[pkt.packet_type], + pkt.data if not isinstance(pkt.data, bytes) + else '<binary>') + + async def handle_get_request(self, environ): + """Handle a long-polling GET request from the client.""" + connections = [ + s.strip() + for s in environ.get('HTTP_CONNECTION', '').lower().split(',')] + transport = environ.get('HTTP_UPGRADE', '').lower() + if 'upgrade' in connections and transport in self.upgrade_protocols: + self.server.logger.info('%s: Received request to upgrade to %s', + self.sid, transport) + return await getattr(self, '_upgrade_' + transport)(environ) + try: + packets = await self.poll() + except exceptions.QueueEmpty: + exc = sys.exc_info() + await self.close(wait=False) + six.reraise(*exc) + return packets + + async def handle_post_request(self, environ): + """Handle a long-polling POST request from the client.""" + length = int(environ.get('CONTENT_LENGTH', '0')) + if length > self.server.max_http_buffer_size: + raise exceptions.ContentTooLongError() + else: + body = await environ['wsgi.input'].read(length) + p = payload.Payload(encoded_payload=body) + for pkt in p.packets: + await self.receive(pkt) + + async def close(self, wait=True, abort=False): + """Close the socket connection.""" + if not self.closed and not self.closing: + self.closing = True + await self.server._trigger_event('disconnect', self.sid) + if not abort: + await self.send(packet.Packet(packet.CLOSE)) + self.closed = True + if wait: + await self.queue.join() + + async def _upgrade_websocket(self, environ): + """Upgrade the connection from polling to websocket.""" + if self.upgraded: + raise IOError('Socket has been upgraded already') + if self.server._async['websocket'] is None: + # the selected async mode does not support websocket + return self.server._bad_request() + ws = self.server._async['websocket'](self._websocket_handler) + return await ws(environ) + + async def _websocket_handler(self, ws): + """Engine.IO handler for websocket transport.""" + if self.connected: + # the socket was already connected, so this is an upgrade + self.upgrading = True # hold packet sends during the upgrade + + try: + pkt = await ws.wait() + except IOError: # pragma: no cover + return + decoded_pkt = packet.Packet(encoded_packet=pkt) + if decoded_pkt.packet_type != packet.PING or \ + decoded_pkt.data != 'probe': + self.server.logger.info( + '%s: Failed websocket upgrade, no PING packet', self.sid) + return + await ws.send(packet.Packet( + packet.PONG, + data=six.text_type('probe')).encode(always_bytes=False)) + await self.queue.put(packet.Packet(packet.NOOP)) # end poll + + try: + pkt = await ws.wait() + except IOError: # pragma: no cover + return + decoded_pkt = packet.Packet(encoded_packet=pkt) + if decoded_pkt.packet_type != packet.UPGRADE: + self.upgraded = False + self.server.logger.info( + ('%s: Failed websocket upgrade, expected UPGRADE packet, ' + 'received %s instead.'), + self.sid, pkt) + return + self.upgraded = True + + # flush any packets that were sent during the upgrade + for pkt in self.packet_backlog: + await self.queue.put(pkt) + self.packet_backlog = [] + self.upgrading = False + else: + self.connected = True + self.upgraded = True + + # start separate writer thread + async def writer(): + while True: + packets = None + try: + packets = await self.poll() + except exceptions.QueueEmpty: + break + if not packets: + # empty packet list returned -> connection closed + break + try: + for pkt in packets: + await ws.send(pkt.encode(always_bytes=False)) + except: + break + writer_task = asyncio.ensure_future(writer()) + + self.server.logger.info( + '%s: Upgrade to websocket successful', self.sid) + + while True: + p = None + wait_task = asyncio.ensure_future(ws.wait()) + try: + p = await asyncio.wait_for(wait_task, self.server.ping_timeout) + except asyncio.CancelledError: # pragma: no cover + # there is a bug (https://bugs.python.org/issue30508) in + # asyncio that causes a "Task exception never retrieved" error + # to appear when wait_task raises an exception before it gets + # cancelled. Calling wait_task.exception() prevents the error + # from being issued in Python 3.6, but causes other errors in + # other versions, so we run it with all errors suppressed and + # hope for the best. + try: + wait_task.exception() + except: + pass + break + except: + break + if p is None: + # connection closed by client + break + if isinstance(p, six.text_type): # pragma: no cover + p = p.encode('utf-8') + pkt = packet.Packet(encoded_packet=p) + try: + await self.receive(pkt) + except exceptions.UnknownPacketError: # pragma: no cover + pass + except exceptions.SocketIsClosedError: # pragma: no cover + self.server.logger.info('Receive error -- socket is closed') + break + except: # pragma: no cover + # if we get an unexpected exception we log the error and exit + # the connection properly + self.server.logger.exception('Unknown receive error') + + await self.queue.put(None) # unlock the writer task so it can exit + await asyncio.wait_for(writer_task, timeout=None) + await self.close(wait=False, abort=True) diff --git a/libs/engineio/client.py b/libs/engineio/client.py new file mode 100644 index 000000000..b5ab50377 --- /dev/null +++ b/libs/engineio/client.py @@ -0,0 +1,680 @@ +import logging +try: + import queue +except ImportError: # pragma: no cover + import Queue as queue +import signal +import ssl +import threading +import time + +import six +from six.moves import urllib +try: + import requests +except ImportError: # pragma: no cover + requests = None +try: + import websocket +except ImportError: # pragma: no cover + websocket = None +from . import exceptions +from . import packet +from . import payload + +default_logger = logging.getLogger('engineio.client') +connected_clients = [] + +if six.PY2: # pragma: no cover + ConnectionError = OSError + + +def signal_handler(sig, frame): + """SIGINT handler. + + Disconnect all active clients and then invoke the original signal handler. + """ + for client in connected_clients[:]: + if client.is_asyncio_based(): + client.start_background_task(client.disconnect, abort=True) + else: + client.disconnect(abort=True) + if callable(original_signal_handler): + return original_signal_handler(sig, frame) + else: # pragma: no cover + # Handle case where no original SIGINT handler was present. + return signal.default_int_handler(sig, frame) + + +original_signal_handler = None + + +class Client(object): + """An Engine.IO client. + + This class implements a fully compliant Engine.IO web client with support + for websocket and long-polling transports. + + :param logger: To enable logging set to ``True`` or pass a logger object to + use. To disable logging set to ``False``. The default is + ``False``. + :param json: An alternative json module to use for encoding and decoding + packets. Custom json modules must have ``dumps`` and ``loads`` + functions that are compatible with the standard library + versions. + :param request_timeout: A timeout in seconds for requests. The default is + 5 seconds. + :param ssl_verify: ``True`` to verify SSL certificates, or ``False`` to + skip SSL certificate verification, allowing + connections to servers with self signed certificates. + The default is ``True``. + """ + event_names = ['connect', 'disconnect', 'message'] + + def __init__(self, + logger=False, + json=None, + request_timeout=5, + ssl_verify=True): + global original_signal_handler + if original_signal_handler is None: + original_signal_handler = signal.signal(signal.SIGINT, + signal_handler) + self.handlers = {} + self.base_url = None + self.transports = None + self.current_transport = None + self.sid = None + self.upgrades = None + self.ping_interval = None + self.ping_timeout = None + self.pong_received = True + self.http = None + self.ws = None + self.read_loop_task = None + self.write_loop_task = None + self.ping_loop_task = None + self.ping_loop_event = None + self.queue = None + self.state = 'disconnected' + self.ssl_verify = ssl_verify + + if json is not None: + packet.Packet.json = json + if not isinstance(logger, bool): + self.logger = logger + else: + self.logger = default_logger + if not logging.root.handlers and \ + self.logger.level == logging.NOTSET: + if logger: + self.logger.setLevel(logging.INFO) + else: + self.logger.setLevel(logging.ERROR) + self.logger.addHandler(logging.StreamHandler()) + + self.request_timeout = request_timeout + + def is_asyncio_based(self): + return False + + def on(self, event, handler=None): + """Register an event handler. + + :param event: The event name. Can be ``'connect'``, ``'message'`` or + ``'disconnect'``. + :param handler: The function that should be invoked to handle the + event. When this parameter is not given, the method + acts as a decorator for the handler function. + + Example usage:: + + # as a decorator: + @eio.on('connect') + def connect_handler(): + print('Connection request') + + # as a method: + def message_handler(msg): + print('Received message: ', msg) + eio.send('response') + eio.on('message', message_handler) + """ + if event not in self.event_names: + raise ValueError('Invalid event') + + def set_handler(handler): + self.handlers[event] = handler + return handler + + if handler is None: + return set_handler + set_handler(handler) + + def connect(self, url, headers={}, transports=None, + engineio_path='engine.io'): + """Connect to an Engine.IO server. + + :param url: The URL of the Engine.IO server. It can include custom + query string parameters if required by the server. + :param headers: A dictionary with custom headers to send with the + connection request. + :param transports: The list of allowed transports. Valid transports + are ``'polling'`` and ``'websocket'``. If not + given, the polling transport is connected first, + then an upgrade to websocket is attempted. + :param engineio_path: The endpoint where the Engine.IO server is + installed. The default value is appropriate for + most cases. + + Example usage:: + + eio = engineio.Client() + eio.connect('http://localhost:5000') + """ + if self.state != 'disconnected': + raise ValueError('Client is not in a disconnected state') + valid_transports = ['polling', 'websocket'] + if transports is not None: + if isinstance(transports, six.string_types): + transports = [transports] + transports = [transport for transport in transports + if transport in valid_transports] + if not transports: + raise ValueError('No valid transports provided') + self.transports = transports or valid_transports + self.queue = self.create_queue() + return getattr(self, '_connect_' + self.transports[0])( + url, headers, engineio_path) + + def wait(self): + """Wait until the connection with the server ends. + + Client applications can use this function to block the main thread + during the life of the connection. + """ + if self.read_loop_task: + self.read_loop_task.join() + + def send(self, data, binary=None): + """Send a message to a client. + + :param data: The data to send to the client. Data can be of type + ``str``, ``bytes``, ``list`` or ``dict``. If a ``list`` + or ``dict``, the data will be serialized as JSON. + :param binary: ``True`` to send packet as binary, ``False`` to send + as text. If not given, unicode (Python 2) and str + (Python 3) are sent as text, and str (Python 2) and + bytes (Python 3) are sent as binary. + """ + self._send_packet(packet.Packet(packet.MESSAGE, data=data, + binary=binary)) + + def disconnect(self, abort=False): + """Disconnect from the server. + + :param abort: If set to ``True``, do not wait for background tasks + associated with the connection to end. + """ + if self.state == 'connected': + self._send_packet(packet.Packet(packet.CLOSE)) + self.queue.put(None) + self.state = 'disconnecting' + self._trigger_event('disconnect', run_async=False) + if self.current_transport == 'websocket': + self.ws.close() + if not abort: + self.read_loop_task.join() + self.state = 'disconnected' + try: + connected_clients.remove(self) + except ValueError: # pragma: no cover + pass + self._reset() + + def transport(self): + """Return the name of the transport currently in use. + + The possible values returned by this function are ``'polling'`` and + ``'websocket'``. + """ + return self.current_transport + + def start_background_task(self, target, *args, **kwargs): + """Start a background task. + + This is a utility function that applications can use to start a + background task. + + :param target: the target function to execute. + :param args: arguments to pass to the function. + :param kwargs: keyword arguments to pass to the function. + + This function returns an object compatible with the `Thread` class in + the Python standard library. The `start()` method on this object is + already called by this function. + """ + th = threading.Thread(target=target, args=args, kwargs=kwargs) + th.start() + return th + + def sleep(self, seconds=0): + """Sleep for the requested amount of time.""" + return time.sleep(seconds) + + def create_queue(self, *args, **kwargs): + """Create a queue object.""" + q = queue.Queue(*args, **kwargs) + q.Empty = queue.Empty + return q + + def create_event(self, *args, **kwargs): + """Create an event object.""" + return threading.Event(*args, **kwargs) + + def _reset(self): + self.state = 'disconnected' + self.sid = None + + def _connect_polling(self, url, headers, engineio_path): + """Establish a long-polling connection to the Engine.IO server.""" + if requests is None: # pragma: no cover + # not installed + self.logger.error('requests package is not installed -- cannot ' + 'send HTTP requests!') + return + self.base_url = self._get_engineio_url(url, engineio_path, 'polling') + self.logger.info('Attempting polling connection to ' + self.base_url) + r = self._send_request( + 'GET', self.base_url + self._get_url_timestamp(), headers=headers, + timeout=self.request_timeout) + if r is None: + self._reset() + raise exceptions.ConnectionError( + 'Connection refused by the server') + if r.status_code < 200 or r.status_code >= 300: + raise exceptions.ConnectionError( + 'Unexpected status code {} in server response'.format( + r.status_code)) + try: + p = payload.Payload(encoded_payload=r.content) + except ValueError: + six.raise_from(exceptions.ConnectionError( + 'Unexpected response from server'), None) + open_packet = p.packets[0] + if open_packet.packet_type != packet.OPEN: + raise exceptions.ConnectionError( + 'OPEN packet not returned by server') + self.logger.info( + 'Polling connection accepted with ' + str(open_packet.data)) + self.sid = open_packet.data['sid'] + self.upgrades = open_packet.data['upgrades'] + self.ping_interval = open_packet.data['pingInterval'] / 1000.0 + self.ping_timeout = open_packet.data['pingTimeout'] / 1000.0 + self.current_transport = 'polling' + self.base_url += '&sid=' + self.sid + + self.state = 'connected' + connected_clients.append(self) + self._trigger_event('connect', run_async=False) + + for pkt in p.packets[1:]: + self._receive_packet(pkt) + + if 'websocket' in self.upgrades and 'websocket' in self.transports: + # attempt to upgrade to websocket + if self._connect_websocket(url, headers, engineio_path): + # upgrade to websocket succeeded, we're done here + return + + # start background tasks associated with this client + self.ping_loop_task = self.start_background_task(self._ping_loop) + self.write_loop_task = self.start_background_task(self._write_loop) + self.read_loop_task = self.start_background_task( + self._read_loop_polling) + + def _connect_websocket(self, url, headers, engineio_path): + """Establish or upgrade to a WebSocket connection with the server.""" + if websocket is None: # pragma: no cover + # not installed + self.logger.warning('websocket-client package not installed, only ' + 'polling transport is available') + return False + websocket_url = self._get_engineio_url(url, engineio_path, 'websocket') + if self.sid: + self.logger.info( + 'Attempting WebSocket upgrade to ' + websocket_url) + upgrade = True + websocket_url += '&sid=' + self.sid + else: + upgrade = False + self.base_url = websocket_url + self.logger.info( + 'Attempting WebSocket connection to ' + websocket_url) + + # get the cookies from the long-polling connection so that they can + # also be sent the the WebSocket route + cookies = None + if self.http: + cookies = '; '.join(["{}={}".format(cookie.name, cookie.value) + for cookie in self.http.cookies]) + + try: + if not self.ssl_verify: + ws = websocket.create_connection( + websocket_url + self._get_url_timestamp(), header=headers, + cookie=cookies, sslopt={"cert_reqs": ssl.CERT_NONE}) + else: + ws = websocket.create_connection( + websocket_url + self._get_url_timestamp(), header=headers, + cookie=cookies) + except (ConnectionError, IOError, websocket.WebSocketException): + if upgrade: + self.logger.warning( + 'WebSocket upgrade failed: connection error') + return False + else: + raise exceptions.ConnectionError('Connection error') + if upgrade: + p = packet.Packet(packet.PING, + data=six.text_type('probe')).encode() + try: + ws.send(p) + except Exception as e: # pragma: no cover + self.logger.warning( + 'WebSocket upgrade failed: unexpected send exception: %s', + str(e)) + return False + try: + p = ws.recv() + except Exception as e: # pragma: no cover + self.logger.warning( + 'WebSocket upgrade failed: unexpected recv exception: %s', + str(e)) + return False + pkt = packet.Packet(encoded_packet=p) + if pkt.packet_type != packet.PONG or pkt.data != 'probe': + self.logger.warning( + 'WebSocket upgrade failed: no PONG packet') + return False + p = packet.Packet(packet.UPGRADE).encode() + try: + ws.send(p) + except Exception as e: # pragma: no cover + self.logger.warning( + 'WebSocket upgrade failed: unexpected send exception: %s', + str(e)) + return False + self.current_transport = 'websocket' + self.logger.info('WebSocket upgrade was successful') + else: + try: + p = ws.recv() + except Exception as e: # pragma: no cover + raise exceptions.ConnectionError( + 'Unexpected recv exception: ' + str(e)) + open_packet = packet.Packet(encoded_packet=p) + if open_packet.packet_type != packet.OPEN: + raise exceptions.ConnectionError('no OPEN packet') + self.logger.info( + 'WebSocket connection accepted with ' + str(open_packet.data)) + self.sid = open_packet.data['sid'] + self.upgrades = open_packet.data['upgrades'] + self.ping_interval = open_packet.data['pingInterval'] / 1000.0 + self.ping_timeout = open_packet.data['pingTimeout'] / 1000.0 + self.current_transport = 'websocket' + + self.state = 'connected' + connected_clients.append(self) + self._trigger_event('connect', run_async=False) + self.ws = ws + + # start background tasks associated with this client + self.ping_loop_task = self.start_background_task(self._ping_loop) + self.write_loop_task = self.start_background_task(self._write_loop) + self.read_loop_task = self.start_background_task( + self._read_loop_websocket) + return True + + def _receive_packet(self, pkt): + """Handle incoming packets from the server.""" + packet_name = packet.packet_names[pkt.packet_type] \ + if pkt.packet_type < len(packet.packet_names) else 'UNKNOWN' + self.logger.info( + 'Received packet %s data %s', packet_name, + pkt.data if not isinstance(pkt.data, bytes) else '<binary>') + if pkt.packet_type == packet.MESSAGE: + self._trigger_event('message', pkt.data, run_async=True) + elif pkt.packet_type == packet.PONG: + self.pong_received = True + elif pkt.packet_type == packet.CLOSE: + self.disconnect(abort=True) + elif pkt.packet_type == packet.NOOP: + pass + else: + self.logger.error('Received unexpected packet of type %s', + pkt.packet_type) + + def _send_packet(self, pkt): + """Queue a packet to be sent to the server.""" + if self.state != 'connected': + return + self.queue.put(pkt) + self.logger.info( + 'Sending packet %s data %s', + packet.packet_names[pkt.packet_type], + pkt.data if not isinstance(pkt.data, bytes) else '<binary>') + + def _send_request( + self, method, url, headers=None, body=None, + timeout=None): # pragma: no cover + if self.http is None: + self.http = requests.Session() + try: + return self.http.request(method, url, headers=headers, data=body, + timeout=timeout, verify=self.ssl_verify) + except requests.exceptions.RequestException as exc: + self.logger.info('HTTP %s request to %s failed with error %s.', + method, url, exc) + + def _trigger_event(self, event, *args, **kwargs): + """Invoke an event handler.""" + run_async = kwargs.pop('run_async', False) + if event in self.handlers: + if run_async: + return self.start_background_task(self.handlers[event], *args) + else: + try: + return self.handlers[event](*args) + except: + self.logger.exception(event + ' handler error') + + def _get_engineio_url(self, url, engineio_path, transport): + """Generate the Engine.IO connection URL.""" + engineio_path = engineio_path.strip('/') + parsed_url = urllib.parse.urlparse(url) + + if transport == 'polling': + scheme = 'http' + elif transport == 'websocket': + scheme = 'ws' + else: # pragma: no cover + raise ValueError('invalid transport') + if parsed_url.scheme in ['https', 'wss']: + scheme += 's' + + return ('{scheme}://{netloc}/{path}/?{query}' + '{sep}transport={transport}&EIO=3').format( + scheme=scheme, netloc=parsed_url.netloc, + path=engineio_path, query=parsed_url.query, + sep='&' if parsed_url.query else '', + transport=transport) + + def _get_url_timestamp(self): + """Generate the Engine.IO query string timestamp.""" + return '&t=' + str(time.time()) + + def _ping_loop(self): + """This background task sends a PING to the server at the requested + interval. + """ + self.pong_received = True + if self.ping_loop_event is None: + self.ping_loop_event = self.create_event() + else: + self.ping_loop_event.clear() + while self.state == 'connected': + if not self.pong_received: + self.logger.info( + 'PONG response has not been received, aborting') + if self.ws: + self.ws.close(timeout=0) + self.queue.put(None) + break + self.pong_received = False + self._send_packet(packet.Packet(packet.PING)) + self.ping_loop_event.wait(timeout=self.ping_interval) + self.logger.info('Exiting ping task') + + def _read_loop_polling(self): + """Read packets by polling the Engine.IO server.""" + while self.state == 'connected': + self.logger.info( + 'Sending polling GET request to ' + self.base_url) + r = self._send_request( + 'GET', self.base_url + self._get_url_timestamp(), + timeout=max(self.ping_interval, self.ping_timeout) + 5) + if r is None: + self.logger.warning( + 'Connection refused by the server, aborting') + self.queue.put(None) + break + if r.status_code < 200 or r.status_code >= 300: + self.logger.warning('Unexpected status code %s in server ' + 'response, aborting', r.status_code) + self.queue.put(None) + break + try: + p = payload.Payload(encoded_payload=r.content) + except ValueError: + self.logger.warning( + 'Unexpected packet from server, aborting') + self.queue.put(None) + break + for pkt in p.packets: + self._receive_packet(pkt) + + self.logger.info('Waiting for write loop task to end') + self.write_loop_task.join() + self.logger.info('Waiting for ping loop task to end') + if self.ping_loop_event: # pragma: no cover + self.ping_loop_event.set() + self.ping_loop_task.join() + if self.state == 'connected': + self._trigger_event('disconnect', run_async=False) + try: + connected_clients.remove(self) + except ValueError: # pragma: no cover + pass + self._reset() + self.logger.info('Exiting read loop task') + + def _read_loop_websocket(self): + """Read packets from the Engine.IO WebSocket connection.""" + while self.state == 'connected': + p = None + try: + p = self.ws.recv() + except websocket.WebSocketConnectionClosedException: + self.logger.warning( + 'WebSocket connection was closed, aborting') + self.queue.put(None) + break + except Exception as e: + self.logger.info( + 'Unexpected error "%s", aborting', str(e)) + self.queue.put(None) + break + if isinstance(p, six.text_type): # pragma: no cover + p = p.encode('utf-8') + pkt = packet.Packet(encoded_packet=p) + self._receive_packet(pkt) + + self.logger.info('Waiting for write loop task to end') + self.write_loop_task.join() + self.logger.info('Waiting for ping loop task to end') + if self.ping_loop_event: # pragma: no cover + self.ping_loop_event.set() + self.ping_loop_task.join() + if self.state == 'connected': + self._trigger_event('disconnect', run_async=False) + try: + connected_clients.remove(self) + except ValueError: # pragma: no cover + pass + self._reset() + self.logger.info('Exiting read loop task') + + def _write_loop(self): + """This background task sends packages to the server as they are + pushed to the send queue. + """ + while self.state == 'connected': + # to simplify the timeout handling, use the maximum of the + # ping interval and ping timeout as timeout, with an extra 5 + # seconds grace period + timeout = max(self.ping_interval, self.ping_timeout) + 5 + packets = None + try: + packets = [self.queue.get(timeout=timeout)] + except self.queue.Empty: + self.logger.error('packet queue is empty, aborting') + break + if packets == [None]: + self.queue.task_done() + packets = [] + else: + while True: + try: + packets.append(self.queue.get(block=False)) + except self.queue.Empty: + break + if packets[-1] is None: + packets = packets[:-1] + self.queue.task_done() + break + if not packets: + # empty packet list returned -> connection closed + break + if self.current_transport == 'polling': + p = payload.Payload(packets=packets) + r = self._send_request( + 'POST', self.base_url, body=p.encode(), + headers={'Content-Type': 'application/octet-stream'}, + timeout=self.request_timeout) + for pkt in packets: + self.queue.task_done() + if r is None: + self.logger.warning( + 'Connection refused by the server, aborting') + break + if r.status_code < 200 or r.status_code >= 300: + self.logger.warning('Unexpected status code %s in server ' + 'response, aborting', r.status_code) + self._reset() + break + else: + # websocket + try: + for pkt in packets: + encoded_packet = pkt.encode(always_bytes=False) + if pkt.binary: + self.ws.send_binary(encoded_packet) + else: + self.ws.send(encoded_packet) + self.queue.task_done() + except websocket.WebSocketConnectionClosedException: + self.logger.warning( + 'WebSocket connection was closed, aborting') + break + self.logger.info('Exiting write loop task') diff --git a/libs/engineio/exceptions.py b/libs/engineio/exceptions.py new file mode 100644 index 000000000..fb0b3e057 --- /dev/null +++ b/libs/engineio/exceptions.py @@ -0,0 +1,22 @@ +class EngineIOError(Exception): + pass + + +class ContentTooLongError(EngineIOError): + pass + + +class UnknownPacketError(EngineIOError): + pass + + +class QueueEmpty(EngineIOError): + pass + + +class SocketIsClosedError(EngineIOError): + pass + + +class ConnectionError(EngineIOError): + pass diff --git a/libs/engineio/middleware.py b/libs/engineio/middleware.py new file mode 100644 index 000000000..d0bdcc747 --- /dev/null +++ b/libs/engineio/middleware.py @@ -0,0 +1,87 @@ +import os +from engineio.static_files import get_static_file + + +class WSGIApp(object): + """WSGI application middleware for Engine.IO. + + This middleware dispatches traffic to an Engine.IO application. It can + also serve a list of static files to the client, or forward unrelated + HTTP traffic to another WSGI application. + + :param engineio_app: The Engine.IO server. Must be an instance of the + ``engineio.Server`` class. + :param wsgi_app: The WSGI app that receives all other traffic. + :param static_files: A dictionary with static file mapping rules. See the + documentation for details on this argument. + :param engineio_path: The endpoint where the Engine.IO application should + be installed. The default value is appropriate for + most cases. + + Example usage:: + + import engineio + import eventlet + + eio = engineio.Server() + app = engineio.WSGIApp(eio, static_files={ + '/': {'content_type': 'text/html', 'filename': 'index.html'}, + '/index.html': {'content_type': 'text/html', + 'filename': 'index.html'}, + }) + eventlet.wsgi.server(eventlet.listen(('', 8000)), app) + """ + def __init__(self, engineio_app, wsgi_app=None, static_files=None, + engineio_path='engine.io'): + self.engineio_app = engineio_app + self.wsgi_app = wsgi_app + self.engineio_path = engineio_path.strip('/') + self.static_files = static_files or {} + + def __call__(self, environ, start_response): + if 'gunicorn.socket' in environ: + # gunicorn saves the socket under environ['gunicorn.socket'], while + # eventlet saves it under environ['eventlet.input']. Eventlet also + # stores the socket inside a wrapper class, while gunicon writes it + # directly into the environment. To give eventlet's WebSocket + # module access to this socket when running under gunicorn, here we + # copy the socket to the eventlet format. + class Input(object): + def __init__(self, socket): + self.socket = socket + + def get_socket(self): + return self.socket + + environ['eventlet.input'] = Input(environ['gunicorn.socket']) + path = environ['PATH_INFO'] + if path is not None and \ + path.startswith('/{0}/'.format(self.engineio_path)): + return self.engineio_app.handle_request(environ, start_response) + else: + static_file = get_static_file(path, self.static_files) \ + if self.static_files else None + if static_file: + if os.path.exists(static_file['filename']): + start_response( + '200 OK', + [('Content-Type', static_file['content_type'])]) + with open(static_file['filename'], 'rb') as f: + return [f.read()] + else: + return self.not_found(start_response) + elif self.wsgi_app is not None: + return self.wsgi_app(environ, start_response) + return self.not_found(start_response) + + def not_found(self, start_response): + start_response("404 Not Found", [('Content-Type', 'text/plain')]) + return [b'Not Found'] + + +class Middleware(WSGIApp): + """This class has been renamed to ``WSGIApp`` and is now deprecated.""" + def __init__(self, engineio_app, wsgi_app=None, + engineio_path='engine.io'): + super(Middleware, self).__init__(engineio_app, wsgi_app, + engineio_path=engineio_path) diff --git a/libs/engineio/packet.py b/libs/engineio/packet.py new file mode 100644 index 000000000..a3aa6d476 --- /dev/null +++ b/libs/engineio/packet.py @@ -0,0 +1,92 @@ +import base64 +import json as _json + +import six + +(OPEN, CLOSE, PING, PONG, MESSAGE, UPGRADE, NOOP) = (0, 1, 2, 3, 4, 5, 6) +packet_names = ['OPEN', 'CLOSE', 'PING', 'PONG', 'MESSAGE', 'UPGRADE', 'NOOP'] + +binary_types = (six.binary_type, bytearray) + + +class Packet(object): + """Engine.IO packet.""" + + json = _json + + def __init__(self, packet_type=NOOP, data=None, binary=None, + encoded_packet=None): + self.packet_type = packet_type + self.data = data + if binary is not None: + self.binary = binary + elif isinstance(data, six.text_type): + self.binary = False + elif isinstance(data, binary_types): + self.binary = True + else: + self.binary = False + if encoded_packet: + self.decode(encoded_packet) + + def encode(self, b64=False, always_bytes=True): + """Encode the packet for transmission.""" + if self.binary and not b64: + encoded_packet = six.int2byte(self.packet_type) + else: + encoded_packet = six.text_type(self.packet_type) + if self.binary and b64: + encoded_packet = 'b' + encoded_packet + if self.binary: + if b64: + encoded_packet += base64.b64encode(self.data).decode('utf-8') + else: + encoded_packet += self.data + elif isinstance(self.data, six.string_types): + encoded_packet += self.data + elif isinstance(self.data, dict) or isinstance(self.data, list): + encoded_packet += self.json.dumps(self.data, + separators=(',', ':')) + elif self.data is not None: + encoded_packet += str(self.data) + if always_bytes and not isinstance(encoded_packet, binary_types): + encoded_packet = encoded_packet.encode('utf-8') + return encoded_packet + + def decode(self, encoded_packet): + """Decode a transmitted package.""" + b64 = False + if not isinstance(encoded_packet, binary_types): + encoded_packet = encoded_packet.encode('utf-8') + elif not isinstance(encoded_packet, bytes): + encoded_packet = bytes(encoded_packet) + self.packet_type = six.byte2int(encoded_packet[0:1]) + if self.packet_type == 98: # 'b' --> binary base64 encoded packet + self.binary = True + encoded_packet = encoded_packet[1:] + self.packet_type = six.byte2int(encoded_packet[0:1]) + self.packet_type -= 48 + b64 = True + elif self.packet_type >= 48: + self.packet_type -= 48 + self.binary = False + else: + self.binary = True + self.data = None + if len(encoded_packet) > 1: + if self.binary: + if b64: + self.data = base64.b64decode(encoded_packet[1:]) + else: + self.data = encoded_packet[1:] + else: + try: + self.data = self.json.loads( + encoded_packet[1:].decode('utf-8')) + if isinstance(self.data, int): + # do not allow integer payloads, see + # github.com/miguelgrinberg/python-engineio/issues/75 + # for background on this decision + raise ValueError + except ValueError: + self.data = encoded_packet[1:].decode('utf-8') diff --git a/libs/engineio/payload.py b/libs/engineio/payload.py new file mode 100644 index 000000000..fbf9cbd27 --- /dev/null +++ b/libs/engineio/payload.py @@ -0,0 +1,81 @@ +import six + +from . import packet + +from six.moves import urllib + + +class Payload(object): + """Engine.IO payload.""" + max_decode_packets = 16 + + def __init__(self, packets=None, encoded_payload=None): + self.packets = packets or [] + if encoded_payload is not None: + self.decode(encoded_payload) + + def encode(self, b64=False, jsonp_index=None): + """Encode the payload for transmission.""" + encoded_payload = b'' + for pkt in self.packets: + encoded_packet = pkt.encode(b64=b64) + packet_len = len(encoded_packet) + if b64: + encoded_payload += str(packet_len).encode('utf-8') + b':' + \ + encoded_packet + else: + binary_len = b'' + while packet_len != 0: + binary_len = six.int2byte(packet_len % 10) + binary_len + packet_len = int(packet_len / 10) + if not pkt.binary: + encoded_payload += b'\0' + else: + encoded_payload += b'\1' + encoded_payload += binary_len + b'\xff' + encoded_packet + if jsonp_index is not None: + encoded_payload = b'___eio[' + \ + str(jsonp_index).encode() + \ + b']("' + \ + encoded_payload.replace(b'"', b'\\"') + \ + b'");' + return encoded_payload + + def decode(self, encoded_payload): + """Decode a transmitted payload.""" + self.packets = [] + + if len(encoded_payload) == 0: + return + + # JSONP POST payload starts with 'd=' + if encoded_payload.startswith(b'd='): + encoded_payload = urllib.parse.parse_qs( + encoded_payload)[b'd'][0] + + i = 0 + if six.byte2int(encoded_payload[0:1]) <= 1: + # binary encoding + while i < len(encoded_payload): + if len(self.packets) >= self.max_decode_packets: + raise ValueError('Too many packets in payload') + packet_len = 0 + i += 1 + while six.byte2int(encoded_payload[i:i + 1]) != 255: + packet_len = packet_len * 10 + six.byte2int( + encoded_payload[i:i + 1]) + i += 1 + self.packets.append(packet.Packet( + encoded_packet=encoded_payload[i + 1:i + 1 + packet_len])) + i += packet_len + 1 + else: + # assume text encoding + encoded_payload = encoded_payload.decode('utf-8') + while i < len(encoded_payload): + if len(self.packets) >= self.max_decode_packets: + raise ValueError('Too many packets in payload') + j = encoded_payload.find(':', i) + packet_len = int(encoded_payload[i:j]) + pkt = encoded_payload[j + 1:j + 1 + packet_len] + self.packets.append(packet.Packet(encoded_packet=pkt)) + i = j + 1 + packet_len diff --git a/libs/engineio/server.py b/libs/engineio/server.py new file mode 100644 index 000000000..e1543c2dc --- /dev/null +++ b/libs/engineio/server.py @@ -0,0 +1,675 @@ +import gzip +import importlib +import logging +import uuid +import zlib + +import six +from six.moves import urllib + +from . import exceptions +from . import packet +from . import payload +from . import socket + +default_logger = logging.getLogger('engineio.server') + + +class Server(object): + """An Engine.IO server. + + This class implements a fully compliant Engine.IO web server with support + for websocket and long-polling transports. + + :param async_mode: The asynchronous model to use. See the Deployment + section in the documentation for a description of the + available options. Valid async modes are "threading", + "eventlet", "gevent" and "gevent_uwsgi". If this + argument is not given, "eventlet" is tried first, then + "gevent_uwsgi", then "gevent", and finally "threading". + The first async mode that has all its dependencies + installed is the one that is chosen. + :param ping_timeout: The time in seconds that the client waits for the + server to respond before disconnecting. The default + is 60 seconds. + :param ping_interval: The interval in seconds at which the client pings + the server. The default is 25 seconds. For advanced + control, a two element tuple can be given, where + the first number is the ping interval and the second + is a grace period added by the server. The default + grace period is 5 seconds. + :param max_http_buffer_size: The maximum size of a message when using the + polling transport. The default is 100,000,000 + bytes. + :param allow_upgrades: Whether to allow transport upgrades or not. The + default is ``True``. + :param http_compression: Whether to compress packages when using the + polling transport. The default is ``True``. + :param compression_threshold: Only compress messages when their byte size + is greater than this value. The default is + 1024 bytes. + :param cookie: Name of the HTTP cookie that contains the client session + id. If set to ``None``, a cookie is not sent to the client. + The default is ``'io'``. + :param cors_allowed_origins: Origin or list of origins that are allowed to + connect to this server. Only the same origin + is allowed by default. Set this argument to + ``'*'`` to allow all origins, or to ``[]`` to + disable CORS handling. + :param cors_credentials: Whether credentials (cookies, authentication) are + allowed in requests to this server. The default + is ``True``. + :param logger: To enable logging set to ``True`` or pass a logger object to + use. To disable logging set to ``False``. The default is + ``False``. + :param json: An alternative json module to use for encoding and decoding + packets. Custom json modules must have ``dumps`` and ``loads`` + functions that are compatible with the standard library + versions. + :param async_handlers: If set to ``True``, run message event handlers in + non-blocking threads. To run handlers synchronously, + set to ``False``. The default is ``True``. + :param monitor_clients: If set to ``True``, a background task will ensure + inactive clients are closed. Set to ``False`` to + disable the monitoring task (not recommended). The + default is ``True``. + :param kwargs: Reserved for future extensions, any additional parameters + given as keyword arguments will be silently ignored. + """ + compression_methods = ['gzip', 'deflate'] + event_names = ['connect', 'disconnect', 'message'] + _default_monitor_clients = True + + def __init__(self, async_mode=None, ping_timeout=60, ping_interval=25, + max_http_buffer_size=100000000, allow_upgrades=True, + http_compression=True, compression_threshold=1024, + cookie='io', cors_allowed_origins=None, + cors_credentials=True, logger=False, json=None, + async_handlers=True, monitor_clients=None, **kwargs): + self.ping_timeout = ping_timeout + if isinstance(ping_interval, tuple): + self.ping_interval = ping_interval[0] + self.ping_interval_grace_period = ping_interval[1] + else: + self.ping_interval = ping_interval + self.ping_interval_grace_period = 5 + self.max_http_buffer_size = max_http_buffer_size + self.allow_upgrades = allow_upgrades + self.http_compression = http_compression + self.compression_threshold = compression_threshold + self.cookie = cookie + self.cors_allowed_origins = cors_allowed_origins + self.cors_credentials = cors_credentials + self.async_handlers = async_handlers + self.sockets = {} + self.handlers = {} + self.start_service_task = monitor_clients \ + if monitor_clients is not None else self._default_monitor_clients + if json is not None: + packet.Packet.json = json + if not isinstance(logger, bool): + self.logger = logger + else: + self.logger = default_logger + if not logging.root.handlers and \ + self.logger.level == logging.NOTSET: + if logger: + self.logger.setLevel(logging.INFO) + else: + self.logger.setLevel(logging.ERROR) + self.logger.addHandler(logging.StreamHandler()) + modes = self.async_modes() + if async_mode is not None: + modes = [async_mode] if async_mode in modes else [] + self._async = None + self.async_mode = None + for mode in modes: + try: + self._async = importlib.import_module( + 'engineio.async_drivers.' + mode)._async + asyncio_based = self._async['asyncio'] \ + if 'asyncio' in self._async else False + if asyncio_based != self.is_asyncio_based(): + continue # pragma: no cover + self.async_mode = mode + break + except ImportError: + pass + if self.async_mode is None: + raise ValueError('Invalid async_mode specified') + if self.is_asyncio_based() and \ + ('asyncio' not in self._async or not + self._async['asyncio']): # pragma: no cover + raise ValueError('The selected async_mode is not asyncio ' + 'compatible') + if not self.is_asyncio_based() and 'asyncio' in self._async and \ + self._async['asyncio']: # pragma: no cover + raise ValueError('The selected async_mode requires asyncio and ' + 'must use the AsyncServer class') + self.logger.info('Server initialized for %s.', self.async_mode) + + def is_asyncio_based(self): + return False + + def async_modes(self): + return ['eventlet', 'gevent_uwsgi', 'gevent', 'threading'] + + def on(self, event, handler=None): + """Register an event handler. + + :param event: The event name. Can be ``'connect'``, ``'message'`` or + ``'disconnect'``. + :param handler: The function that should be invoked to handle the + event. When this parameter is not given, the method + acts as a decorator for the handler function. + + Example usage:: + + # as a decorator: + @eio.on('connect') + def connect_handler(sid, environ): + print('Connection request') + if environ['REMOTE_ADDR'] in blacklisted: + return False # reject + + # as a method: + def message_handler(sid, msg): + print('Received message: ', msg) + eio.send(sid, 'response') + eio.on('message', message_handler) + + The handler function receives the ``sid`` (session ID) for the + client as first argument. The ``'connect'`` event handler receives the + WSGI environment as a second argument, and can return ``False`` to + reject the connection. The ``'message'`` handler receives the message + payload as a second argument. The ``'disconnect'`` handler does not + take a second argument. + """ + if event not in self.event_names: + raise ValueError('Invalid event') + + def set_handler(handler): + self.handlers[event] = handler + return handler + + if handler is None: + return set_handler + set_handler(handler) + + def send(self, sid, data, binary=None): + """Send a message to a client. + + :param sid: The session id of the recipient client. + :param data: The data to send to the client. Data can be of type + ``str``, ``bytes``, ``list`` or ``dict``. If a ``list`` + or ``dict``, the data will be serialized as JSON. + :param binary: ``True`` to send packet as binary, ``False`` to send + as text. If not given, unicode (Python 2) and str + (Python 3) are sent as text, and str (Python 2) and + bytes (Python 3) are sent as binary. + """ + try: + socket = self._get_socket(sid) + except KeyError: + # the socket is not available + self.logger.warning('Cannot send to sid %s', sid) + return + socket.send(packet.Packet(packet.MESSAGE, data=data, binary=binary)) + + def get_session(self, sid): + """Return the user session for a client. + + :param sid: The session id of the client. + + The return value is a dictionary. Modifications made to this + dictionary are not guaranteed to be preserved unless + ``save_session()`` is called, or when the ``session`` context manager + is used. + """ + socket = self._get_socket(sid) + return socket.session + + def save_session(self, sid, session): + """Store the user session for a client. + + :param sid: The session id of the client. + :param session: The session dictionary. + """ + socket = self._get_socket(sid) + socket.session = session + + def session(self, sid): + """Return the user session for a client with context manager syntax. + + :param sid: The session id of the client. + + This is a context manager that returns the user session dictionary for + the client. Any changes that are made to this dictionary inside the + context manager block are saved back to the session. Example usage:: + + @eio.on('connect') + def on_connect(sid, environ): + username = authenticate_user(environ) + if not username: + return False + with eio.session(sid) as session: + session['username'] = username + + @eio.on('message') + def on_message(sid, msg): + with eio.session(sid) as session: + print('received message from ', session['username']) + """ + class _session_context_manager(object): + def __init__(self, server, sid): + self.server = server + self.sid = sid + self.session = None + + def __enter__(self): + self.session = self.server.get_session(sid) + return self.session + + def __exit__(self, *args): + self.server.save_session(sid, self.session) + + return _session_context_manager(self, sid) + + def disconnect(self, sid=None): + """Disconnect a client. + + :param sid: The session id of the client to close. If this parameter + is not given, then all clients are closed. + """ + if sid is not None: + try: + socket = self._get_socket(sid) + except KeyError: # pragma: no cover + # the socket was already closed or gone + pass + else: + socket.close() + if sid in self.sockets: # pragma: no cover + del self.sockets[sid] + else: + for client in six.itervalues(self.sockets): + client.close() + self.sockets = {} + + def transport(self, sid): + """Return the name of the transport used by the client. + + The two possible values returned by this function are ``'polling'`` + and ``'websocket'``. + + :param sid: The session of the client. + """ + return 'websocket' if self._get_socket(sid).upgraded else 'polling' + + def handle_request(self, environ, start_response): + """Handle an HTTP request from the client. + + This is the entry point of the Engine.IO application, using the same + interface as a WSGI application. For the typical usage, this function + is invoked by the :class:`Middleware` instance, but it can be invoked + directly when the middleware is not used. + + :param environ: The WSGI environment. + :param start_response: The WSGI ``start_response`` function. + + This function returns the HTTP response body to deliver to the client + as a byte sequence. + """ + if self.cors_allowed_origins != []: + # Validate the origin header if present + # This is important for WebSocket more than for HTTP, since + # browsers only apply CORS controls to HTTP. + origin = environ.get('HTTP_ORIGIN') + if origin: + allowed_origins = self._cors_allowed_origins(environ) + if allowed_origins is not None and origin not in \ + allowed_origins: + self.logger.info(origin + ' is not an accepted origin.') + r = self._bad_request() + start_response(r['status'], r['headers']) + return [r['response']] + + method = environ['REQUEST_METHOD'] + query = urllib.parse.parse_qs(environ.get('QUERY_STRING', '')) + + sid = query['sid'][0] if 'sid' in query else None + b64 = False + jsonp = False + jsonp_index = None + + if 'b64' in query: + if query['b64'][0] == "1" or query['b64'][0].lower() == "true": + b64 = True + if 'j' in query: + jsonp = True + try: + jsonp_index = int(query['j'][0]) + except (ValueError, KeyError, IndexError): + # Invalid JSONP index number + pass + + if jsonp and jsonp_index is None: + self.logger.warning('Invalid JSONP index number') + r = self._bad_request() + elif method == 'GET': + if sid is None: + transport = query.get('transport', ['polling'])[0] + if transport != 'polling' and transport != 'websocket': + self.logger.warning('Invalid transport %s', transport) + r = self._bad_request() + else: + r = self._handle_connect(environ, start_response, + transport, b64, jsonp_index) + else: + if sid not in self.sockets: + self.logger.warning('Invalid session %s', sid) + r = self._bad_request() + else: + socket = self._get_socket(sid) + try: + packets = socket.handle_get_request( + environ, start_response) + if isinstance(packets, list): + r = self._ok(packets, b64=b64, + jsonp_index=jsonp_index) + else: + r = packets + except exceptions.EngineIOError: + if sid in self.sockets: # pragma: no cover + self.disconnect(sid) + r = self._bad_request() + if sid in self.sockets and self.sockets[sid].closed: + del self.sockets[sid] + elif method == 'POST': + if sid is None or sid not in self.sockets: + self.logger.warning('Invalid session %s', sid) + r = self._bad_request() + else: + socket = self._get_socket(sid) + try: + socket.handle_post_request(environ) + r = self._ok(jsonp_index=jsonp_index) + except exceptions.EngineIOError: + if sid in self.sockets: # pragma: no cover + self.disconnect(sid) + r = self._bad_request() + except: # pragma: no cover + # for any other unexpected errors, we log the error + # and keep going + self.logger.exception('post request handler error') + r = self._ok(jsonp_index=jsonp_index) + elif method == 'OPTIONS': + r = self._ok() + else: + self.logger.warning('Method %s not supported', method) + r = self._method_not_found() + + if not isinstance(r, dict): + return r or [] + if self.http_compression and \ + len(r['response']) >= self.compression_threshold: + encodings = [e.split(';')[0].strip() for e in + environ.get('HTTP_ACCEPT_ENCODING', '').split(',')] + for encoding in encodings: + if encoding in self.compression_methods: + r['response'] = \ + getattr(self, '_' + encoding)(r['response']) + r['headers'] += [('Content-Encoding', encoding)] + break + cors_headers = self._cors_headers(environ) + start_response(r['status'], r['headers'] + cors_headers) + return [r['response']] + + def start_background_task(self, target, *args, **kwargs): + """Start a background task using the appropriate async model. + + This is a utility function that applications can use to start a + background task using the method that is compatible with the + selected async mode. + + :param target: the target function to execute. + :param args: arguments to pass to the function. + :param kwargs: keyword arguments to pass to the function. + + This function returns an object compatible with the `Thread` class in + the Python standard library. The `start()` method on this object is + already called by this function. + """ + th = self._async['thread'](target=target, args=args, kwargs=kwargs) + th.start() + return th # pragma: no cover + + def sleep(self, seconds=0): + """Sleep for the requested amount of time using the appropriate async + model. + + This is a utility function that applications can use to put a task to + sleep without having to worry about using the correct call for the + selected async mode. + """ + return self._async['sleep'](seconds) + + def create_queue(self, *args, **kwargs): + """Create a queue object using the appropriate async model. + + This is a utility function that applications can use to create a queue + without having to worry about using the correct call for the selected + async mode. + """ + return self._async['queue'](*args, **kwargs) + + def get_queue_empty_exception(self): + """Return the queue empty exception for the appropriate async model. + + This is a utility function that applications can use to work with a + queue without having to worry about using the correct call for the + selected async mode. + """ + return self._async['queue_empty'] + + def create_event(self, *args, **kwargs): + """Create an event object using the appropriate async model. + + This is a utility function that applications can use to create an + event without having to worry about using the correct call for the + selected async mode. + """ + return self._async['event'](*args, **kwargs) + + def _generate_id(self): + """Generate a unique session id.""" + return uuid.uuid4().hex + + def _handle_connect(self, environ, start_response, transport, b64=False, + jsonp_index=None): + """Handle a client connection request.""" + if self.start_service_task: + # start the service task to monitor connected clients + self.start_service_task = False + self.start_background_task(self._service_task) + + sid = self._generate_id() + s = socket.Socket(self, sid) + self.sockets[sid] = s + + pkt = packet.Packet( + packet.OPEN, {'sid': sid, + 'upgrades': self._upgrades(sid, transport), + 'pingTimeout': int(self.ping_timeout * 1000), + 'pingInterval': int(self.ping_interval * 1000)}) + s.send(pkt) + + ret = self._trigger_event('connect', sid, environ, run_async=False) + if ret is False: + del self.sockets[sid] + self.logger.warning('Application rejected connection') + return self._unauthorized() + + if transport == 'websocket': + ret = s.handle_get_request(environ, start_response) + if s.closed: + # websocket connection ended, so we are done + del self.sockets[sid] + return ret + else: + s.connected = True + headers = None + if self.cookie: + headers = [('Set-Cookie', self.cookie + '=' + sid)] + try: + return self._ok(s.poll(), headers=headers, b64=b64, + jsonp_index=jsonp_index) + except exceptions.QueueEmpty: + return self._bad_request() + + def _upgrades(self, sid, transport): + """Return the list of possible upgrades for a client connection.""" + if not self.allow_upgrades or self._get_socket(sid).upgraded or \ + self._async['websocket'] is None or transport == 'websocket': + return [] + return ['websocket'] + + def _trigger_event(self, event, *args, **kwargs): + """Invoke an event handler.""" + run_async = kwargs.pop('run_async', False) + if event in self.handlers: + if run_async: + return self.start_background_task(self.handlers[event], *args) + else: + try: + return self.handlers[event](*args) + except: + self.logger.exception(event + ' handler error') + if event == 'connect': + # if connect handler raised error we reject the + # connection + return False + + def _get_socket(self, sid): + """Return the socket object for a given session.""" + try: + s = self.sockets[sid] + except KeyError: + raise KeyError('Session not found') + if s.closed: + del self.sockets[sid] + raise KeyError('Session is disconnected') + return s + + def _ok(self, packets=None, headers=None, b64=False, jsonp_index=None): + """Generate a successful HTTP response.""" + if packets is not None: + if headers is None: + headers = [] + if b64: + headers += [('Content-Type', 'text/plain; charset=UTF-8')] + else: + headers += [('Content-Type', 'application/octet-stream')] + return {'status': '200 OK', + 'headers': headers, + 'response': payload.Payload(packets=packets).encode( + b64=b64, jsonp_index=jsonp_index)} + else: + return {'status': '200 OK', + 'headers': [('Content-Type', 'text/plain')], + 'response': b'OK'} + + def _bad_request(self): + """Generate a bad request HTTP error response.""" + return {'status': '400 BAD REQUEST', + 'headers': [('Content-Type', 'text/plain')], + 'response': b'Bad Request'} + + def _method_not_found(self): + """Generate a method not found HTTP error response.""" + return {'status': '405 METHOD NOT FOUND', + 'headers': [('Content-Type', 'text/plain')], + 'response': b'Method Not Found'} + + def _unauthorized(self): + """Generate a unauthorized HTTP error response.""" + return {'status': '401 UNAUTHORIZED', + 'headers': [('Content-Type', 'text/plain')], + 'response': b'Unauthorized'} + + def _cors_allowed_origins(self, environ): + default_origins = [] + if 'wsgi.url_scheme' in environ and 'HTTP_HOST' in environ: + default_origins.append('{scheme}://{host}'.format( + scheme=environ['wsgi.url_scheme'], host=environ['HTTP_HOST'])) + if 'HTTP_X_FORWARDED_HOST' in environ: + scheme = environ.get( + 'HTTP_X_FORWARDED_PROTO', + environ['wsgi.url_scheme']).split(',')[0].strip() + default_origins.append('{scheme}://{host}'.format( + scheme=scheme, host=environ['HTTP_X_FORWARDED_HOST'].split( + ',')[0].strip())) + if self.cors_allowed_origins is None: + allowed_origins = default_origins + elif self.cors_allowed_origins == '*': + allowed_origins = None + elif isinstance(self.cors_allowed_origins, six.string_types): + allowed_origins = [self.cors_allowed_origins] + else: + allowed_origins = self.cors_allowed_origins + return allowed_origins + + def _cors_headers(self, environ): + """Return the cross-origin-resource-sharing headers.""" + if self.cors_allowed_origins == []: + # special case, CORS handling is completely disabled + return [] + headers = [] + allowed_origins = self._cors_allowed_origins(environ) + if 'HTTP_ORIGIN' in environ and \ + (allowed_origins is None or environ['HTTP_ORIGIN'] in + allowed_origins): + headers = [('Access-Control-Allow-Origin', environ['HTTP_ORIGIN'])] + if environ['REQUEST_METHOD'] == 'OPTIONS': + headers += [('Access-Control-Allow-Methods', 'OPTIONS, GET, POST')] + if 'HTTP_ACCESS_CONTROL_REQUEST_HEADERS' in environ: + headers += [('Access-Control-Allow-Headers', + environ['HTTP_ACCESS_CONTROL_REQUEST_HEADERS'])] + if self.cors_credentials: + headers += [('Access-Control-Allow-Credentials', 'true')] + return headers + + def _gzip(self, response): + """Apply gzip compression to a response.""" + bytesio = six.BytesIO() + with gzip.GzipFile(fileobj=bytesio, mode='w') as gz: + gz.write(response) + return bytesio.getvalue() + + def _deflate(self, response): + """Apply deflate compression to a response.""" + return zlib.compress(response) + + def _service_task(self): # pragma: no cover + """Monitor connected clients and clean up those that time out.""" + while True: + if len(self.sockets) == 0: + # nothing to do + self.sleep(self.ping_timeout) + continue + + # go through the entire client list in a ping interval cycle + sleep_interval = float(self.ping_timeout) / len(self.sockets) + + try: + # iterate over the current clients + for s in self.sockets.copy().values(): + if not s.closing and not s.closed: + s.check_ping_timeout() + self.sleep(sleep_interval) + except (SystemExit, KeyboardInterrupt): + self.logger.info('service task canceled') + break + except: + # an unexpected exception has occurred, log it and continue + self.logger.exception('service task exception') diff --git a/libs/engineio/socket.py b/libs/engineio/socket.py new file mode 100644 index 000000000..38593e7c7 --- /dev/null +++ b/libs/engineio/socket.py @@ -0,0 +1,248 @@ +import six +import sys +import time + +from . import exceptions +from . import packet +from . import payload + + +class Socket(object): + """An Engine.IO socket.""" + upgrade_protocols = ['websocket'] + + def __init__(self, server, sid): + self.server = server + self.sid = sid + self.queue = self.server.create_queue() + self.last_ping = time.time() + self.connected = False + self.upgrading = False + self.upgraded = False + self.packet_backlog = [] + self.closing = False + self.closed = False + self.session = {} + + def poll(self): + """Wait for packets to send to the client.""" + queue_empty = self.server.get_queue_empty_exception() + try: + packets = [self.queue.get(timeout=self.server.ping_timeout)] + self.queue.task_done() + except queue_empty: + raise exceptions.QueueEmpty() + if packets == [None]: + return [] + while True: + try: + packets.append(self.queue.get(block=False)) + self.queue.task_done() + except queue_empty: + break + return packets + + def receive(self, pkt): + """Receive packet from the client.""" + packet_name = packet.packet_names[pkt.packet_type] \ + if pkt.packet_type < len(packet.packet_names) else 'UNKNOWN' + self.server.logger.info('%s: Received packet %s data %s', + self.sid, packet_name, + pkt.data if not isinstance(pkt.data, bytes) + else '<binary>') + if pkt.packet_type == packet.PING: + self.last_ping = time.time() + self.send(packet.Packet(packet.PONG, pkt.data)) + elif pkt.packet_type == packet.MESSAGE: + self.server._trigger_event('message', self.sid, pkt.data, + run_async=self.server.async_handlers) + elif pkt.packet_type == packet.UPGRADE: + self.send(packet.Packet(packet.NOOP)) + elif pkt.packet_type == packet.CLOSE: + self.close(wait=False, abort=True) + else: + raise exceptions.UnknownPacketError() + + def check_ping_timeout(self): + """Make sure the client is still sending pings. + + This helps detect disconnections for long-polling clients. + """ + if self.closed: + raise exceptions.SocketIsClosedError() + if time.time() - self.last_ping > self.server.ping_interval + \ + self.server.ping_interval_grace_period: + self.server.logger.info('%s: Client is gone, closing socket', + self.sid) + # Passing abort=False here will cause close() to write a + # CLOSE packet. This has the effect of updating half-open sockets + # to their correct state of disconnected + self.close(wait=False, abort=False) + return False + return True + + def send(self, pkt): + """Send a packet to the client.""" + if not self.check_ping_timeout(): + return + if self.upgrading: + self.packet_backlog.append(pkt) + else: + self.queue.put(pkt) + self.server.logger.info('%s: Sending packet %s data %s', + self.sid, packet.packet_names[pkt.packet_type], + pkt.data if not isinstance(pkt.data, bytes) + else '<binary>') + + def handle_get_request(self, environ, start_response): + """Handle a long-polling GET request from the client.""" + connections = [ + s.strip() + for s in environ.get('HTTP_CONNECTION', '').lower().split(',')] + transport = environ.get('HTTP_UPGRADE', '').lower() + if 'upgrade' in connections and transport in self.upgrade_protocols: + self.server.logger.info('%s: Received request to upgrade to %s', + self.sid, transport) + return getattr(self, '_upgrade_' + transport)(environ, + start_response) + try: + packets = self.poll() + except exceptions.QueueEmpty: + exc = sys.exc_info() + self.close(wait=False) + six.reraise(*exc) + return packets + + def handle_post_request(self, environ): + """Handle a long-polling POST request from the client.""" + length = int(environ.get('CONTENT_LENGTH', '0')) + if length > self.server.max_http_buffer_size: + raise exceptions.ContentTooLongError() + else: + body = environ['wsgi.input'].read(length) + p = payload.Payload(encoded_payload=body) + for pkt in p.packets: + self.receive(pkt) + + def close(self, wait=True, abort=False): + """Close the socket connection.""" + if not self.closed and not self.closing: + self.closing = True + self.server._trigger_event('disconnect', self.sid, run_async=False) + if not abort: + self.send(packet.Packet(packet.CLOSE)) + self.closed = True + self.queue.put(None) + if wait: + self.queue.join() + + def _upgrade_websocket(self, environ, start_response): + """Upgrade the connection from polling to websocket.""" + if self.upgraded: + raise IOError('Socket has been upgraded already') + if self.server._async['websocket'] is None: + # the selected async mode does not support websocket + return self.server._bad_request() + ws = self.server._async['websocket'](self._websocket_handler) + return ws(environ, start_response) + + def _websocket_handler(self, ws): + """Engine.IO handler for websocket transport.""" + # try to set a socket timeout matching the configured ping interval + for attr in ['_sock', 'socket']: # pragma: no cover + if hasattr(ws, attr) and hasattr(getattr(ws, attr), 'settimeout'): + getattr(ws, attr).settimeout(self.server.ping_timeout) + + if self.connected: + # the socket was already connected, so this is an upgrade + self.upgrading = True # hold packet sends during the upgrade + + pkt = ws.wait() + decoded_pkt = packet.Packet(encoded_packet=pkt) + if decoded_pkt.packet_type != packet.PING or \ + decoded_pkt.data != 'probe': + self.server.logger.info( + '%s: Failed websocket upgrade, no PING packet', self.sid) + return [] + ws.send(packet.Packet( + packet.PONG, + data=six.text_type('probe')).encode(always_bytes=False)) + self.queue.put(packet.Packet(packet.NOOP)) # end poll + + pkt = ws.wait() + decoded_pkt = packet.Packet(encoded_packet=pkt) + if decoded_pkt.packet_type != packet.UPGRADE: + self.upgraded = False + self.server.logger.info( + ('%s: Failed websocket upgrade, expected UPGRADE packet, ' + 'received %s instead.'), + self.sid, pkt) + return [] + self.upgraded = True + + # flush any packets that were sent during the upgrade + for pkt in self.packet_backlog: + self.queue.put(pkt) + self.packet_backlog = [] + self.upgrading = False + else: + self.connected = True + self.upgraded = True + + # start separate writer thread + def writer(): + while True: + packets = None + try: + packets = self.poll() + except exceptions.QueueEmpty: + break + if not packets: + # empty packet list returned -> connection closed + break + try: + for pkt in packets: + ws.send(pkt.encode(always_bytes=False)) + except: + break + writer_task = self.server.start_background_task(writer) + + self.server.logger.info( + '%s: Upgrade to websocket successful', self.sid) + + while True: + p = None + try: + p = ws.wait() + except Exception as e: + # if the socket is already closed, we can assume this is a + # downstream error of that + if not self.closed: # pragma: no cover + self.server.logger.info( + '%s: Unexpected error "%s", closing connection', + self.sid, str(e)) + break + if p is None: + # connection closed by client + break + if isinstance(p, six.text_type): # pragma: no cover + p = p.encode('utf-8') + pkt = packet.Packet(encoded_packet=p) + try: + self.receive(pkt) + except exceptions.UnknownPacketError: # pragma: no cover + pass + except exceptions.SocketIsClosedError: # pragma: no cover + self.server.logger.info('Receive error -- socket is closed') + break + except: # pragma: no cover + # if we get an unexpected exception we log the error and exit + # the connection properly + self.server.logger.exception('Unknown receive error') + break + + self.queue.put(None) # unlock the writer task so that it can exit + writer_task.join() + self.close(wait=False, abort=True) + + return [] diff --git a/libs/engineio/static_files.py b/libs/engineio/static_files.py new file mode 100644 index 000000000..3058f6ea4 --- /dev/null +++ b/libs/engineio/static_files.py @@ -0,0 +1,55 @@ +content_types = { + 'css': 'text/css', + 'gif': 'image/gif', + 'html': 'text/html', + 'jpg': 'image/jpeg', + 'js': 'application/javascript', + 'json': 'application/json', + 'png': 'image/png', + 'txt': 'text/plain', +} + + +def get_static_file(path, static_files): + """Return the local filename and content type for the requested static + file URL. + + :param path: the path portion of the requested URL. + :param static_files: a static file configuration dictionary. + + This function returns a dictionary with two keys, "filename" and + "content_type". If the requested URL does not match any static file, the + return value is None. + """ + if path in static_files: + f = static_files[path] + else: + f = None + rest = '' + while path != '': + path, last = path.rsplit('/', 1) + rest = '/' + last + rest + if path in static_files: + f = static_files[path] + rest + break + elif path + '/' in static_files: + f = static_files[path + '/'] + rest[1:] + break + if f: + if isinstance(f, str): + f = {'filename': f} + if f['filename'].endswith('/'): + if '' in static_files: + if isinstance(static_files[''], str): + f['filename'] += static_files[''] + else: + f['filename'] += static_files['']['filename'] + if 'content_type' in static_files['']: + f['content_type'] = static_files['']['content_type'] + else: + f['filename'] += 'index.html' + if 'content_type' not in f: + ext = f['filename'].rsplit('.')[-1] + f['content_type'] = content_types.get( + ext, 'application/octet-stream') + return f diff --git a/libs/flask_socketio/__init__.py b/libs/flask_socketio/__init__.py new file mode 100644 index 000000000..e4209f1e9 --- /dev/null +++ b/libs/flask_socketio/__init__.py @@ -0,0 +1,922 @@ +from functools import wraps +import os +import sys + +# make sure gevent-socketio is not installed, as it conflicts with +# python-socketio +gevent_socketio_found = True +try: + from socketio import socketio_manage +except ImportError: + gevent_socketio_found = False +if gevent_socketio_found: + print('The gevent-socketio package is incompatible with this version of ' + 'the Flask-SocketIO extension. Please uninstall it, and then ' + 'install the latest version of python-socketio in its place.') + sys.exit(1) + +import flask +from flask import _request_ctx_stack, json as flask_json +from flask.sessions import SessionMixin +import socketio +from socketio.exceptions import ConnectionRefusedError +from werkzeug.debug import DebuggedApplication +from werkzeug.serving import run_with_reloader + +from .namespace import Namespace +from .test_client import SocketIOTestClient + +__version__ = '4.2.1' + + +class _SocketIOMiddleware(socketio.WSGIApp): + """This WSGI middleware simply exposes the Flask application in the WSGI + environment before executing the request. + """ + def __init__(self, socketio_app, flask_app, socketio_path='socket.io'): + self.flask_app = flask_app + super(_SocketIOMiddleware, self).__init__(socketio_app, + flask_app.wsgi_app, + socketio_path=socketio_path) + + def __call__(self, environ, start_response): + environ = environ.copy() + environ['flask.app'] = self.flask_app + return super(_SocketIOMiddleware, self).__call__(environ, + start_response) + + +class _ManagedSession(dict, SessionMixin): + """This class is used for user sessions that are managed by + Flask-SocketIO. It is simple dict, expanded with the Flask session + attributes.""" + pass + + +class SocketIO(object): + """Create a Flask-SocketIO server. + + :param app: The flask application instance. If the application instance + isn't known at the time this class is instantiated, then call + ``socketio.init_app(app)`` once the application instance is + available. + :param manage_session: If set to ``True``, this extension manages the user + session for Socket.IO events. If set to ``False``, + Flask's own session management is used. When using + Flask's cookie based sessions it is recommended that + you leave this set to the default of ``True``. When + using server-side sessions, a ``False`` setting + enables sharing the user session between HTTP routes + and Socket.IO events. + :param message_queue: A connection URL for a message queue service the + server can use for multi-process communication. A + message queue is not required when using a single + server process. + :param channel: The channel name, when using a message queue. If a channel + isn't specified, a default channel will be used. If + multiple clusters of SocketIO processes need to use the + same message queue without interfering with each other, then + each cluster should use a different channel. + :param path: The path where the Socket.IO server is exposed. Defaults to + ``'socket.io'``. Leave this as is unless you know what you are + doing. + :param resource: Alias to ``path``. + :param kwargs: Socket.IO and Engine.IO server options. + + The Socket.IO server options are detailed below: + + :param client_manager: The client manager instance that will manage the + client list. When this is omitted, the client list + is stored in an in-memory structure, so the use of + multiple connected servers is not possible. In most + cases, this argument does not need to be set + explicitly. + :param logger: To enable logging set to ``True`` or pass a logger object to + use. To disable logging set to ``False``. The default is + ``False``. + :param binary: ``True`` to support binary payloads, ``False`` to treat all + payloads as text. On Python 2, if this is set to ``True``, + ``unicode`` values are treated as text, and ``str`` and + ``bytes`` values are treated as binary. This option has no + effect on Python 3, where text and binary payloads are + always automatically discovered. + :param json: An alternative json module to use for encoding and decoding + packets. Custom json modules must have ``dumps`` and ``loads`` + functions that are compatible with the standard library + versions. To use the same json encoder and decoder as a Flask + application, use ``flask.json``. + :param async_handlers: If set to ``True``, event handlers for a client are + executed in separate threads. To run handlers for a + client synchronously, set to ``False``. The default + is ``True``. + :param always_connect: When set to ``False``, new connections are + provisory until the connect handler returns + something other than ``False``, at which point they + are accepted. When set to ``True``, connections are + immediately accepted, and then if the connect + handler returns ``False`` a disconnect is issued. + Set to ``True`` if you need to emit events from the + connect handler and your client is confused when it + receives events before the connection acceptance. + In any other case use the default of ``False``. + + The Engine.IO server configuration supports the following settings: + + :param async_mode: The asynchronous model to use. See the Deployment + section in the documentation for a description of the + available options. Valid async modes are + ``threading``, ``eventlet``, ``gevent`` and + ``gevent_uwsgi``. If this argument is not given, + ``eventlet`` is tried first, then ``gevent_uwsgi``, + then ``gevent``, and finally ``threading``. The + first async mode that has all its dependencies installed + is then one that is chosen. + :param ping_timeout: The time in seconds that the client waits for the + server to respond before disconnecting. The default is + 60 seconds. + :param ping_interval: The interval in seconds at which the client pings + the server. The default is 25 seconds. + :param max_http_buffer_size: The maximum size of a message when using the + polling transport. The default is 100,000,000 + bytes. + :param allow_upgrades: Whether to allow transport upgrades or not. The + default is ``True``. + :param http_compression: Whether to compress packages when using the + polling transport. The default is ``True``. + :param compression_threshold: Only compress messages when their byte size + is greater than this value. The default is + 1024 bytes. + :param cookie: Name of the HTTP cookie that contains the client session + id. If set to ``None``, a cookie is not sent to the client. + The default is ``'io'``. + :param cors_allowed_origins: Origin or list of origins that are allowed to + connect to this server. Only the same origin + is allowed by default. Set this argument to + ``'*'`` to allow all origins, or to ``[]`` to + disable CORS handling. + :param cors_credentials: Whether credentials (cookies, authentication) are + allowed in requests to this server. The default is + ``True``. + :param monitor_clients: If set to ``True``, a background task will ensure + inactive clients are closed. Set to ``False`` to + disable the monitoring task (not recommended). The + default is ``True``. + :param engineio_logger: To enable Engine.IO logging set to ``True`` or pass + a logger object to use. To disable logging set to + ``False``. The default is ``False``. + """ + + def __init__(self, app=None, **kwargs): + self.server = None + self.server_options = {} + self.wsgi_server = None + self.handlers = [] + self.namespace_handlers = [] + self.exception_handlers = {} + self.default_exception_handler = None + self.manage_session = True + # We can call init_app when: + # - we were given the Flask app instance (standard initialization) + # - we were not given the app, but we were given a message_queue + # (standard initialization for auxiliary process) + # In all other cases we collect the arguments and assume the client + # will call init_app from an app factory function. + if app is not None or 'message_queue' in kwargs: + self.init_app(app, **kwargs) + else: + self.server_options.update(kwargs) + + def init_app(self, app, **kwargs): + if app is not None: + if not hasattr(app, 'extensions'): + app.extensions = {} # pragma: no cover + app.extensions['socketio'] = self + self.server_options.update(kwargs) + self.manage_session = self.server_options.pop('manage_session', + self.manage_session) + + if 'client_manager' not in self.server_options: + url = self.server_options.pop('message_queue', None) + channel = self.server_options.pop('channel', 'flask-socketio') + write_only = app is None + if url: + if url.startswith(('redis://', "rediss://")): + queue_class = socketio.RedisManager + elif url.startswith(('kafka://')): + queue_class = socketio.KafkaManager + elif url.startswith('zmq'): + queue_class = socketio.ZmqManager + else: + queue_class = socketio.KombuManager + queue = queue_class(url, channel=channel, + write_only=write_only) + self.server_options['client_manager'] = queue + + if 'json' in self.server_options and \ + self.server_options['json'] == flask_json: + # flask's json module is tricky to use because its output + # changes when it is invoked inside or outside the app context + # so here to prevent any ambiguities we replace it with wrappers + # that ensure that the app context is always present + class FlaskSafeJSON(object): + @staticmethod + def dumps(*args, **kwargs): + with app.app_context(): + return flask_json.dumps(*args, **kwargs) + + @staticmethod + def loads(*args, **kwargs): + with app.app_context(): + return flask_json.loads(*args, **kwargs) + + self.server_options['json'] = FlaskSafeJSON + + resource = self.server_options.pop('path', None) or \ + self.server_options.pop('resource', None) or 'socket.io' + if resource.startswith('/'): + resource = resource[1:] + if os.environ.get('FLASK_RUN_FROM_CLI'): + if self.server_options.get('async_mode') is None: + if app is not None: + app.logger.warning( + 'Flask-SocketIO is Running under Werkzeug, WebSocket ' + 'is not available.') + self.server_options['async_mode'] = 'threading' + self.server = socketio.Server(**self.server_options) + self.async_mode = self.server.async_mode + for handler in self.handlers: + self.server.on(handler[0], handler[1], namespace=handler[2]) + for namespace_handler in self.namespace_handlers: + self.server.register_namespace(namespace_handler) + + if app is not None: + # here we attach the SocketIO middlware to the SocketIO object so it + # can be referenced later if debug middleware needs to be inserted + self.sockio_mw = _SocketIOMiddleware(self.server, app, + socketio_path=resource) + app.wsgi_app = self.sockio_mw + + def on(self, message, namespace=None): + """Decorator to register a SocketIO event handler. + + This decorator must be applied to SocketIO event handlers. Example:: + + @socketio.on('my event', namespace='/chat') + def handle_my_custom_event(json): + print('received json: ' + str(json)) + + :param message: The name of the event. This is normally a user defined + string, but a few event names are already defined. Use + ``'message'`` to define a handler that takes a string + payload, ``'json'`` to define a handler that takes a + JSON blob payload, ``'connect'`` or ``'disconnect'`` + to create handlers for connection and disconnection + events. + :param namespace: The namespace on which the handler is to be + registered. Defaults to the global namespace. + """ + namespace = namespace or '/' + + def decorator(handler): + @wraps(handler) + def _handler(sid, *args): + return self._handle_event(handler, message, namespace, sid, + *args) + + if self.server: + self.server.on(message, _handler, namespace=namespace) + else: + self.handlers.append((message, _handler, namespace)) + return handler + return decorator + + def on_error(self, namespace=None): + """Decorator to define a custom error handler for SocketIO events. + + This decorator can be applied to a function that acts as an error + handler for a namespace. This handler will be invoked when a SocketIO + event handler raises an exception. The handler function must accept one + argument, which is the exception raised. Example:: + + @socketio.on_error(namespace='/chat') + def chat_error_handler(e): + print('An error has occurred: ' + str(e)) + + :param namespace: The namespace for which to register the error + handler. Defaults to the global namespace. + """ + namespace = namespace or '/' + + def decorator(exception_handler): + if not callable(exception_handler): + raise ValueError('exception_handler must be callable') + self.exception_handlers[namespace] = exception_handler + return exception_handler + return decorator + + def on_error_default(self, exception_handler): + """Decorator to define a default error handler for SocketIO events. + + This decorator can be applied to a function that acts as a default + error handler for any namespaces that do not have a specific handler. + Example:: + + @socketio.on_error_default + def error_handler(e): + print('An error has occurred: ' + str(e)) + """ + if not callable(exception_handler): + raise ValueError('exception_handler must be callable') + self.default_exception_handler = exception_handler + return exception_handler + + def on_event(self, message, handler, namespace=None): + """Register a SocketIO event handler. + + ``on_event`` is the non-decorator version of ``'on'``. + + Example:: + + def on_foo_event(json): + print('received json: ' + str(json)) + + socketio.on_event('my event', on_foo_event, namespace='/chat') + + :param message: The name of the event. This is normally a user defined + string, but a few event names are already defined. Use + ``'message'`` to define a handler that takes a string + payload, ``'json'`` to define a handler that takes a + JSON blob payload, ``'connect'`` or ``'disconnect'`` + to create handlers for connection and disconnection + events. + :param handler: The function that handles the event. + :param namespace: The namespace on which the handler is to be + registered. Defaults to the global namespace. + """ + self.on(message, namespace=namespace)(handler) + + def on_namespace(self, namespace_handler): + if not isinstance(namespace_handler, Namespace): + raise ValueError('Not a namespace instance.') + namespace_handler._set_socketio(self) + if self.server: + self.server.register_namespace(namespace_handler) + else: + self.namespace_handlers.append(namespace_handler) + + def emit(self, event, *args, **kwargs): + """Emit a server generated SocketIO event. + + This function emits a SocketIO event to one or more connected clients. + A JSON blob can be attached to the event as payload. This function can + be used outside of a SocketIO event context, so it is appropriate to + use when the server is the originator of an event, outside of any + client context, such as in a regular HTTP request handler or a + background task. Example:: + + @app.route('/ping') + def ping(): + socketio.emit('ping event', {'data': 42}, namespace='/chat') + + :param event: The name of the user event to emit. + :param args: A dictionary with the JSON data to send as payload. + :param namespace: The namespace under which the message is to be sent. + Defaults to the global namespace. + :param room: Send the message to all the users in the given room. If + this parameter is not included, the event is sent to + all connected users. + :param skip_sid: The session id of a client to ignore when broadcasting + or addressing a room. This is typically set to the + originator of the message, so that everyone except + that client receive the message. To skip multiple sids + pass a list. + :param callback: If given, this function will be called to acknowledge + that the client has received the message. The + arguments that will be passed to the function are + those provided by the client. Callback functions can + only be used when addressing an individual client. + """ + namespace = kwargs.pop('namespace', '/') + room = kwargs.pop('room', None) + include_self = kwargs.pop('include_self', True) + skip_sid = kwargs.pop('skip_sid', None) + if not include_self and not skip_sid: + skip_sid = flask.request.sid + callback = kwargs.pop('callback', None) + if callback: + # wrap the callback so that it sets app app and request contexts + sid = flask.request.sid + original_callback = callback + + def _callback_wrapper(*args): + return self._handle_event(original_callback, None, namespace, + sid, *args) + + callback = _callback_wrapper + self.server.emit(event, *args, namespace=namespace, room=room, + skip_sid=skip_sid, callback=callback, **kwargs) + + def send(self, data, json=False, namespace=None, room=None, + callback=None, include_self=True, skip_sid=None, **kwargs): + """Send a server-generated SocketIO message. + + This function sends a simple SocketIO message to one or more connected + clients. The message can be a string or a JSON blob. This is a simpler + version of ``emit()``, which should be preferred. This function can be + used outside of a SocketIO event context, so it is appropriate to use + when the server is the originator of an event. + + :param data: The message to send, either a string or a JSON blob. + :param json: ``True`` if ``message`` is a JSON blob, ``False`` + otherwise. + :param namespace: The namespace under which the message is to be sent. + Defaults to the global namespace. + :param room: Send the message only to the users in the given room. If + this parameter is not included, the message is sent to + all connected users. + :param skip_sid: The session id of a client to ignore when broadcasting + or addressing a room. This is typically set to the + originator of the message, so that everyone except + that client receive the message. To skip multiple sids + pass a list. + :param callback: If given, this function will be called to acknowledge + that the client has received the message. The + arguments that will be passed to the function are + those provided by the client. Callback functions can + only be used when addressing an individual client. + """ + skip_sid = flask.request.sid if not include_self else skip_sid + if json: + self.emit('json', data, namespace=namespace, room=room, + skip_sid=skip_sid, callback=callback, **kwargs) + else: + self.emit('message', data, namespace=namespace, room=room, + skip_sid=skip_sid, callback=callback, **kwargs) + + def close_room(self, room, namespace=None): + """Close a room. + + This function removes any users that are in the given room and then + deletes the room from the server. This function can be used outside + of a SocketIO event context. + + :param room: The name of the room to close. + :param namespace: The namespace under which the room exists. Defaults + to the global namespace. + """ + self.server.close_room(room, namespace) + + def run(self, app, host=None, port=None, **kwargs): + """Run the SocketIO web server. + + :param app: The Flask application instance. + :param host: The hostname or IP address for the server to listen on. + Defaults to 127.0.0.1. + :param port: The port number for the server to listen on. Defaults to + 5000. + :param debug: ``True`` to start the server in debug mode, ``False`` to + start in normal mode. + :param use_reloader: ``True`` to enable the Flask reloader, ``False`` + to disable it. + :param extra_files: A list of additional files that the Flask + reloader should watch. Defaults to ``None`` + :param log_output: If ``True``, the server logs all incomming + connections. If ``False`` logging is disabled. + Defaults to ``True`` in debug mode, ``False`` + in normal mode. Unused when the threading async + mode is used. + :param kwargs: Additional web server options. The web server options + are specific to the server used in each of the supported + async modes. Note that options provided here will + not be seen when using an external web server such + as gunicorn, since this method is not called in that + case. + """ + if host is None: + host = '127.0.0.1' + if port is None: + server_name = app.config['SERVER_NAME'] + if server_name and ':' in server_name: + port = int(server_name.rsplit(':', 1)[1]) + else: + port = 5000 + + debug = kwargs.pop('debug', app.debug) + log_output = kwargs.pop('log_output', debug) + use_reloader = kwargs.pop('use_reloader', debug) + extra_files = kwargs.pop('extra_files', None) + + app.debug = debug + if app.debug and self.server.eio.async_mode != 'threading': + # put the debug middleware between the SocketIO middleware + # and the Flask application instance + # + # mw1 mw2 mw3 Flask app + # o ---- o ---- o ---- o + # / + # o Flask-SocketIO + # \ middleware + # o + # Flask-SocketIO WebSocket handler + # + # BECOMES + # + # dbg-mw mw1 mw2 mw3 Flask app + # o ---- o ---- o ---- o ---- o + # / + # o Flask-SocketIO + # \ middleware + # o + # Flask-SocketIO WebSocket handler + # + self.sockio_mw.wsgi_app = DebuggedApplication(self.sockio_mw.wsgi_app, + evalex=True) + + if self.server.eio.async_mode == 'threading': + from werkzeug._internal import _log + _log('warning', 'WebSocket transport not available. Install ' + 'eventlet or gevent and gevent-websocket for ' + 'improved performance.') + app.run(host=host, port=port, threaded=True, + use_reloader=use_reloader, **kwargs) + elif self.server.eio.async_mode == 'eventlet': + def run_server(): + import eventlet + import eventlet.wsgi + import eventlet.green + addresses = eventlet.green.socket.getaddrinfo(host, port) + if not addresses: + raise RuntimeError('Could not resolve host to a valid address') + eventlet_socket = eventlet.listen(addresses[0][4], addresses[0][0]) + + # If provided an SSL argument, use an SSL socket + ssl_args = ['keyfile', 'certfile', 'server_side', 'cert_reqs', + 'ssl_version', 'ca_certs', + 'do_handshake_on_connect', 'suppress_ragged_eofs', + 'ciphers'] + ssl_params = {k: kwargs[k] for k in kwargs if k in ssl_args} + if len(ssl_params) > 0: + for k in ssl_params: + kwargs.pop(k) + ssl_params['server_side'] = True # Listening requires true + eventlet_socket = eventlet.wrap_ssl(eventlet_socket, + **ssl_params) + + eventlet.wsgi.server(eventlet_socket, app, + log_output=log_output, **kwargs) + + if use_reloader: + run_with_reloader(run_server, extra_files=extra_files) + else: + run_server() + elif self.server.eio.async_mode == 'gevent': + from gevent import pywsgi + try: + from geventwebsocket.handler import WebSocketHandler + websocket = True + except ImportError: + websocket = False + + log = 'default' + if not log_output: + log = None + if websocket: + self.wsgi_server = pywsgi.WSGIServer( + (host, port), app, handler_class=WebSocketHandler, + log=log, **kwargs) + else: + self.wsgi_server = pywsgi.WSGIServer((host, port), app, + log=log, **kwargs) + + if use_reloader: + # monkey patching is required by the reloader + from gevent import monkey + monkey.patch_thread() + monkey.patch_time() + + def run_server(): + self.wsgi_server.serve_forever() + + run_with_reloader(run_server, extra_files=extra_files) + else: + self.wsgi_server.serve_forever() + + def stop(self): + """Stop a running SocketIO web server. + + This method must be called from a HTTP or SocketIO handler function. + """ + if self.server.eio.async_mode == 'threading': + func = flask.request.environ.get('werkzeug.server.shutdown') + if func: + func() + else: + raise RuntimeError('Cannot stop unknown web server') + elif self.server.eio.async_mode == 'eventlet': + raise SystemExit + elif self.server.eio.async_mode == 'gevent': + self.wsgi_server.stop() + + def start_background_task(self, target, *args, **kwargs): + """Start a background task using the appropriate async model. + + This is a utility function that applications can use to start a + background task using the method that is compatible with the + selected async mode. + + :param target: the target function to execute. + :param args: arguments to pass to the function. + :param kwargs: keyword arguments to pass to the function. + + This function returns an object compatible with the `Thread` class in + the Python standard library. The `start()` method on this object is + already called by this function. + """ + return self.server.start_background_task(target, *args, **kwargs) + + def sleep(self, seconds=0): + """Sleep for the requested amount of time using the appropriate async + model. + + This is a utility function that applications can use to put a task to + sleep without having to worry about using the correct call for the + selected async mode. + """ + return self.server.sleep(seconds) + + def test_client(self, app, namespace=None, query_string=None, + headers=None, flask_test_client=None): + """The Socket.IO test client is useful for testing a Flask-SocketIO + server. It works in a similar way to the Flask Test Client, but + adapted to the Socket.IO server. + + :param app: The Flask application instance. + :param namespace: The namespace for the client. If not provided, the + client connects to the server on the global + namespace. + :param query_string: A string with custom query string arguments. + :param headers: A dictionary with custom HTTP headers. + :param flask_test_client: The instance of the Flask test client + currently in use. Passing the Flask test + client is optional, but is necessary if you + want the Flask user session and any other + cookies set in HTTP routes accessible from + Socket.IO events. + """ + return SocketIOTestClient(app, self, namespace=namespace, + query_string=query_string, headers=headers, + flask_test_client=flask_test_client) + + def _handle_event(self, handler, message, namespace, sid, *args): + if sid not in self.server.environ: + # we don't have record of this client, ignore this event + return '', 400 + app = self.server.environ[sid]['flask.app'] + with app.request_context(self.server.environ[sid]): + if self.manage_session: + # manage a separate session for this client's Socket.IO events + # created as a copy of the regular user session + if 'saved_session' not in self.server.environ[sid]: + self.server.environ[sid]['saved_session'] = \ + _ManagedSession(flask.session) + session_obj = self.server.environ[sid]['saved_session'] + else: + # let Flask handle the user session + # for cookie based sessions, this effectively freezes the + # session to its state at connection time + # for server-side sessions, this allows HTTP and Socket.IO to + # share the session, with both having read/write access to it + session_obj = flask.session._get_current_object() + _request_ctx_stack.top.session = session_obj + flask.request.sid = sid + flask.request.namespace = namespace + flask.request.event = {'message': message, 'args': args} + try: + if message == 'connect': + ret = handler() + else: + ret = handler(*args) + except: + err_handler = self.exception_handlers.get( + namespace, self.default_exception_handler) + if err_handler is None: + raise + type, value, traceback = sys.exc_info() + return err_handler(value) + if not self.manage_session: + # when Flask is managing the user session, it needs to save it + if not hasattr(session_obj, 'modified') or session_obj.modified: + resp = app.response_class() + app.session_interface.save_session(app, session_obj, resp) + return ret + + +def emit(event, *args, **kwargs): + """Emit a SocketIO event. + + This function emits a SocketIO event to one or more connected clients. A + JSON blob can be attached to the event as payload. This is a function that + can only be called from a SocketIO event handler, as in obtains some + information from the current client context. Example:: + + @socketio.on('my event') + def handle_my_custom_event(json): + emit('my response', {'data': 42}) + + :param event: The name of the user event to emit. + :param args: A dictionary with the JSON data to send as payload. + :param namespace: The namespace under which the message is to be sent. + Defaults to the namespace used by the originating event. + A ``'/'`` can be used to explicitly specify the global + namespace. + :param callback: Callback function to invoke with the client's + acknowledgement. + :param broadcast: ``True`` to send the message to all clients, or ``False`` + to only reply to the sender of the originating event. + :param room: Send the message to all the users in the given room. If this + argument is set, then broadcast is implied to be ``True``. + :param include_self: ``True`` to include the sender when broadcasting or + addressing a room, or ``False`` to send to everyone + but the sender. + :param ignore_queue: Only used when a message queue is configured. If + set to ``True``, the event is emitted to the + clients directly, without going through the queue. + This is more efficient, but only works when a + single server process is used, or when there is a + single addresee. It is recommended to always leave + this parameter with its default value of ``False``. + """ + if 'namespace' in kwargs: + namespace = kwargs['namespace'] + else: + namespace = flask.request.namespace + callback = kwargs.get('callback') + broadcast = kwargs.get('broadcast') + room = kwargs.get('room') + if room is None and not broadcast: + room = flask.request.sid + include_self = kwargs.get('include_self', True) + ignore_queue = kwargs.get('ignore_queue', False) + + socketio = flask.current_app.extensions['socketio'] + return socketio.emit(event, *args, namespace=namespace, room=room, + include_self=include_self, callback=callback, + ignore_queue=ignore_queue) + + +def send(message, **kwargs): + """Send a SocketIO message. + + This function sends a simple SocketIO message to one or more connected + clients. The message can be a string or a JSON blob. This is a simpler + version of ``emit()``, which should be preferred. This is a function that + can only be called from a SocketIO event handler. + + :param message: The message to send, either a string or a JSON blob. + :param json: ``True`` if ``message`` is a JSON blob, ``False`` + otherwise. + :param namespace: The namespace under which the message is to be sent. + Defaults to the namespace used by the originating event. + An empty string can be used to use the global namespace. + :param callback: Callback function to invoke with the client's + acknowledgement. + :param broadcast: ``True`` to send the message to all connected clients, or + ``False`` to only reply to the sender of the originating + event. + :param room: Send the message to all the users in the given room. + :param include_self: ``True`` to include the sender when broadcasting or + addressing a room, or ``False`` to send to everyone + but the sender. + :param ignore_queue: Only used when a message queue is configured. If + set to ``True``, the event is emitted to the + clients directly, without going through the queue. + This is more efficient, but only works when a + single server process is used, or when there is a + single addresee. It is recommended to always leave + this parameter with its default value of ``False``. + """ + json = kwargs.get('json', False) + if 'namespace' in kwargs: + namespace = kwargs['namespace'] + else: + namespace = flask.request.namespace + callback = kwargs.get('callback') + broadcast = kwargs.get('broadcast') + room = kwargs.get('room') + if room is None and not broadcast: + room = flask.request.sid + include_self = kwargs.get('include_self', True) + ignore_queue = kwargs.get('ignore_queue', False) + + socketio = flask.current_app.extensions['socketio'] + return socketio.send(message, json=json, namespace=namespace, room=room, + include_self=include_self, callback=callback, + ignore_queue=ignore_queue) + + +def join_room(room, sid=None, namespace=None): + """Join a room. + + This function puts the user in a room, under the current namespace. The + user and the namespace are obtained from the event context. This is a + function that can only be called from a SocketIO event handler. Example:: + + @socketio.on('join') + def on_join(data): + username = session['username'] + room = data['room'] + join_room(room) + send(username + ' has entered the room.', room=room) + + :param room: The name of the room to join. + :param sid: The session id of the client. If not provided, the client is + obtained from the request context. + :param namespace: The namespace for the room. If not provided, the + namespace is obtained from the request context. + """ + socketio = flask.current_app.extensions['socketio'] + sid = sid or flask.request.sid + namespace = namespace or flask.request.namespace + socketio.server.enter_room(sid, room, namespace=namespace) + + +def leave_room(room, sid=None, namespace=None): + """Leave a room. + + This function removes the user from a room, under the current namespace. + The user and the namespace are obtained from the event context. Example:: + + @socketio.on('leave') + def on_leave(data): + username = session['username'] + room = data['room'] + leave_room(room) + send(username + ' has left the room.', room=room) + + :param room: The name of the room to leave. + :param sid: The session id of the client. If not provided, the client is + obtained from the request context. + :param namespace: The namespace for the room. If not provided, the + namespace is obtained from the request context. + """ + socketio = flask.current_app.extensions['socketio'] + sid = sid or flask.request.sid + namespace = namespace or flask.request.namespace + socketio.server.leave_room(sid, room, namespace=namespace) + + +def close_room(room, namespace=None): + """Close a room. + + This function removes any users that are in the given room and then deletes + the room from the server. + + :param room: The name of the room to close. + :param namespace: The namespace for the room. If not provided, the + namespace is obtained from the request context. + """ + socketio = flask.current_app.extensions['socketio'] + namespace = namespace or flask.request.namespace + socketio.server.close_room(room, namespace=namespace) + + +def rooms(sid=None, namespace=None): + """Return a list of the rooms the client is in. + + This function returns all the rooms the client has entered, including its + own room, assigned by the Socket.IO server. + + :param sid: The session id of the client. If not provided, the client is + obtained from the request context. + :param namespace: The namespace for the room. If not provided, the + namespace is obtained from the request context. + """ + socketio = flask.current_app.extensions['socketio'] + sid = sid or flask.request.sid + namespace = namespace or flask.request.namespace + return socketio.server.rooms(sid, namespace=namespace) + + +def disconnect(sid=None, namespace=None, silent=False): + """Disconnect the client. + + This function terminates the connection with the client. As a result of + this call the client will receive a disconnect event. Example:: + + @socketio.on('message') + def receive_message(msg): + if is_banned(session['username']): + disconnect() + else: + # ... + + :param sid: The session id of the client. If not provided, the client is + obtained from the request context. + :param namespace: The namespace for the room. If not provided, the + namespace is obtained from the request context. + :param silent: this option is deprecated. + """ + socketio = flask.current_app.extensions['socketio'] + sid = sid or flask.request.sid + namespace = namespace or flask.request.namespace + return socketio.server.disconnect(sid, namespace=namespace) diff --git a/libs/flask_socketio/namespace.py b/libs/flask_socketio/namespace.py new file mode 100644 index 000000000..914ff3816 --- /dev/null +++ b/libs/flask_socketio/namespace.py @@ -0,0 +1,47 @@ +from socketio import Namespace as _Namespace + + +class Namespace(_Namespace): + def __init__(self, namespace=None): + super(Namespace, self).__init__(namespace) + self.socketio = None + + def _set_socketio(self, socketio): + self.socketio = socketio + + def trigger_event(self, event, *args): + """Dispatch an event to the proper handler method. + + In the most common usage, this method is not overloaded by subclasses, + as it performs the routing of events to methods. However, this + method can be overriden if special dispatching rules are needed, or if + having a single method that catches all events is desired. + """ + handler_name = 'on_' + event + if not hasattr(self, handler_name): + # there is no handler for this event, so we ignore it + return + handler = getattr(self, handler_name) + return self.socketio._handle_event(handler, event, self.namespace, + *args) + + def emit(self, event, data=None, room=None, include_self=True, + namespace=None, callback=None): + """Emit a custom event to one or more connected clients.""" + return self.socketio.emit(event, data, room=room, + include_self=include_self, + namespace=namespace or self.namespace, + callback=callback) + + def send(self, data, room=None, include_self=True, namespace=None, + callback=None): + """Send a message to one or more connected clients.""" + return self.socketio.send(data, room=room, include_self=include_self, + namespace=namespace or self.namespace, + callback=callback) + + def close_room(self, room, namespace=None): + """Close a room.""" + return self.socketio.close_room(room=room, + namespace=namespace or self.namespace) + diff --git a/libs/flask_socketio/test_client.py b/libs/flask_socketio/test_client.py new file mode 100644 index 000000000..0c4592034 --- /dev/null +++ b/libs/flask_socketio/test_client.py @@ -0,0 +1,205 @@ +import uuid + +from socketio import packet +from socketio.pubsub_manager import PubSubManager +from werkzeug.test import EnvironBuilder + + +class SocketIOTestClient(object): + """ + This class is useful for testing a Flask-SocketIO server. It works in a + similar way to the Flask Test Client, but adapted to the Socket.IO server. + + :param app: The Flask application instance. + :param socketio: The application's ``SocketIO`` instance. + :param namespace: The namespace for the client. If not provided, the client + connects to the server on the global namespace. + :param query_string: A string with custom query string arguments. + :param headers: A dictionary with custom HTTP headers. + :param flask_test_client: The instance of the Flask test client + currently in use. Passing the Flask test + client is optional, but is necessary if you + want the Flask user session and any other + cookies set in HTTP routes accessible from + Socket.IO events. + """ + queue = {} + acks = {} + + def __init__(self, app, socketio, namespace=None, query_string=None, + headers=None, flask_test_client=None): + def _mock_send_packet(sid, pkt): + if pkt.packet_type == packet.EVENT or \ + pkt.packet_type == packet.BINARY_EVENT: + if sid not in self.queue: + self.queue[sid] = [] + if pkt.data[0] == 'message' or pkt.data[0] == 'json': + self.queue[sid].append({'name': pkt.data[0], + 'args': pkt.data[1], + 'namespace': pkt.namespace or '/'}) + else: + self.queue[sid].append({'name': pkt.data[0], + 'args': pkt.data[1:], + 'namespace': pkt.namespace or '/'}) + elif pkt.packet_type == packet.ACK or \ + pkt.packet_type == packet.BINARY_ACK: + self.acks[sid] = {'args': pkt.data, + 'namespace': pkt.namespace or '/'} + elif pkt.packet_type == packet.DISCONNECT: + self.connected[pkt.namespace or '/'] = False + + self.app = app + self.flask_test_client = flask_test_client + self.sid = uuid.uuid4().hex + self.queue[self.sid] = [] + self.acks[self.sid] = None + self.callback_counter = 0 + self.socketio = socketio + self.connected = {} + socketio.server._send_packet = _mock_send_packet + socketio.server.environ[self.sid] = {} + socketio.server.async_handlers = False # easier to test when + socketio.server.eio.async_handlers = False # events are sync + if isinstance(socketio.server.manager, PubSubManager): + raise RuntimeError('Test client cannot be used with a message ' + 'queue. Disable the queue on your test ' + 'configuration.') + socketio.server.manager.initialize() + self.connect(namespace=namespace, query_string=query_string, + headers=headers) + + def is_connected(self, namespace=None): + """Check if a namespace is connected. + + :param namespace: The namespace to check. The global namespace is + assumed if this argument is not provided. + """ + return self.connected.get(namespace or '/', False) + + def connect(self, namespace=None, query_string=None, headers=None): + """Connect the client. + + :param namespace: The namespace for the client. If not provided, the + client connects to the server on the global + namespace. + :param query_string: A string with custom query string arguments. + :param headers: A dictionary with custom HTTP headers. + + Note that it is usually not necessary to explicitly call this method, + since a connection is automatically established when an instance of + this class is created. An example where it this method would be useful + is when the application accepts multiple namespace connections. + """ + url = '/socket.io' + if query_string: + if query_string[0] != '?': + query_string = '?' + query_string + url += query_string + environ = EnvironBuilder(url, headers=headers).get_environ() + environ['flask.app'] = self.app + if self.flask_test_client: + # inject cookies from Flask + self.flask_test_client.cookie_jar.inject_wsgi(environ) + self.connected['/'] = True + if self.socketio.server._handle_eio_connect( + self.sid, environ) is False: + del self.connected['/'] + if namespace is not None and namespace != '/': + self.connected[namespace] = True + pkt = packet.Packet(packet.CONNECT, namespace=namespace) + with self.app.app_context(): + if self.socketio.server._handle_eio_message( + self.sid, pkt.encode()) is False: + del self.connected[namespace] + + def disconnect(self, namespace=None): + """Disconnect the client. + + :param namespace: The namespace to disconnect. The global namespace is + assumed if this argument is not provided. + """ + if not self.is_connected(namespace): + raise RuntimeError('not connected') + pkt = packet.Packet(packet.DISCONNECT, namespace=namespace) + with self.app.app_context(): + self.socketio.server._handle_eio_message(self.sid, pkt.encode()) + del self.connected[namespace or '/'] + + def emit(self, event, *args, **kwargs): + """Emit an event to the server. + + :param event: The event name. + :param *args: The event arguments. + :param callback: ``True`` if the client requests a callback, ``False`` + if not. Note that client-side callbacks are not + implemented, a callback request will just tell the + server to provide the arguments to invoke the + callback, but no callback is invoked. Instead, the + arguments that the server provided for the callback + are returned by this function. + :param namespace: The namespace of the event. The global namespace is + assumed if this argument is not provided. + """ + namespace = kwargs.pop('namespace', None) + if not self.is_connected(namespace): + raise RuntimeError('not connected') + callback = kwargs.pop('callback', False) + id = None + if callback: + self.callback_counter += 1 + id = self.callback_counter + pkt = packet.Packet(packet.EVENT, data=[event] + list(args), + namespace=namespace, id=id) + with self.app.app_context(): + encoded_pkt = pkt.encode() + if isinstance(encoded_pkt, list): + for epkt in encoded_pkt: + self.socketio.server._handle_eio_message(self.sid, epkt) + else: + self.socketio.server._handle_eio_message(self.sid, encoded_pkt) + ack = self.acks.pop(self.sid, None) + if ack is not None: + return ack['args'][0] if len(ack['args']) == 1 \ + else ack['args'] + + def send(self, data, json=False, callback=False, namespace=None): + """Send a text or JSON message to the server. + + :param data: A string, dictionary or list to send to the server. + :param json: ``True`` to send a JSON message, ``False`` to send a text + message. + :param callback: ``True`` if the client requests a callback, ``False`` + if not. Note that client-side callbacks are not + implemented, a callback request will just tell the + server to provide the arguments to invoke the + callback, but no callback is invoked. Instead, the + arguments that the server provided for the callback + are returned by this function. + :param namespace: The namespace of the event. The global namespace is + assumed if this argument is not provided. + """ + if json: + msg = 'json' + else: + msg = 'message' + return self.emit(msg, data, callback=callback, namespace=namespace) + + def get_received(self, namespace=None): + """Return the list of messages received from the server. + + Since this is not a real client, any time the server emits an event, + the event is simply stored. The test code can invoke this method to + obtain the list of events that were received since the last call. + + :param namespace: The namespace to get events from. The global + namespace is assumed if this argument is not + provided. + """ + if not self.is_connected(namespace): + raise RuntimeError('not connected') + namespace = namespace or '/' + r = [pkt for pkt in self.queue[self.sid] + if pkt['namespace'] == namespace] + self.queue[self.sid] = [pkt for pkt in self.queue[self.sid] + if pkt not in r] + return r diff --git a/libs/socketio/__init__.py b/libs/socketio/__init__.py new file mode 100644 index 000000000..d3ee7242b --- /dev/null +++ b/libs/socketio/__init__.py @@ -0,0 +1,38 @@ +import sys + +from .client import Client +from .base_manager import BaseManager +from .pubsub_manager import PubSubManager +from .kombu_manager import KombuManager +from .redis_manager import RedisManager +from .kafka_manager import KafkaManager +from .zmq_manager import ZmqManager +from .server import Server +from .namespace import Namespace, ClientNamespace +from .middleware import WSGIApp, Middleware +from .tornado import get_tornado_handler +if sys.version_info >= (3, 5): # pragma: no cover + from .asyncio_client import AsyncClient + from .asyncio_server import AsyncServer + from .asyncio_manager import AsyncManager + from .asyncio_namespace import AsyncNamespace, AsyncClientNamespace + from .asyncio_redis_manager import AsyncRedisManager + from .asyncio_aiopika_manager import AsyncAioPikaManager + from .asgi import ASGIApp +else: # pragma: no cover + AsyncClient = None + AsyncServer = None + AsyncManager = None + AsyncNamespace = None + AsyncRedisManager = None + AsyncAioPikaManager = None + +__version__ = '4.4.0' + +__all__ = ['__version__', 'Client', 'Server', 'BaseManager', 'PubSubManager', + 'KombuManager', 'RedisManager', 'ZmqManager', 'KafkaManager', + 'Namespace', 'ClientNamespace', 'WSGIApp', 'Middleware'] +if AsyncServer is not None: # pragma: no cover + __all__ += ['AsyncClient', 'AsyncServer', 'AsyncNamespace', + 'AsyncClientNamespace', 'AsyncManager', 'AsyncRedisManager', + 'ASGIApp', 'get_tornado_handler', 'AsyncAioPikaManager'] diff --git a/libs/socketio/asgi.py b/libs/socketio/asgi.py new file mode 100644 index 000000000..9bcdd03ba --- /dev/null +++ b/libs/socketio/asgi.py @@ -0,0 +1,36 @@ +import engineio + + +class ASGIApp(engineio.ASGIApp): # pragma: no cover + """ASGI application middleware for Socket.IO. + + This middleware dispatches traffic to an Socket.IO application. It can + also serve a list of static files to the client, or forward unrelated + HTTP traffic to another ASGI application. + + :param socketio_server: The Socket.IO server. Must be an instance of the + ``socketio.AsyncServer`` class. + :param static_files: A dictionary with static file mapping rules. See the + documentation for details on this argument. + :param other_asgi_app: A separate ASGI app that receives all other traffic. + :param socketio_path: The endpoint where the Socket.IO application should + be installed. The default value is appropriate for + most cases. + + Example usage:: + + import socketio + import uvicorn + + sio = socketio.AsyncServer() + app = engineio.ASGIApp(sio, static_files={ + '/': 'index.html', + '/static': './public', + }) + uvicorn.run(app, host='127.0.0.1', port=5000) + """ + def __init__(self, socketio_server, other_asgi_app=None, + static_files=None, socketio_path='socket.io'): + super().__init__(socketio_server, other_asgi_app, + static_files=static_files, + engineio_path=socketio_path) diff --git a/libs/socketio/asyncio_aiopika_manager.py b/libs/socketio/asyncio_aiopika_manager.py new file mode 100644 index 000000000..b20d6afd9 --- /dev/null +++ b/libs/socketio/asyncio_aiopika_manager.py @@ -0,0 +1,105 @@ +import asyncio +import pickle + +from socketio.asyncio_pubsub_manager import AsyncPubSubManager + +try: + import aio_pika +except ImportError: + aio_pika = None + + +class AsyncAioPikaManager(AsyncPubSubManager): # pragma: no cover + """Client manager that uses aio_pika for inter-process messaging under + asyncio. + + This class implements a client manager backend for event sharing across + multiple processes, using RabbitMQ + + To use a aio_pika backend, initialize the :class:`Server` instance as + follows:: + + url = 'amqp://user:password@hostname:port//' + server = socketio.Server(client_manager=socketio.AsyncAioPikaManager( + url)) + + :param url: The connection URL for the backend messaging queue. Example + connection URLs are ``'amqp://guest:guest@localhost:5672//'`` + for RabbitMQ. + :param channel: The channel name on which the server sends and receives + notifications. Must be the same in all the servers. + With this manager, the channel name is the exchange name + in rabbitmq + :param write_only: If set ot ``True``, only initialize to emit events. The + default of ``False`` initializes the class for emitting + and receiving. + """ + + name = 'asyncaiopika' + + def __init__(self, url='amqp://guest:guest@localhost:5672//', + channel='socketio', write_only=False, logger=None): + if aio_pika is None: + raise RuntimeError('aio_pika package is not installed ' + '(Run "pip install aio_pika" in your ' + 'virtualenv).') + self.url = url + self.listener_connection = None + self.listener_channel = None + self.listener_queue = None + super().__init__(channel=channel, write_only=write_only, logger=logger) + + async def _connection(self): + return await aio_pika.connect_robust(self.url) + + async def _channel(self, connection): + return await connection.channel() + + async def _exchange(self, channel): + return await channel.declare_exchange(self.channel, + aio_pika.ExchangeType.FANOUT) + + async def _queue(self, channel, exchange): + queue = await channel.declare_queue(durable=False, + arguments={'x-expires': 300000}) + await queue.bind(exchange) + return queue + + async def _publish(self, data): + connection = await self._connection() + channel = await self._channel(connection) + exchange = await self._exchange(channel) + await exchange.publish( + aio_pika.Message(body=pickle.dumps(data), + delivery_mode=aio_pika.DeliveryMode.PERSISTENT), + routing_key='*' + ) + + async def _listen(self): + retry_sleep = 1 + while True: + try: + if self.listener_connection is None: + self.listener_connection = await self._connection() + self.listener_channel = await self._channel( + self.listener_connection + ) + await self.listener_channel.set_qos(prefetch_count=1) + exchange = await self._exchange(self.listener_channel) + self.listener_queue = await self._queue( + self.listener_channel, exchange + ) + + async with self.listener_queue.iterator() as queue_iter: + async for message in queue_iter: + with message.process(): + return pickle.loads(message.body) + except Exception: + self._get_logger().error('Cannot receive from rabbitmq... ' + 'retrying in ' + '{} secs'.format(retry_sleep)) + self.listener_connection = None + await asyncio.sleep(retry_sleep) + retry_sleep *= 2 + if retry_sleep > 60: + retry_sleep = 60 diff --git a/libs/socketio/asyncio_client.py b/libs/socketio/asyncio_client.py new file mode 100644 index 000000000..2b10434ae --- /dev/null +++ b/libs/socketio/asyncio_client.py @@ -0,0 +1,475 @@ +import asyncio +import logging +import random + +import engineio +import six + +from . import client +from . import exceptions +from . import packet + +default_logger = logging.getLogger('socketio.client') + + +class AsyncClient(client.Client): + """A Socket.IO client for asyncio. + + This class implements a fully compliant Socket.IO web client with support + for websocket and long-polling transports. + + :param reconnection: ``True`` if the client should automatically attempt to + reconnect to the server after an interruption, or + ``False`` to not reconnect. The default is ``True``. + :param reconnection_attempts: How many reconnection attempts to issue + before giving up, or 0 for infinity attempts. + The default is 0. + :param reconnection_delay: How long to wait in seconds before the first + reconnection attempt. Each successive attempt + doubles this delay. + :param reconnection_delay_max: The maximum delay between reconnection + attempts. + :param randomization_factor: Randomization amount for each delay between + reconnection attempts. The default is 0.5, + which means that each delay is randomly + adjusted by +/- 50%. + :param logger: To enable logging set to ``True`` or pass a logger object to + use. To disable logging set to ``False``. The default is + ``False``. + :param binary: ``True`` to support binary payloads, ``False`` to treat all + payloads as text. On Python 2, if this is set to ``True``, + ``unicode`` values are treated as text, and ``str`` and + ``bytes`` values are treated as binary. This option has no + effect on Python 3, where text and binary payloads are + always automatically discovered. + :param json: An alternative json module to use for encoding and decoding + packets. Custom json modules must have ``dumps`` and ``loads`` + functions that are compatible with the standard library + versions. + + The Engine.IO configuration supports the following settings: + + :param request_timeout: A timeout in seconds for requests. The default is + 5 seconds. + :param ssl_verify: ``True`` to verify SSL certificates, or ``False`` to + skip SSL certificate verification, allowing + connections to servers with self signed certificates. + The default is ``True``. + :param engineio_logger: To enable Engine.IO logging set to ``True`` or pass + a logger object to use. To disable logging set to + ``False``. The default is ``False``. + """ + def is_asyncio_based(self): + return True + + async def connect(self, url, headers={}, transports=None, + namespaces=None, socketio_path='socket.io'): + """Connect to a Socket.IO server. + + :param url: The URL of the Socket.IO server. It can include custom + query string parameters if required by the server. + :param headers: A dictionary with custom headers to send with the + connection request. + :param transports: The list of allowed transports. Valid transports + are ``'polling'`` and ``'websocket'``. If not + given, the polling transport is connected first, + then an upgrade to websocket is attempted. + :param namespaces: The list of custom namespaces to connect, in + addition to the default namespace. If not given, + the namespace list is obtained from the registered + event handlers. + :param socketio_path: The endpoint where the Socket.IO server is + installed. The default value is appropriate for + most cases. + + Note: this method is a coroutine. + + Example usage:: + + sio = socketio.Client() + sio.connect('http://localhost:5000') + """ + self.connection_url = url + self.connection_headers = headers + self.connection_transports = transports + self.connection_namespaces = namespaces + self.socketio_path = socketio_path + + if namespaces is None: + namespaces = set(self.handlers.keys()).union( + set(self.namespace_handlers.keys())) + elif isinstance(namespaces, six.string_types): + namespaces = [namespaces] + self.connection_namespaces = namespaces + self.namespaces = [n for n in namespaces if n != '/'] + try: + await self.eio.connect(url, headers=headers, + transports=transports, + engineio_path=socketio_path) + except engineio.exceptions.ConnectionError as exc: + six.raise_from(exceptions.ConnectionError(exc.args[0]), None) + self.connected = True + + async def wait(self): + """Wait until the connection with the server ends. + + Client applications can use this function to block the main thread + during the life of the connection. + + Note: this method is a coroutine. + """ + while True: + await self.eio.wait() + await self.sleep(1) # give the reconnect task time to start up + if not self._reconnect_task: + break + await self._reconnect_task + if self.eio.state != 'connected': + break + + async def emit(self, event, data=None, namespace=None, callback=None): + """Emit a custom event to one or more connected clients. + + :param event: The event name. It can be any string. The event names + ``'connect'``, ``'message'`` and ``'disconnect'`` are + reserved and should not be used. + :param data: The data to send to the client or clients. Data can be of + type ``str``, ``bytes``, ``list`` or ``dict``. If a + ``list`` or ``dict``, the data will be serialized as JSON. + :param namespace: The Socket.IO namespace for the event. If this + argument is omitted the event is emitted to the + default namespace. + :param callback: If given, this function will be called to acknowledge + the the client has received the message. The arguments + that will be passed to the function are those provided + by the client. Callback functions can only be used + when addressing an individual client. + + Note: this method is a coroutine. + """ + namespace = namespace or '/' + if namespace != '/' and namespace not in self.namespaces: + raise exceptions.BadNamespaceError( + namespace + ' is not a connected namespace.') + self.logger.info('Emitting event "%s" [%s]', event, namespace) + if callback is not None: + id = self._generate_ack_id(namespace, callback) + else: + id = None + if six.PY2 and not self.binary: + binary = False # pragma: nocover + else: + binary = None + # tuples are expanded to multiple arguments, everything else is sent + # as a single argument + if isinstance(data, tuple): + data = list(data) + elif data is not None: + data = [data] + else: + data = [] + await self._send_packet(packet.Packet( + packet.EVENT, namespace=namespace, data=[event] + data, id=id, + binary=binary)) + + async def send(self, data, namespace=None, callback=None): + """Send a message to one or more connected clients. + + This function emits an event with the name ``'message'``. Use + :func:`emit` to issue custom event names. + + :param data: The data to send to the client or clients. Data can be of + type ``str``, ``bytes``, ``list`` or ``dict``. If a + ``list`` or ``dict``, the data will be serialized as JSON. + :param namespace: The Socket.IO namespace for the event. If this + argument is omitted the event is emitted to the + default namespace. + :param callback: If given, this function will be called to acknowledge + the the client has received the message. The arguments + that will be passed to the function are those provided + by the client. Callback functions can only be used + when addressing an individual client. + + Note: this method is a coroutine. + """ + await self.emit('message', data=data, namespace=namespace, + callback=callback) + + async def call(self, event, data=None, namespace=None, timeout=60): + """Emit a custom event to a client and wait for the response. + + :param event: The event name. It can be any string. The event names + ``'connect'``, ``'message'`` and ``'disconnect'`` are + reserved and should not be used. + :param data: The data to send to the client or clients. Data can be of + type ``str``, ``bytes``, ``list`` or ``dict``. If a + ``list`` or ``dict``, the data will be serialized as JSON. + :param namespace: The Socket.IO namespace for the event. If this + argument is omitted the event is emitted to the + default namespace. + :param timeout: The waiting timeout. If the timeout is reached before + the client acknowledges the event, then a + ``TimeoutError`` exception is raised. + + Note: this method is a coroutine. + """ + callback_event = self.eio.create_event() + callback_args = [] + + def event_callback(*args): + callback_args.append(args) + callback_event.set() + + await self.emit(event, data=data, namespace=namespace, + callback=event_callback) + try: + await asyncio.wait_for(callback_event.wait(), timeout) + except asyncio.TimeoutError: + six.raise_from(exceptions.TimeoutError(), None) + return callback_args[0] if len(callback_args[0]) > 1 \ + else callback_args[0][0] if len(callback_args[0]) == 1 \ + else None + + async def disconnect(self): + """Disconnect from the server. + + Note: this method is a coroutine. + """ + # here we just request the disconnection + # later in _handle_eio_disconnect we invoke the disconnect handler + for n in self.namespaces: + await self._send_packet(packet.Packet(packet.DISCONNECT, + namespace=n)) + await self._send_packet(packet.Packet( + packet.DISCONNECT, namespace='/')) + self.connected = False + await self.eio.disconnect(abort=True) + + def start_background_task(self, target, *args, **kwargs): + """Start a background task using the appropriate async model. + + This is a utility function that applications can use to start a + background task using the method that is compatible with the + selected async mode. + + :param target: the target function to execute. + :param args: arguments to pass to the function. + :param kwargs: keyword arguments to pass to the function. + + This function returns an object compatible with the `Thread` class in + the Python standard library. The `start()` method on this object is + already called by this function. + """ + return self.eio.start_background_task(target, *args, **kwargs) + + async def sleep(self, seconds=0): + """Sleep for the requested amount of time using the appropriate async + model. + + This is a utility function that applications can use to put a task to + sleep without having to worry about using the correct call for the + selected async mode. + + Note: this method is a coroutine. + """ + return await self.eio.sleep(seconds) + + async def _send_packet(self, pkt): + """Send a Socket.IO packet to the server.""" + encoded_packet = pkt.encode() + if isinstance(encoded_packet, list): + binary = False + for ep in encoded_packet: + await self.eio.send(ep, binary=binary) + binary = True + else: + await self.eio.send(encoded_packet, binary=False) + + async def _handle_connect(self, namespace): + namespace = namespace or '/' + self.logger.info('Namespace {} is connected'.format(namespace)) + await self._trigger_event('connect', namespace=namespace) + if namespace == '/': + for n in self.namespaces: + await self._send_packet(packet.Packet(packet.CONNECT, + namespace=n)) + elif namespace not in self.namespaces: + self.namespaces.append(namespace) + + async def _handle_disconnect(self, namespace): + if not self.connected: + return + namespace = namespace or '/' + if namespace == '/': + for n in self.namespaces: + await self._trigger_event('disconnect', namespace=n) + self.namespaces = [] + await self._trigger_event('disconnect', namespace=namespace) + if namespace in self.namespaces: + self.namespaces.remove(namespace) + if namespace == '/': + self.connected = False + + async def _handle_event(self, namespace, id, data): + namespace = namespace or '/' + self.logger.info('Received event "%s" [%s]', data[0], namespace) + r = await self._trigger_event(data[0], namespace, *data[1:]) + if id is not None: + # send ACK packet with the response returned by the handler + # tuples are expanded as multiple arguments + if r is None: + data = [] + elif isinstance(r, tuple): + data = list(r) + else: + data = [r] + if six.PY2 and not self.binary: + binary = False # pragma: nocover + else: + binary = None + await self._send_packet(packet.Packet( + packet.ACK, namespace=namespace, id=id, data=data, + binary=binary)) + + async def _handle_ack(self, namespace, id, data): + namespace = namespace or '/' + self.logger.info('Received ack [%s]', namespace) + callback = None + try: + callback = self.callbacks[namespace][id] + except KeyError: + # if we get an unknown callback we just ignore it + self.logger.warning('Unknown callback received, ignoring.') + else: + del self.callbacks[namespace][id] + if callback is not None: + if asyncio.iscoroutinefunction(callback): + await callback(*data) + else: + callback(*data) + + async def _handle_error(self, namespace, data): + namespace = namespace or '/' + self.logger.info('Connection to namespace {} was rejected'.format( + namespace)) + if data is None: + data = tuple() + elif not isinstance(data, (tuple, list)): + data = (data,) + await self._trigger_event('connect_error', namespace, *data) + if namespace in self.namespaces: + self.namespaces.remove(namespace) + if namespace == '/': + self.namespaces = [] + self.connected = False + + async def _trigger_event(self, event, namespace, *args): + """Invoke an application event handler.""" + # first see if we have an explicit handler for the event + if namespace in self.handlers and event in self.handlers[namespace]: + if asyncio.iscoroutinefunction(self.handlers[namespace][event]): + try: + ret = await self.handlers[namespace][event](*args) + except asyncio.CancelledError: # pragma: no cover + ret = None + else: + ret = self.handlers[namespace][event](*args) + return ret + + # or else, forward the event to a namepsace handler if one exists + elif namespace in self.namespace_handlers: + return await self.namespace_handlers[namespace].trigger_event( + event, *args) + + async def _handle_reconnect(self): + self._reconnect_abort.clear() + client.reconnecting_clients.append(self) + attempt_count = 0 + current_delay = self.reconnection_delay + while True: + delay = current_delay + current_delay *= 2 + if delay > self.reconnection_delay_max: + delay = self.reconnection_delay_max + delay += self.randomization_factor * (2 * random.random() - 1) + self.logger.info( + 'Connection failed, new attempt in {:.02f} seconds'.format( + delay)) + try: + await asyncio.wait_for(self._reconnect_abort.wait(), delay) + self.logger.info('Reconnect task aborted') + break + except (asyncio.TimeoutError, asyncio.CancelledError): + pass + attempt_count += 1 + try: + await self.connect(self.connection_url, + headers=self.connection_headers, + transports=self.connection_transports, + namespaces=self.connection_namespaces, + socketio_path=self.socketio_path) + except (exceptions.ConnectionError, ValueError): + pass + else: + self.logger.info('Reconnection successful') + self._reconnect_task = None + break + if self.reconnection_attempts and \ + attempt_count >= self.reconnection_attempts: + self.logger.info( + 'Maximum reconnection attempts reached, giving up') + break + client.reconnecting_clients.remove(self) + + def _handle_eio_connect(self): + """Handle the Engine.IO connection event.""" + self.logger.info('Engine.IO connection established') + self.sid = self.eio.sid + + async def _handle_eio_message(self, data): + """Dispatch Engine.IO messages.""" + if self._binary_packet: + pkt = self._binary_packet + if pkt.add_attachment(data): + self._binary_packet = None + if pkt.packet_type == packet.BINARY_EVENT: + await self._handle_event(pkt.namespace, pkt.id, pkt.data) + else: + await self._handle_ack(pkt.namespace, pkt.id, pkt.data) + else: + pkt = packet.Packet(encoded_packet=data) + if pkt.packet_type == packet.CONNECT: + await self._handle_connect(pkt.namespace) + elif pkt.packet_type == packet.DISCONNECT: + await self._handle_disconnect(pkt.namespace) + elif pkt.packet_type == packet.EVENT: + await self._handle_event(pkt.namespace, pkt.id, pkt.data) + elif pkt.packet_type == packet.ACK: + await self._handle_ack(pkt.namespace, pkt.id, pkt.data) + elif pkt.packet_type == packet.BINARY_EVENT or \ + pkt.packet_type == packet.BINARY_ACK: + self._binary_packet = pkt + elif pkt.packet_type == packet.ERROR: + await self._handle_error(pkt.namespace, pkt.data) + else: + raise ValueError('Unknown packet type.') + + async def _handle_eio_disconnect(self): + """Handle the Engine.IO disconnection event.""" + self.logger.info('Engine.IO connection dropped') + self._reconnect_abort.set() + if self.connected: + for n in self.namespaces: + await self._trigger_event('disconnect', namespace=n) + await self._trigger_event('disconnect', namespace='/') + self.namespaces = [] + self.connected = False + self.callbacks = {} + self._binary_packet = None + self.sid = None + if self.eio.state == 'connected' and self.reconnection: + self._reconnect_task = self.start_background_task( + self._handle_reconnect) + + def _engineio_client_class(self): + return engineio.AsyncClient diff --git a/libs/socketio/asyncio_manager.py b/libs/socketio/asyncio_manager.py new file mode 100644 index 000000000..f4496ec7f --- /dev/null +++ b/libs/socketio/asyncio_manager.py @@ -0,0 +1,58 @@ +import asyncio + +from .base_manager import BaseManager + + +class AsyncManager(BaseManager): + """Manage a client list for an asyncio server.""" + async def emit(self, event, data, namespace, room=None, skip_sid=None, + callback=None, **kwargs): + """Emit a message to a single client, a room, or all the clients + connected to the namespace. + + Note: this method is a coroutine. + """ + if namespace not in self.rooms or room not in self.rooms[namespace]: + return + tasks = [] + if not isinstance(skip_sid, list): + skip_sid = [skip_sid] + for sid in self.get_participants(namespace, room): + if sid not in skip_sid: + if callback is not None: + id = self._generate_ack_id(sid, namespace, callback) + else: + id = None + tasks.append(self.server._emit_internal(sid, event, data, + namespace, id)) + if tasks == []: # pragma: no cover + return + await asyncio.wait(tasks) + + async def close_room(self, room, namespace): + """Remove all participants from a room. + + Note: this method is a coroutine. + """ + return super().close_room(room, namespace) + + async def trigger_callback(self, sid, namespace, id, data): + """Invoke an application callback. + + Note: this method is a coroutine. + """ + callback = None + try: + callback = self.callbacks[sid][namespace][id] + except KeyError: + # if we get an unknown callback we just ignore it + self._get_logger().warning('Unknown callback received, ignoring.') + else: + del self.callbacks[sid][namespace][id] + if callback is not None: + ret = callback(*data) + if asyncio.iscoroutine(ret): + try: + await ret + except asyncio.CancelledError: # pragma: no cover + pass diff --git a/libs/socketio/asyncio_namespace.py b/libs/socketio/asyncio_namespace.py new file mode 100644 index 000000000..12e9c0fe6 --- /dev/null +++ b/libs/socketio/asyncio_namespace.py @@ -0,0 +1,204 @@ +import asyncio + +from socketio import namespace + + +class AsyncNamespace(namespace.Namespace): + """Base class for asyncio server-side class-based namespaces. + + A class-based namespace is a class that contains all the event handlers + for a Socket.IO namespace. The event handlers are methods of the class + with the prefix ``on_``, such as ``on_connect``, ``on_disconnect``, + ``on_message``, ``on_json``, and so on. These can be regular functions or + coroutines. + + :param namespace: The Socket.IO namespace to be used with all the event + handlers defined in this class. If this argument is + omitted, the default namespace is used. + """ + def is_asyncio_based(self): + return True + + async def trigger_event(self, event, *args): + """Dispatch an event to the proper handler method. + + In the most common usage, this method is not overloaded by subclasses, + as it performs the routing of events to methods. However, this + method can be overriden if special dispatching rules are needed, or if + having a single method that catches all events is desired. + + Note: this method is a coroutine. + """ + handler_name = 'on_' + event + if hasattr(self, handler_name): + handler = getattr(self, handler_name) + if asyncio.iscoroutinefunction(handler) is True: + try: + ret = await handler(*args) + except asyncio.CancelledError: # pragma: no cover + ret = None + else: + ret = handler(*args) + return ret + + async def emit(self, event, data=None, room=None, skip_sid=None, + namespace=None, callback=None): + """Emit a custom event to one or more connected clients. + + The only difference with the :func:`socketio.Server.emit` method is + that when the ``namespace`` argument is not given the namespace + associated with the class is used. + + Note: this method is a coroutine. + """ + return await self.server.emit(event, data=data, room=room, + skip_sid=skip_sid, + namespace=namespace or self.namespace, + callback=callback) + + async def send(self, data, room=None, skip_sid=None, namespace=None, + callback=None): + """Send a message to one or more connected clients. + + The only difference with the :func:`socketio.Server.send` method is + that when the ``namespace`` argument is not given the namespace + associated with the class is used. + + Note: this method is a coroutine. + """ + return await self.server.send(data, room=room, skip_sid=skip_sid, + namespace=namespace or self.namespace, + callback=callback) + + async def close_room(self, room, namespace=None): + """Close a room. + + The only difference with the :func:`socketio.Server.close_room` method + is that when the ``namespace`` argument is not given the namespace + associated with the class is used. + + Note: this method is a coroutine. + """ + return await self.server.close_room( + room, namespace=namespace or self.namespace) + + async def get_session(self, sid, namespace=None): + """Return the user session for a client. + + The only difference with the :func:`socketio.Server.get_session` + method is that when the ``namespace`` argument is not given the + namespace associated with the class is used. + + Note: this method is a coroutine. + """ + return await self.server.get_session( + sid, namespace=namespace or self.namespace) + + async def save_session(self, sid, session, namespace=None): + """Store the user session for a client. + + The only difference with the :func:`socketio.Server.save_session` + method is that when the ``namespace`` argument is not given the + namespace associated with the class is used. + + Note: this method is a coroutine. + """ + return await self.server.save_session( + sid, session, namespace=namespace or self.namespace) + + def session(self, sid, namespace=None): + """Return the user session for a client with context manager syntax. + + The only difference with the :func:`socketio.Server.session` method is + that when the ``namespace`` argument is not given the namespace + associated with the class is used. + """ + return self.server.session(sid, namespace=namespace or self.namespace) + + async def disconnect(self, sid, namespace=None): + """Disconnect a client. + + The only difference with the :func:`socketio.Server.disconnect` method + is that when the ``namespace`` argument is not given the namespace + associated with the class is used. + + Note: this method is a coroutine. + """ + return await self.server.disconnect( + sid, namespace=namespace or self.namespace) + + +class AsyncClientNamespace(namespace.ClientNamespace): + """Base class for asyncio client-side class-based namespaces. + + A class-based namespace is a class that contains all the event handlers + for a Socket.IO namespace. The event handlers are methods of the class + with the prefix ``on_``, such as ``on_connect``, ``on_disconnect``, + ``on_message``, ``on_json``, and so on. These can be regular functions or + coroutines. + + :param namespace: The Socket.IO namespace to be used with all the event + handlers defined in this class. If this argument is + omitted, the default namespace is used. + """ + def is_asyncio_based(self): + return True + + async def trigger_event(self, event, *args): + """Dispatch an event to the proper handler method. + + In the most common usage, this method is not overloaded by subclasses, + as it performs the routing of events to methods. However, this + method can be overriden if special dispatching rules are needed, or if + having a single method that catches all events is desired. + + Note: this method is a coroutine. + """ + handler_name = 'on_' + event + if hasattr(self, handler_name): + handler = getattr(self, handler_name) + if asyncio.iscoroutinefunction(handler) is True: + try: + ret = await handler(*args) + except asyncio.CancelledError: # pragma: no cover + ret = None + else: + ret = handler(*args) + return ret + + async def emit(self, event, data=None, namespace=None, callback=None): + """Emit a custom event to the server. + + The only difference with the :func:`socketio.Client.emit` method is + that when the ``namespace`` argument is not given the namespace + associated with the class is used. + + Note: this method is a coroutine. + """ + return await self.client.emit(event, data=data, + namespace=namespace or self.namespace, + callback=callback) + + async def send(self, data, namespace=None, callback=None): + """Send a message to the server. + + The only difference with the :func:`socketio.Client.send` method is + that when the ``namespace`` argument is not given the namespace + associated with the class is used. + + Note: this method is a coroutine. + """ + return await self.client.send(data, + namespace=namespace or self.namespace, + callback=callback) + + async def disconnect(self): + """Disconnect a client. + + The only difference with the :func:`socketio.Client.disconnect` method + is that when the ``namespace`` argument is not given the namespace + associated with the class is used. + + Note: this method is a coroutine. + """ + return await self.client.disconnect() diff --git a/libs/socketio/asyncio_pubsub_manager.py b/libs/socketio/asyncio_pubsub_manager.py new file mode 100644 index 000000000..6fdba6d0c --- /dev/null +++ b/libs/socketio/asyncio_pubsub_manager.py @@ -0,0 +1,163 @@ +from functools import partial +import uuid + +import json +import pickle +import six + +from .asyncio_manager import AsyncManager + + +class AsyncPubSubManager(AsyncManager): + """Manage a client list attached to a pub/sub backend under asyncio. + + This is a base class that enables multiple servers to share the list of + clients, with the servers communicating events through a pub/sub backend. + The use of a pub/sub backend also allows any client connected to the + backend to emit events addressed to Socket.IO clients. + + The actual backends must be implemented by subclasses, this class only + provides a pub/sub generic framework for asyncio applications. + + :param channel: The channel name on which the server sends and receives + notifications. + """ + name = 'asyncpubsub' + + def __init__(self, channel='socketio', write_only=False, logger=None): + super().__init__() + self.channel = channel + self.write_only = write_only + self.host_id = uuid.uuid4().hex + self.logger = logger + + def initialize(self): + super().initialize() + if not self.write_only: + self.thread = self.server.start_background_task(self._thread) + self._get_logger().info(self.name + ' backend initialized.') + + async def emit(self, event, data, namespace=None, room=None, skip_sid=None, + callback=None, **kwargs): + """Emit a message to a single client, a room, or all the clients + connected to the namespace. + + This method takes care or propagating the message to all the servers + that are connected through the message queue. + + The parameters are the same as in :meth:`.Server.emit`. + + Note: this method is a coroutine. + """ + if kwargs.get('ignore_queue'): + return await super().emit( + event, data, namespace=namespace, room=room, skip_sid=skip_sid, + callback=callback) + namespace = namespace or '/' + if callback is not None: + if self.server is None: + raise RuntimeError('Callbacks can only be issued from the ' + 'context of a server.') + if room is None: + raise ValueError('Cannot use callback without a room set.') + id = self._generate_ack_id(room, namespace, callback) + callback = (room, namespace, id) + else: + callback = None + await self._publish({'method': 'emit', 'event': event, 'data': data, + 'namespace': namespace, 'room': room, + 'skip_sid': skip_sid, 'callback': callback, + 'host_id': self.host_id}) + + async def close_room(self, room, namespace=None): + await self._publish({'method': 'close_room', 'room': room, + 'namespace': namespace or '/'}) + + async def _publish(self, data): + """Publish a message on the Socket.IO channel. + + This method needs to be implemented by the different subclasses that + support pub/sub backends. + """ + raise NotImplementedError('This method must be implemented in a ' + 'subclass.') # pragma: no cover + + async def _listen(self): + """Return the next message published on the Socket.IO channel, + blocking until a message is available. + + This method needs to be implemented by the different subclasses that + support pub/sub backends. + """ + raise NotImplementedError('This method must be implemented in a ' + 'subclass.') # pragma: no cover + + async def _handle_emit(self, message): + # Events with callbacks are very tricky to handle across hosts + # Here in the receiving end we set up a local callback that preserves + # the callback host and id from the sender + remote_callback = message.get('callback') + remote_host_id = message.get('host_id') + if remote_callback is not None and len(remote_callback) == 3: + callback = partial(self._return_callback, remote_host_id, + *remote_callback) + else: + callback = None + await super().emit(message['event'], message['data'], + namespace=message.get('namespace'), + room=message.get('room'), + skip_sid=message.get('skip_sid'), + callback=callback) + + async def _handle_callback(self, message): + if self.host_id == message.get('host_id'): + try: + sid = message['sid'] + namespace = message['namespace'] + id = message['id'] + args = message['args'] + except KeyError: + return + await self.trigger_callback(sid, namespace, id, args) + + async def _return_callback(self, host_id, sid, namespace, callback_id, + *args): + # When an event callback is received, the callback is returned back + # the sender, which is identified by the host_id + await self._publish({'method': 'callback', 'host_id': host_id, + 'sid': sid, 'namespace': namespace, + 'id': callback_id, 'args': args}) + + async def _handle_close_room(self, message): + await super().close_room( + room=message.get('room'), namespace=message.get('namespace')) + + async def _thread(self): + while True: + try: + message = await self._listen() + except: + import traceback + traceback.print_exc() + break + data = None + if isinstance(message, dict): + data = message + else: + if isinstance(message, six.binary_type): # pragma: no cover + try: + data = pickle.loads(message) + except: + pass + if data is None: + try: + data = json.loads(message) + except: + pass + if data and 'method' in data: + if data['method'] == 'emit': + await self._handle_emit(data) + elif data['method'] == 'callback': + await self._handle_callback(data) + elif data['method'] == 'close_room': + await self._handle_close_room(data) diff --git a/libs/socketio/asyncio_redis_manager.py b/libs/socketio/asyncio_redis_manager.py new file mode 100644 index 000000000..21499c26c --- /dev/null +++ b/libs/socketio/asyncio_redis_manager.py @@ -0,0 +1,107 @@ +import asyncio +import pickle +from urllib.parse import urlparse + +try: + import aioredis +except ImportError: + aioredis = None + +from .asyncio_pubsub_manager import AsyncPubSubManager + + +def _parse_redis_url(url): + p = urlparse(url) + if p.scheme not in {'redis', 'rediss'}: + raise ValueError('Invalid redis url') + ssl = p.scheme == 'rediss' + host = p.hostname or 'localhost' + port = p.port or 6379 + password = p.password + if p.path: + db = int(p.path[1:]) + else: + db = 0 + return host, port, password, db, ssl + + +class AsyncRedisManager(AsyncPubSubManager): # pragma: no cover + """Redis based client manager for asyncio servers. + + This class implements a Redis backend for event sharing across multiple + processes. Only kept here as one more example of how to build a custom + backend, since the kombu backend is perfectly adequate to support a Redis + message queue. + + To use a Redis backend, initialize the :class:`Server` instance as + follows:: + + server = socketio.Server(client_manager=socketio.AsyncRedisManager( + 'redis://hostname:port/0')) + + :param url: The connection URL for the Redis server. For a default Redis + store running on the same host, use ``redis://``. To use an + SSL connection, use ``rediss://``. + :param channel: The channel name on which the server sends and receives + notifications. Must be the same in all the servers. + :param write_only: If set ot ``True``, only initialize to emit events. The + default of ``False`` initializes the class for emitting + and receiving. + """ + name = 'aioredis' + + def __init__(self, url='redis://localhost:6379/0', channel='socketio', + write_only=False, logger=None): + if aioredis is None: + raise RuntimeError('Redis package is not installed ' + '(Run "pip install aioredis" in your ' + 'virtualenv).') + ( + self.host, self.port, self.password, self.db, self.ssl + ) = _parse_redis_url(url) + self.pub = None + self.sub = None + super().__init__(channel=channel, write_only=write_only, logger=logger) + + async def _publish(self, data): + retry = True + while True: + try: + if self.pub is None: + self.pub = await aioredis.create_redis( + (self.host, self.port), db=self.db, + password=self.password, ssl=self.ssl + ) + return await self.pub.publish(self.channel, + pickle.dumps(data)) + except (aioredis.RedisError, OSError): + if retry: + self._get_logger().error('Cannot publish to redis... ' + 'retrying') + self.pub = None + retry = False + else: + self._get_logger().error('Cannot publish to redis... ' + 'giving up') + break + + async def _listen(self): + retry_sleep = 1 + while True: + try: + if self.sub is None: + self.sub = await aioredis.create_redis( + (self.host, self.port), db=self.db, + password=self.password, ssl=self.ssl + ) + self.ch = (await self.sub.subscribe(self.channel))[0] + return await self.ch.get() + except (aioredis.RedisError, OSError): + self._get_logger().error('Cannot receive from redis... ' + 'retrying in ' + '{} secs'.format(retry_sleep)) + self.sub = None + await asyncio.sleep(retry_sleep) + retry_sleep *= 2 + if retry_sleep > 60: + retry_sleep = 60 diff --git a/libs/socketio/asyncio_server.py b/libs/socketio/asyncio_server.py new file mode 100644 index 000000000..251d58180 --- /dev/null +++ b/libs/socketio/asyncio_server.py @@ -0,0 +1,526 @@ +import asyncio + +import engineio +import six + +from . import asyncio_manager +from . import exceptions +from . import packet +from . import server + + +class AsyncServer(server.Server): + """A Socket.IO server for asyncio. + + This class implements a fully compliant Socket.IO web server with support + for websocket and long-polling transports, compatible with the asyncio + framework on Python 3.5 or newer. + + :param client_manager: The client manager instance that will manage the + client list. When this is omitted, the client list + is stored in an in-memory structure, so the use of + multiple connected servers is not possible. + :param logger: To enable logging set to ``True`` or pass a logger object to + use. To disable logging set to ``False``. + :param json: An alternative json module to use for encoding and decoding + packets. Custom json modules must have ``dumps`` and ``loads`` + functions that are compatible with the standard library + versions. + :param async_handlers: If set to ``True``, event handlers are executed in + separate threads. To run handlers synchronously, + set to ``False``. The default is ``True``. + :param kwargs: Connection parameters for the underlying Engine.IO server. + + The Engine.IO configuration supports the following settings: + + :param async_mode: The asynchronous model to use. See the Deployment + section in the documentation for a description of the + available options. Valid async modes are "aiohttp". If + this argument is not given, an async mode is chosen + based on the installed packages. + :param ping_timeout: The time in seconds that the client waits for the + server to respond before disconnecting. + :param ping_interval: The interval in seconds at which the client pings + the server. + :param max_http_buffer_size: The maximum size of a message when using the + polling transport. + :param allow_upgrades: Whether to allow transport upgrades or not. + :param http_compression: Whether to compress packages when using the + polling transport. + :param compression_threshold: Only compress messages when their byte size + is greater than this value. + :param cookie: Name of the HTTP cookie that contains the client session + id. If set to ``None``, a cookie is not sent to the client. + :param cors_allowed_origins: Origin or list of origins that are allowed to + connect to this server. Only the same origin + is allowed by default. Set this argument to + ``'*'`` to allow all origins, or to ``[]`` to + disable CORS handling. + :param cors_credentials: Whether credentials (cookies, authentication) are + allowed in requests to this server. + :param monitor_clients: If set to ``True``, a background task will ensure + inactive clients are closed. Set to ``False`` to + disable the monitoring task (not recommended). The + default is ``True``. + :param engineio_logger: To enable Engine.IO logging set to ``True`` or pass + a logger object to use. To disable logging set to + ``False``. + """ + def __init__(self, client_manager=None, logger=False, json=None, + async_handlers=True, **kwargs): + if client_manager is None: + client_manager = asyncio_manager.AsyncManager() + super().__init__(client_manager=client_manager, logger=logger, + binary=False, json=json, + async_handlers=async_handlers, **kwargs) + + def is_asyncio_based(self): + return True + + def attach(self, app, socketio_path='socket.io'): + """Attach the Socket.IO server to an application.""" + self.eio.attach(app, socketio_path) + + async def emit(self, event, data=None, to=None, room=None, skip_sid=None, + namespace=None, callback=None, **kwargs): + """Emit a custom event to one or more connected clients. + + :param event: The event name. It can be any string. The event names + ``'connect'``, ``'message'`` and ``'disconnect'`` are + reserved and should not be used. + :param data: The data to send to the client or clients. Data can be of + type ``str``, ``bytes``, ``list`` or ``dict``. If a + ``list`` or ``dict``, the data will be serialized as JSON. + :param to: The recipient of the message. This can be set to the + session ID of a client to address only that client, or to + to any custom room created by the application to address all + the clients in that room, If this argument is omitted the + event is broadcasted to all connected clients. + :param room: Alias for the ``to`` parameter. + :param skip_sid: The session ID of a client to skip when broadcasting + to a room or to all clients. This can be used to + prevent a message from being sent to the sender. + :param namespace: The Socket.IO namespace for the event. If this + argument is omitted the event is emitted to the + default namespace. + :param callback: If given, this function will be called to acknowledge + the the client has received the message. The arguments + that will be passed to the function are those provided + by the client. Callback functions can only be used + when addressing an individual client. + :param ignore_queue: Only used when a message queue is configured. If + set to ``True``, the event is emitted to the + clients directly, without going through the queue. + This is more efficient, but only works when a + single server process is used. It is recommended + to always leave this parameter with its default + value of ``False``. + + Note: this method is a coroutine. + """ + namespace = namespace or '/' + room = to or room + self.logger.info('emitting event "%s" to %s [%s]', event, + room or 'all', namespace) + await self.manager.emit(event, data, namespace, room=room, + skip_sid=skip_sid, callback=callback, + **kwargs) + + async def send(self, data, to=None, room=None, skip_sid=None, + namespace=None, callback=None, **kwargs): + """Send a message to one or more connected clients. + + This function emits an event with the name ``'message'``. Use + :func:`emit` to issue custom event names. + + :param data: The data to send to the client or clients. Data can be of + type ``str``, ``bytes``, ``list`` or ``dict``. If a + ``list`` or ``dict``, the data will be serialized as JSON. + :param to: The recipient of the message. This can be set to the + session ID of a client to address only that client, or to + to any custom room created by the application to address all + the clients in that room, If this argument is omitted the + event is broadcasted to all connected clients. + :param room: Alias for the ``to`` parameter. + :param skip_sid: The session ID of a client to skip when broadcasting + to a room or to all clients. This can be used to + prevent a message from being sent to the sender. + :param namespace: The Socket.IO namespace for the event. If this + argument is omitted the event is emitted to the + default namespace. + :param callback: If given, this function will be called to acknowledge + the the client has received the message. The arguments + that will be passed to the function are those provided + by the client. Callback functions can only be used + when addressing an individual client. + :param ignore_queue: Only used when a message queue is configured. If + set to ``True``, the event is emitted to the + clients directly, without going through the queue. + This is more efficient, but only works when a + single server process is used. It is recommended + to always leave this parameter with its default + value of ``False``. + + Note: this method is a coroutine. + """ + await self.emit('message', data=data, to=to, room=room, + skip_sid=skip_sid, namespace=namespace, + callback=callback, **kwargs) + + async def call(self, event, data=None, to=None, sid=None, namespace=None, + timeout=60, **kwargs): + """Emit a custom event to a client and wait for the response. + + :param event: The event name. It can be any string. The event names + ``'connect'``, ``'message'`` and ``'disconnect'`` are + reserved and should not be used. + :param data: The data to send to the client or clients. Data can be of + type ``str``, ``bytes``, ``list`` or ``dict``. If a + ``list`` or ``dict``, the data will be serialized as JSON. + :param to: The session ID of the recipient client. + :param sid: Alias for the ``to`` parameter. + :param namespace: The Socket.IO namespace for the event. If this + argument is omitted the event is emitted to the + default namespace. + :param timeout: The waiting timeout. If the timeout is reached before + the client acknowledges the event, then a + ``TimeoutError`` exception is raised. + :param ignore_queue: Only used when a message queue is configured. If + set to ``True``, the event is emitted to the + client directly, without going through the queue. + This is more efficient, but only works when a + single server process is used. It is recommended + to always leave this parameter with its default + value of ``False``. + """ + if not self.async_handlers: + raise RuntimeError( + 'Cannot use call() when async_handlers is False.') + callback_event = self.eio.create_event() + callback_args = [] + + def event_callback(*args): + callback_args.append(args) + callback_event.set() + + await self.emit(event, data=data, room=to or sid, namespace=namespace, + callback=event_callback, **kwargs) + try: + await asyncio.wait_for(callback_event.wait(), timeout) + except asyncio.TimeoutError: + six.raise_from(exceptions.TimeoutError(), None) + return callback_args[0] if len(callback_args[0]) > 1 \ + else callback_args[0][0] if len(callback_args[0]) == 1 \ + else None + + async def close_room(self, room, namespace=None): + """Close a room. + + This function removes all the clients from the given room. + + :param room: Room name. + :param namespace: The Socket.IO namespace for the event. If this + argument is omitted the default namespace is used. + + Note: this method is a coroutine. + """ + namespace = namespace or '/' + self.logger.info('room %s is closing [%s]', room, namespace) + await self.manager.close_room(room, namespace) + + async def get_session(self, sid, namespace=None): + """Return the user session for a client. + + :param sid: The session id of the client. + :param namespace: The Socket.IO namespace. If this argument is omitted + the default namespace is used. + + The return value is a dictionary. Modifications made to this + dictionary are not guaranteed to be preserved. If you want to modify + the user session, use the ``session`` context manager instead. + """ + namespace = namespace or '/' + eio_session = await self.eio.get_session(sid) + return eio_session.setdefault(namespace, {}) + + async def save_session(self, sid, session, namespace=None): + """Store the user session for a client. + + :param sid: The session id of the client. + :param session: The session dictionary. + :param namespace: The Socket.IO namespace. If this argument is omitted + the default namespace is used. + """ + namespace = namespace or '/' + eio_session = await self.eio.get_session(sid) + eio_session[namespace] = session + + def session(self, sid, namespace=None): + """Return the user session for a client with context manager syntax. + + :param sid: The session id of the client. + + This is a context manager that returns the user session dictionary for + the client. Any changes that are made to this dictionary inside the + context manager block are saved back to the session. Example usage:: + + @eio.on('connect') + def on_connect(sid, environ): + username = authenticate_user(environ) + if not username: + return False + with eio.session(sid) as session: + session['username'] = username + + @eio.on('message') + def on_message(sid, msg): + async with eio.session(sid) as session: + print('received message from ', session['username']) + """ + class _session_context_manager(object): + def __init__(self, server, sid, namespace): + self.server = server + self.sid = sid + self.namespace = namespace + self.session = None + + async def __aenter__(self): + self.session = await self.server.get_session( + sid, namespace=self.namespace) + return self.session + + async def __aexit__(self, *args): + await self.server.save_session(sid, self.session, + namespace=self.namespace) + + return _session_context_manager(self, sid, namespace) + + async def disconnect(self, sid, namespace=None): + """Disconnect a client. + + :param sid: Session ID of the client. + :param namespace: The Socket.IO namespace to disconnect. If this + argument is omitted the default namespace is used. + + Note: this method is a coroutine. + """ + namespace = namespace or '/' + if self.manager.is_connected(sid, namespace=namespace): + self.logger.info('Disconnecting %s [%s]', sid, namespace) + self.manager.pre_disconnect(sid, namespace=namespace) + await self._send_packet(sid, packet.Packet(packet.DISCONNECT, + namespace=namespace)) + await self._trigger_event('disconnect', namespace, sid) + self.manager.disconnect(sid, namespace=namespace) + if namespace == '/': + await self.eio.disconnect(sid) + + async def handle_request(self, *args, **kwargs): + """Handle an HTTP request from the client. + + This is the entry point of the Socket.IO application. This function + returns the HTTP response body to deliver to the client. + + Note: this method is a coroutine. + """ + return await self.eio.handle_request(*args, **kwargs) + + def start_background_task(self, target, *args, **kwargs): + """Start a background task using the appropriate async model. + + This is a utility function that applications can use to start a + background task using the method that is compatible with the + selected async mode. + + :param target: the target function to execute. Must be a coroutine. + :param args: arguments to pass to the function. + :param kwargs: keyword arguments to pass to the function. + + The return value is a ``asyncio.Task`` object. + + Note: this method is a coroutine. + """ + return self.eio.start_background_task(target, *args, **kwargs) + + async def sleep(self, seconds=0): + """Sleep for the requested amount of time using the appropriate async + model. + + This is a utility function that applications can use to put a task to + sleep without having to worry about using the correct call for the + selected async mode. + + Note: this method is a coroutine. + """ + return await self.eio.sleep(seconds) + + async def _emit_internal(self, sid, event, data, namespace=None, id=None): + """Send a message to a client.""" + # tuples are expanded to multiple arguments, everything else is sent + # as a single argument + if isinstance(data, tuple): + data = list(data) + else: + data = [data] + await self._send_packet(sid, packet.Packet( + packet.EVENT, namespace=namespace, data=[event] + data, id=id, + binary=None)) + + async def _send_packet(self, sid, pkt): + """Send a Socket.IO packet to a client.""" + encoded_packet = pkt.encode() + if isinstance(encoded_packet, list): + binary = False + for ep in encoded_packet: + await self.eio.send(sid, ep, binary=binary) + binary = True + else: + await self.eio.send(sid, encoded_packet, binary=False) + + async def _handle_connect(self, sid, namespace): + """Handle a client connection request.""" + namespace = namespace or '/' + self.manager.connect(sid, namespace) + if self.always_connect: + await self._send_packet(sid, packet.Packet(packet.CONNECT, + namespace=namespace)) + fail_reason = None + try: + success = await self._trigger_event('connect', namespace, sid, + self.environ[sid]) + except exceptions.ConnectionRefusedError as exc: + fail_reason = exc.error_args + success = False + + if success is False: + if self.always_connect: + self.manager.pre_disconnect(sid, namespace) + await self._send_packet(sid, packet.Packet( + packet.DISCONNECT, data=fail_reason, namespace=namespace)) + self.manager.disconnect(sid, namespace) + if not self.always_connect: + await self._send_packet(sid, packet.Packet( + packet.ERROR, data=fail_reason, namespace=namespace)) + if sid in self.environ: # pragma: no cover + del self.environ[sid] + elif not self.always_connect: + await self._send_packet(sid, packet.Packet(packet.CONNECT, + namespace=namespace)) + + async def _handle_disconnect(self, sid, namespace): + """Handle a client disconnect.""" + namespace = namespace or '/' + if namespace == '/': + namespace_list = list(self.manager.get_namespaces()) + else: + namespace_list = [namespace] + for n in namespace_list: + if n != '/' and self.manager.is_connected(sid, n): + await self._trigger_event('disconnect', n, sid) + self.manager.disconnect(sid, n) + if namespace == '/' and self.manager.is_connected(sid, namespace): + await self._trigger_event('disconnect', '/', sid) + self.manager.disconnect(sid, '/') + + async def _handle_event(self, sid, namespace, id, data): + """Handle an incoming client event.""" + namespace = namespace or '/' + self.logger.info('received event "%s" from %s [%s]', data[0], sid, + namespace) + if not self.manager.is_connected(sid, namespace): + self.logger.warning('%s is not connected to namespace %s', + sid, namespace) + return + if self.async_handlers: + self.start_background_task(self._handle_event_internal, self, sid, + data, namespace, id) + else: + await self._handle_event_internal(self, sid, data, namespace, id) + + async def _handle_event_internal(self, server, sid, data, namespace, id): + r = await server._trigger_event(data[0], namespace, sid, *data[1:]) + if id is not None: + # send ACK packet with the response returned by the handler + # tuples are expanded as multiple arguments + if r is None: + data = [] + elif isinstance(r, tuple): + data = list(r) + else: + data = [r] + await server._send_packet(sid, packet.Packet(packet.ACK, + namespace=namespace, + id=id, data=data, + binary=None)) + + async def _handle_ack(self, sid, namespace, id, data): + """Handle ACK packets from the client.""" + namespace = namespace or '/' + self.logger.info('received ack from %s [%s]', sid, namespace) + await self.manager.trigger_callback(sid, namespace, id, data) + + async def _trigger_event(self, event, namespace, *args): + """Invoke an application event handler.""" + # first see if we have an explicit handler for the event + if namespace in self.handlers and event in self.handlers[namespace]: + if asyncio.iscoroutinefunction(self.handlers[namespace][event]) \ + is True: + try: + ret = await self.handlers[namespace][event](*args) + except asyncio.CancelledError: # pragma: no cover + ret = None + else: + ret = self.handlers[namespace][event](*args) + return ret + + # or else, forward the event to a namepsace handler if one exists + elif namespace in self.namespace_handlers: + return await self.namespace_handlers[namespace].trigger_event( + event, *args) + + async def _handle_eio_connect(self, sid, environ): + """Handle the Engine.IO connection event.""" + if not self.manager_initialized: + self.manager_initialized = True + self.manager.initialize() + self.environ[sid] = environ + return await self._handle_connect(sid, '/') + + async def _handle_eio_message(self, sid, data): + """Dispatch Engine.IO messages.""" + if sid in self._binary_packet: + pkt = self._binary_packet[sid] + if pkt.add_attachment(data): + del self._binary_packet[sid] + if pkt.packet_type == packet.BINARY_EVENT: + await self._handle_event(sid, pkt.namespace, pkt.id, + pkt.data) + else: + await self._handle_ack(sid, pkt.namespace, pkt.id, + pkt.data) + else: + pkt = packet.Packet(encoded_packet=data) + if pkt.packet_type == packet.CONNECT: + await self._handle_connect(sid, pkt.namespace) + elif pkt.packet_type == packet.DISCONNECT: + await self._handle_disconnect(sid, pkt.namespace) + elif pkt.packet_type == packet.EVENT: + await self._handle_event(sid, pkt.namespace, pkt.id, pkt.data) + elif pkt.packet_type == packet.ACK: + await self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data) + elif pkt.packet_type == packet.BINARY_EVENT or \ + pkt.packet_type == packet.BINARY_ACK: + self._binary_packet[sid] = pkt + elif pkt.packet_type == packet.ERROR: + raise ValueError('Unexpected ERROR packet.') + else: + raise ValueError('Unknown packet type.') + + async def _handle_eio_disconnect(self, sid): + """Handle Engine.IO disconnect event.""" + await self._handle_disconnect(sid, '/') + if sid in self.environ: + del self.environ[sid] + + def _engineio_server_class(self): + return engineio.AsyncServer diff --git a/libs/socketio/base_manager.py b/libs/socketio/base_manager.py new file mode 100644 index 000000000..3cccb8569 --- /dev/null +++ b/libs/socketio/base_manager.py @@ -0,0 +1,178 @@ +import itertools +import logging + +import six + +default_logger = logging.getLogger('socketio') + + +class BaseManager(object): + """Manage client connections. + + This class keeps track of all the clients and the rooms they are in, to + support the broadcasting of messages. The data used by this class is + stored in a memory structure, making it appropriate only for single process + services. More sophisticated storage backends can be implemented by + subclasses. + """ + def __init__(self): + self.logger = None + self.server = None + self.rooms = {} + self.callbacks = {} + self.pending_disconnect = {} + + def set_server(self, server): + self.server = server + + def initialize(self): + """Invoked before the first request is received. Subclasses can add + their initialization code here. + """ + pass + + def get_namespaces(self): + """Return an iterable with the active namespace names.""" + return six.iterkeys(self.rooms) + + def get_participants(self, namespace, room): + """Return an iterable with the active participants in a room.""" + for sid, active in six.iteritems(self.rooms[namespace][room].copy()): + yield sid + + def connect(self, sid, namespace): + """Register a client connection to a namespace.""" + self.enter_room(sid, namespace, None) + self.enter_room(sid, namespace, sid) + + def is_connected(self, sid, namespace): + if namespace in self.pending_disconnect and \ + sid in self.pending_disconnect[namespace]: + # the client is in the process of being disconnected + return False + try: + return self.rooms[namespace][None][sid] + except KeyError: + pass + + def pre_disconnect(self, sid, namespace): + """Put the client in the to-be-disconnected list. + + This allows the client data structures to be present while the + disconnect handler is invoked, but still recognize the fact that the + client is soon going away. + """ + if namespace not in self.pending_disconnect: + self.pending_disconnect[namespace] = [] + self.pending_disconnect[namespace].append(sid) + + def disconnect(self, sid, namespace): + """Register a client disconnect from a namespace.""" + if namespace not in self.rooms: + return + rooms = [] + for room_name, room in six.iteritems(self.rooms[namespace].copy()): + if sid in room: + rooms.append(room_name) + for room in rooms: + self.leave_room(sid, namespace, room) + if sid in self.callbacks and namespace in self.callbacks[sid]: + del self.callbacks[sid][namespace] + if len(self.callbacks[sid]) == 0: + del self.callbacks[sid] + if namespace in self.pending_disconnect and \ + sid in self.pending_disconnect[namespace]: + self.pending_disconnect[namespace].remove(sid) + if len(self.pending_disconnect[namespace]) == 0: + del self.pending_disconnect[namespace] + + def enter_room(self, sid, namespace, room): + """Add a client to a room.""" + if namespace not in self.rooms: + self.rooms[namespace] = {} + if room not in self.rooms[namespace]: + self.rooms[namespace][room] = {} + self.rooms[namespace][room][sid] = True + + def leave_room(self, sid, namespace, room): + """Remove a client from a room.""" + try: + del self.rooms[namespace][room][sid] + if len(self.rooms[namespace][room]) == 0: + del self.rooms[namespace][room] + if len(self.rooms[namespace]) == 0: + del self.rooms[namespace] + except KeyError: + pass + + def close_room(self, room, namespace): + """Remove all participants from a room.""" + try: + for sid in self.get_participants(namespace, room): + self.leave_room(sid, namespace, room) + except KeyError: + pass + + def get_rooms(self, sid, namespace): + """Return the rooms a client is in.""" + r = [] + try: + for room_name, room in six.iteritems(self.rooms[namespace]): + if room_name is not None and sid in room and room[sid]: + r.append(room_name) + except KeyError: + pass + return r + + def emit(self, event, data, namespace, room=None, skip_sid=None, + callback=None, **kwargs): + """Emit a message to a single client, a room, or all the clients + connected to the namespace.""" + if namespace not in self.rooms or room not in self.rooms[namespace]: + return + if not isinstance(skip_sid, list): + skip_sid = [skip_sid] + for sid in self.get_participants(namespace, room): + if sid not in skip_sid: + if callback is not None: + id = self._generate_ack_id(sid, namespace, callback) + else: + id = None + self.server._emit_internal(sid, event, data, namespace, id) + + def trigger_callback(self, sid, namespace, id, data): + """Invoke an application callback.""" + callback = None + try: + callback = self.callbacks[sid][namespace][id] + except KeyError: + # if we get an unknown callback we just ignore it + self._get_logger().warning('Unknown callback received, ignoring.') + else: + del self.callbacks[sid][namespace][id] + if callback is not None: + callback(*data) + + def _generate_ack_id(self, sid, namespace, callback): + """Generate a unique identifier for an ACK packet.""" + namespace = namespace or '/' + if sid not in self.callbacks: + self.callbacks[sid] = {} + if namespace not in self.callbacks[sid]: + self.callbacks[sid][namespace] = {0: itertools.count(1)} + id = six.next(self.callbacks[sid][namespace][0]) + self.callbacks[sid][namespace][id] = callback + return id + + def _get_logger(self): + """Get the appropriate logger + + Prevents uninitialized servers in write-only mode from failing. + """ + + if self.logger: + return self.logger + elif self.server: + return self.server.logger + else: + return default_logger diff --git a/libs/socketio/client.py b/libs/socketio/client.py new file mode 100644 index 000000000..e917d634d --- /dev/null +++ b/libs/socketio/client.py @@ -0,0 +1,620 @@ +import itertools +import logging +import random +import signal + +import engineio +import six + +from . import exceptions +from . import namespace +from . import packet + +default_logger = logging.getLogger('socketio.client') +reconnecting_clients = [] + + +def signal_handler(sig, frame): # pragma: no cover + """SIGINT handler. + + Notify any clients that are in a reconnect loop to abort. Other + disconnection tasks are handled at the engine.io level. + """ + for client in reconnecting_clients[:]: + client._reconnect_abort.set() + return original_signal_handler(sig, frame) + + +original_signal_handler = signal.signal(signal.SIGINT, signal_handler) + + +class Client(object): + """A Socket.IO client. + + This class implements a fully compliant Socket.IO web client with support + for websocket and long-polling transports. + + :param reconnection: ``True`` if the client should automatically attempt to + reconnect to the server after an interruption, or + ``False`` to not reconnect. The default is ``True``. + :param reconnection_attempts: How many reconnection attempts to issue + before giving up, or 0 for infinity attempts. + The default is 0. + :param reconnection_delay: How long to wait in seconds before the first + reconnection attempt. Each successive attempt + doubles this delay. + :param reconnection_delay_max: The maximum delay between reconnection + attempts. + :param randomization_factor: Randomization amount for each delay between + reconnection attempts. The default is 0.5, + which means that each delay is randomly + adjusted by +/- 50%. + :param logger: To enable logging set to ``True`` or pass a logger object to + use. To disable logging set to ``False``. The default is + ``False``. + :param binary: ``True`` to support binary payloads, ``False`` to treat all + payloads as text. On Python 2, if this is set to ``True``, + ``unicode`` values are treated as text, and ``str`` and + ``bytes`` values are treated as binary. This option has no + effect on Python 3, where text and binary payloads are + always automatically discovered. + :param json: An alternative json module to use for encoding and decoding + packets. Custom json modules must have ``dumps`` and ``loads`` + functions that are compatible with the standard library + versions. + + The Engine.IO configuration supports the following settings: + + :param request_timeout: A timeout in seconds for requests. The default is + 5 seconds. + :param ssl_verify: ``True`` to verify SSL certificates, or ``False`` to + skip SSL certificate verification, allowing + connections to servers with self signed certificates. + The default is ``True``. + :param engineio_logger: To enable Engine.IO logging set to ``True`` or pass + a logger object to use. To disable logging set to + ``False``. The default is ``False``. + """ + def __init__(self, reconnection=True, reconnection_attempts=0, + reconnection_delay=1, reconnection_delay_max=5, + randomization_factor=0.5, logger=False, binary=False, + json=None, **kwargs): + self.reconnection = reconnection + self.reconnection_attempts = reconnection_attempts + self.reconnection_delay = reconnection_delay + self.reconnection_delay_max = reconnection_delay_max + self.randomization_factor = randomization_factor + self.binary = binary + + engineio_options = kwargs + engineio_logger = engineio_options.pop('engineio_logger', None) + if engineio_logger is not None: + engineio_options['logger'] = engineio_logger + if json is not None: + packet.Packet.json = json + engineio_options['json'] = json + + self.eio = self._engineio_client_class()(**engineio_options) + self.eio.on('connect', self._handle_eio_connect) + self.eio.on('message', self._handle_eio_message) + self.eio.on('disconnect', self._handle_eio_disconnect) + + if not isinstance(logger, bool): + self.logger = logger + else: + self.logger = default_logger + if not logging.root.handlers and \ + self.logger.level == logging.NOTSET: + if logger: + self.logger.setLevel(logging.INFO) + else: + self.logger.setLevel(logging.ERROR) + self.logger.addHandler(logging.StreamHandler()) + + self.connection_url = None + self.connection_headers = None + self.connection_transports = None + self.connection_namespaces = None + self.socketio_path = None + self.sid = None + + self.connected = False + self.namespaces = [] + self.handlers = {} + self.namespace_handlers = {} + self.callbacks = {} + self._binary_packet = None + self._reconnect_task = None + self._reconnect_abort = self.eio.create_event() + + def is_asyncio_based(self): + return False + + def on(self, event, handler=None, namespace=None): + """Register an event handler. + + :param event: The event name. It can be any string. The event names + ``'connect'``, ``'message'`` and ``'disconnect'`` are + reserved and should not be used. + :param handler: The function that should be invoked to handle the + event. When this parameter is not given, the method + acts as a decorator for the handler function. + :param namespace: The Socket.IO namespace for the event. If this + argument is omitted the handler is associated with + the default namespace. + + Example usage:: + + # as a decorator: + @sio.on('connect') + def connect_handler(): + print('Connected!') + + # as a method: + def message_handler(msg): + print('Received message: ', msg) + sio.send( 'response') + sio.on('message', message_handler) + + The ``'connect'`` event handler receives no arguments. The + ``'message'`` handler and handlers for custom event names receive the + message payload as only argument. Any values returned from a message + handler will be passed to the client's acknowledgement callback + function if it exists. The ``'disconnect'`` handler does not take + arguments. + """ + namespace = namespace or '/' + + def set_handler(handler): + if namespace not in self.handlers: + self.handlers[namespace] = {} + self.handlers[namespace][event] = handler + return handler + + if handler is None: + return set_handler + set_handler(handler) + + def event(self, *args, **kwargs): + """Decorator to register an event handler. + + This is a simplified version of the ``on()`` method that takes the + event name from the decorated function. + + Example usage:: + + @sio.event + def my_event(data): + print('Received data: ', data) + + The above example is equivalent to:: + + @sio.on('my_event') + def my_event(data): + print('Received data: ', data) + + A custom namespace can be given as an argument to the decorator:: + + @sio.event(namespace='/test') + def my_event(data): + print('Received data: ', data) + """ + if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): + # the decorator was invoked without arguments + # args[0] is the decorated function + return self.on(args[0].__name__)(args[0]) + else: + # the decorator was invoked with arguments + def set_handler(handler): + return self.on(handler.__name__, *args, **kwargs)(handler) + + return set_handler + + def register_namespace(self, namespace_handler): + """Register a namespace handler object. + + :param namespace_handler: An instance of a :class:`Namespace` + subclass that handles all the event traffic + for a namespace. + """ + if not isinstance(namespace_handler, namespace.ClientNamespace): + raise ValueError('Not a namespace instance') + if self.is_asyncio_based() != namespace_handler.is_asyncio_based(): + raise ValueError('Not a valid namespace class for this client') + namespace_handler._set_client(self) + self.namespace_handlers[namespace_handler.namespace] = \ + namespace_handler + + def connect(self, url, headers={}, transports=None, + namespaces=None, socketio_path='socket.io'): + """Connect to a Socket.IO server. + + :param url: The URL of the Socket.IO server. It can include custom + query string parameters if required by the server. + :param headers: A dictionary with custom headers to send with the + connection request. + :param transports: The list of allowed transports. Valid transports + are ``'polling'`` and ``'websocket'``. If not + given, the polling transport is connected first, + then an upgrade to websocket is attempted. + :param namespaces: The list of custom namespaces to connect, in + addition to the default namespace. If not given, + the namespace list is obtained from the registered + event handlers. + :param socketio_path: The endpoint where the Socket.IO server is + installed. The default value is appropriate for + most cases. + + Example usage:: + + sio = socketio.Client() + sio.connect('http://localhost:5000') + """ + self.connection_url = url + self.connection_headers = headers + self.connection_transports = transports + self.connection_namespaces = namespaces + self.socketio_path = socketio_path + + if namespaces is None: + namespaces = set(self.handlers.keys()).union( + set(self.namespace_handlers.keys())) + elif isinstance(namespaces, six.string_types): + namespaces = [namespaces] + self.connection_namespaces = namespaces + self.namespaces = [n for n in namespaces if n != '/'] + try: + self.eio.connect(url, headers=headers, transports=transports, + engineio_path=socketio_path) + except engineio.exceptions.ConnectionError as exc: + six.raise_from(exceptions.ConnectionError(exc.args[0]), None) + self.connected = True + + def wait(self): + """Wait until the connection with the server ends. + + Client applications can use this function to block the main thread + during the life of the connection. + """ + while True: + self.eio.wait() + self.sleep(1) # give the reconnect task time to start up + if not self._reconnect_task: + break + self._reconnect_task.join() + if self.eio.state != 'connected': + break + + def emit(self, event, data=None, namespace=None, callback=None): + """Emit a custom event to one or more connected clients. + + :param event: The event name. It can be any string. The event names + ``'connect'``, ``'message'`` and ``'disconnect'`` are + reserved and should not be used. + :param data: The data to send to the client or clients. Data can be of + type ``str``, ``bytes``, ``list`` or ``dict``. If a + ``list`` or ``dict``, the data will be serialized as JSON. + :param namespace: The Socket.IO namespace for the event. If this + argument is omitted the event is emitted to the + default namespace. + :param callback: If given, this function will be called to acknowledge + the the client has received the message. The arguments + that will be passed to the function are those provided + by the client. Callback functions can only be used + when addressing an individual client. + """ + namespace = namespace or '/' + if namespace != '/' and namespace not in self.namespaces: + raise exceptions.BadNamespaceError( + namespace + ' is not a connected namespace.') + self.logger.info('Emitting event "%s" [%s]', event, namespace) + if callback is not None: + id = self._generate_ack_id(namespace, callback) + else: + id = None + if six.PY2 and not self.binary: + binary = False # pragma: nocover + else: + binary = None + # tuples are expanded to multiple arguments, everything else is sent + # as a single argument + if isinstance(data, tuple): + data = list(data) + elif data is not None: + data = [data] + else: + data = [] + self._send_packet(packet.Packet(packet.EVENT, namespace=namespace, + data=[event] + data, id=id, + binary=binary)) + + def send(self, data, namespace=None, callback=None): + """Send a message to one or more connected clients. + + This function emits an event with the name ``'message'``. Use + :func:`emit` to issue custom event names. + + :param data: The data to send to the client or clients. Data can be of + type ``str``, ``bytes``, ``list`` or ``dict``. If a + ``list`` or ``dict``, the data will be serialized as JSON. + :param namespace: The Socket.IO namespace for the event. If this + argument is omitted the event is emitted to the + default namespace. + :param callback: If given, this function will be called to acknowledge + the the client has received the message. The arguments + that will be passed to the function are those provided + by the client. Callback functions can only be used + when addressing an individual client. + """ + self.emit('message', data=data, namespace=namespace, + callback=callback) + + def call(self, event, data=None, namespace=None, timeout=60): + """Emit a custom event to a client and wait for the response. + + :param event: The event name. It can be any string. The event names + ``'connect'``, ``'message'`` and ``'disconnect'`` are + reserved and should not be used. + :param data: The data to send to the client or clients. Data can be of + type ``str``, ``bytes``, ``list`` or ``dict``. If a + ``list`` or ``dict``, the data will be serialized as JSON. + :param namespace: The Socket.IO namespace for the event. If this + argument is omitted the event is emitted to the + default namespace. + :param timeout: The waiting timeout. If the timeout is reached before + the client acknowledges the event, then a + ``TimeoutError`` exception is raised. + """ + callback_event = self.eio.create_event() + callback_args = [] + + def event_callback(*args): + callback_args.append(args) + callback_event.set() + + self.emit(event, data=data, namespace=namespace, + callback=event_callback) + if not callback_event.wait(timeout=timeout): + raise exceptions.TimeoutError() + return callback_args[0] if len(callback_args[0]) > 1 \ + else callback_args[0][0] if len(callback_args[0]) == 1 \ + else None + + def disconnect(self): + """Disconnect from the server.""" + # here we just request the disconnection + # later in _handle_eio_disconnect we invoke the disconnect handler + for n in self.namespaces: + self._send_packet(packet.Packet(packet.DISCONNECT, namespace=n)) + self._send_packet(packet.Packet( + packet.DISCONNECT, namespace='/')) + self.connected = False + self.eio.disconnect(abort=True) + + def transport(self): + """Return the name of the transport used by the client. + + The two possible values returned by this function are ``'polling'`` + and ``'websocket'``. + """ + return self.eio.transport() + + def start_background_task(self, target, *args, **kwargs): + """Start a background task using the appropriate async model. + + This is a utility function that applications can use to start a + background task using the method that is compatible with the + selected async mode. + + :param target: the target function to execute. + :param args: arguments to pass to the function. + :param kwargs: keyword arguments to pass to the function. + + This function returns an object compatible with the `Thread` class in + the Python standard library. The `start()` method on this object is + already called by this function. + """ + return self.eio.start_background_task(target, *args, **kwargs) + + def sleep(self, seconds=0): + """Sleep for the requested amount of time using the appropriate async + model. + + This is a utility function that applications can use to put a task to + sleep without having to worry about using the correct call for the + selected async mode. + """ + return self.eio.sleep(seconds) + + def _send_packet(self, pkt): + """Send a Socket.IO packet to the server.""" + encoded_packet = pkt.encode() + if isinstance(encoded_packet, list): + binary = False + for ep in encoded_packet: + self.eio.send(ep, binary=binary) + binary = True + else: + self.eio.send(encoded_packet, binary=False) + + def _generate_ack_id(self, namespace, callback): + """Generate a unique identifier for an ACK packet.""" + namespace = namespace or '/' + if namespace not in self.callbacks: + self.callbacks[namespace] = {0: itertools.count(1)} + id = six.next(self.callbacks[namespace][0]) + self.callbacks[namespace][id] = callback + return id + + def _handle_connect(self, namespace): + namespace = namespace or '/' + self.logger.info('Namespace {} is connected'.format(namespace)) + self._trigger_event('connect', namespace=namespace) + if namespace == '/': + for n in self.namespaces: + self._send_packet(packet.Packet(packet.CONNECT, namespace=n)) + elif namespace not in self.namespaces: + self.namespaces.append(namespace) + + def _handle_disconnect(self, namespace): + if not self.connected: + return + namespace = namespace or '/' + if namespace == '/': + for n in self.namespaces: + self._trigger_event('disconnect', namespace=n) + self.namespaces = [] + self._trigger_event('disconnect', namespace=namespace) + if namespace in self.namespaces: + self.namespaces.remove(namespace) + if namespace == '/': + self.connected = False + + def _handle_event(self, namespace, id, data): + namespace = namespace or '/' + self.logger.info('Received event "%s" [%s]', data[0], namespace) + r = self._trigger_event(data[0], namespace, *data[1:]) + if id is not None: + # send ACK packet with the response returned by the handler + # tuples are expanded as multiple arguments + if r is None: + data = [] + elif isinstance(r, tuple): + data = list(r) + else: + data = [r] + if six.PY2 and not self.binary: + binary = False # pragma: nocover + else: + binary = None + self._send_packet(packet.Packet(packet.ACK, namespace=namespace, + id=id, data=data, binary=binary)) + + def _handle_ack(self, namespace, id, data): + namespace = namespace or '/' + self.logger.info('Received ack [%s]', namespace) + callback = None + try: + callback = self.callbacks[namespace][id] + except KeyError: + # if we get an unknown callback we just ignore it + self.logger.warning('Unknown callback received, ignoring.') + else: + del self.callbacks[namespace][id] + if callback is not None: + callback(*data) + + def _handle_error(self, namespace, data): + namespace = namespace or '/' + self.logger.info('Connection to namespace {} was rejected'.format( + namespace)) + if data is None: + data = tuple() + elif not isinstance(data, (tuple, list)): + data = (data,) + self._trigger_event('connect_error', namespace, *data) + if namespace in self.namespaces: + self.namespaces.remove(namespace) + if namespace == '/': + self.namespaces = [] + self.connected = False + + def _trigger_event(self, event, namespace, *args): + """Invoke an application event handler.""" + # first see if we have an explicit handler for the event + if namespace in self.handlers and event in self.handlers[namespace]: + return self.handlers[namespace][event](*args) + + # or else, forward the event to a namespace handler if one exists + elif namespace in self.namespace_handlers: + return self.namespace_handlers[namespace].trigger_event( + event, *args) + + def _handle_reconnect(self): + self._reconnect_abort.clear() + reconnecting_clients.append(self) + attempt_count = 0 + current_delay = self.reconnection_delay + while True: + delay = current_delay + current_delay *= 2 + if delay > self.reconnection_delay_max: + delay = self.reconnection_delay_max + delay += self.randomization_factor * (2 * random.random() - 1) + self.logger.info( + 'Connection failed, new attempt in {:.02f} seconds'.format( + delay)) + if self._reconnect_abort.wait(delay): + self.logger.info('Reconnect task aborted') + break + attempt_count += 1 + try: + self.connect(self.connection_url, + headers=self.connection_headers, + transports=self.connection_transports, + namespaces=self.connection_namespaces, + socketio_path=self.socketio_path) + except (exceptions.ConnectionError, ValueError): + pass + else: + self.logger.info('Reconnection successful') + self._reconnect_task = None + break + if self.reconnection_attempts and \ + attempt_count >= self.reconnection_attempts: + self.logger.info( + 'Maximum reconnection attempts reached, giving up') + break + reconnecting_clients.remove(self) + + def _handle_eio_connect(self): + """Handle the Engine.IO connection event.""" + self.logger.info('Engine.IO connection established') + self.sid = self.eio.sid + + def _handle_eio_message(self, data): + """Dispatch Engine.IO messages.""" + if self._binary_packet: + pkt = self._binary_packet + if pkt.add_attachment(data): + self._binary_packet = None + if pkt.packet_type == packet.BINARY_EVENT: + self._handle_event(pkt.namespace, pkt.id, pkt.data) + else: + self._handle_ack(pkt.namespace, pkt.id, pkt.data) + else: + pkt = packet.Packet(encoded_packet=data) + if pkt.packet_type == packet.CONNECT: + self._handle_connect(pkt.namespace) + elif pkt.packet_type == packet.DISCONNECT: + self._handle_disconnect(pkt.namespace) + elif pkt.packet_type == packet.EVENT: + self._handle_event(pkt.namespace, pkt.id, pkt.data) + elif pkt.packet_type == packet.ACK: + self._handle_ack(pkt.namespace, pkt.id, pkt.data) + elif pkt.packet_type == packet.BINARY_EVENT or \ + pkt.packet_type == packet.BINARY_ACK: + self._binary_packet = pkt + elif pkt.packet_type == packet.ERROR: + self._handle_error(pkt.namespace, pkt.data) + else: + raise ValueError('Unknown packet type.') + + def _handle_eio_disconnect(self): + """Handle the Engine.IO disconnection event.""" + self.logger.info('Engine.IO connection dropped') + if self.connected: + for n in self.namespaces: + self._trigger_event('disconnect', namespace=n) + self._trigger_event('disconnect', namespace='/') + self.namespaces = [] + self.connected = False + self.callbacks = {} + self._binary_packet = None + self.sid = None + if self.eio.state == 'connected' and self.reconnection: + self._reconnect_task = self.start_background_task( + self._handle_reconnect) + + def _engineio_client_class(self): + return engineio.Client diff --git a/libs/socketio/exceptions.py b/libs/socketio/exceptions.py new file mode 100644 index 000000000..36dddd9fc --- /dev/null +++ b/libs/socketio/exceptions.py @@ -0,0 +1,30 @@ +class SocketIOError(Exception): + pass + + +class ConnectionError(SocketIOError): + pass + + +class ConnectionRefusedError(ConnectionError): + """Connection refused exception. + + This exception can be raised from a connect handler when the connection + is not accepted. The positional arguments provided with the exception are + returned with the error packet to the client. + """ + def __init__(self, *args): + if len(args) == 0: + self.error_args = None + elif len(args) == 1 and not isinstance(args[0], list): + self.error_args = args[0] + else: + self.error_args = args + + +class TimeoutError(SocketIOError): + pass + + +class BadNamespaceError(SocketIOError): + pass diff --git a/libs/socketio/kafka_manager.py b/libs/socketio/kafka_manager.py new file mode 100644 index 000000000..00a2e7f05 --- /dev/null +++ b/libs/socketio/kafka_manager.py @@ -0,0 +1,63 @@ +import logging +import pickle + +try: + import kafka +except ImportError: + kafka = None + +from .pubsub_manager import PubSubManager + +logger = logging.getLogger('socketio') + + +class KafkaManager(PubSubManager): # pragma: no cover + """Kafka based client manager. + + This class implements a Kafka backend for event sharing across multiple + processes. + + To use a Kafka backend, initialize the :class:`Server` instance as + follows:: + + url = 'kafka://hostname:port' + server = socketio.Server(client_manager=socketio.KafkaManager(url)) + + :param url: The connection URL for the Kafka server. For a default Kafka + store running on the same host, use ``kafka://``. + :param channel: The channel name (topic) on which the server sends and + receives notifications. Must be the same in all the + servers. + :param write_only: If set ot ``True``, only initialize to emit events. The + default of ``False`` initializes the class for emitting + and receiving. + """ + name = 'kafka' + + def __init__(self, url='kafka://localhost:9092', channel='socketio', + write_only=False): + if kafka is None: + raise RuntimeError('kafka-python package is not installed ' + '(Run "pip install kafka-python" in your ' + 'virtualenv).') + + super(KafkaManager, self).__init__(channel=channel, + write_only=write_only) + + self.kafka_url = url[8:] if url != 'kafka://' else 'localhost:9092' + self.producer = kafka.KafkaProducer(bootstrap_servers=self.kafka_url) + self.consumer = kafka.KafkaConsumer(self.channel, + bootstrap_servers=self.kafka_url) + + def _publish(self, data): + self.producer.send(self.channel, value=pickle.dumps(data)) + self.producer.flush() + + def _kafka_listen(self): + for message in self.consumer: + yield message + + def _listen(self): + for message in self._kafka_listen(): + if message.topic == self.channel: + yield pickle.loads(message.value) diff --git a/libs/socketio/kombu_manager.py b/libs/socketio/kombu_manager.py new file mode 100644 index 000000000..4eb9ee498 --- /dev/null +++ b/libs/socketio/kombu_manager.py @@ -0,0 +1,122 @@ +import pickle +import uuid + +try: + import kombu +except ImportError: + kombu = None + +from .pubsub_manager import PubSubManager + + +class KombuManager(PubSubManager): # pragma: no cover + """Client manager that uses kombu for inter-process messaging. + + This class implements a client manager backend for event sharing across + multiple processes, using RabbitMQ, Redis or any other messaging mechanism + supported by `kombu <http://kombu.readthedocs.org/en/latest/>`_. + + To use a kombu backend, initialize the :class:`Server` instance as + follows:: + + url = 'amqp://user:password@hostname:port//' + server = socketio.Server(client_manager=socketio.KombuManager(url)) + + :param url: The connection URL for the backend messaging queue. Example + connection URLs are ``'amqp://guest:guest@localhost:5672//'`` + and ``'redis://localhost:6379/'`` for RabbitMQ and Redis + respectively. Consult the `kombu documentation + <http://kombu.readthedocs.org/en/latest/userguide\ + /connections.html#urls>`_ for more on how to construct + connection URLs. + :param channel: The channel name on which the server sends and receives + notifications. Must be the same in all the servers. + :param write_only: If set ot ``True``, only initialize to emit events. The + default of ``False`` initializes the class for emitting + and receiving. + :param connection_options: additional keyword arguments to be passed to + ``kombu.Connection()``. + :param exchange_options: additional keyword arguments to be passed to + ``kombu.Exchange()``. + :param queue_options: additional keyword arguments to be passed to + ``kombu.Queue()``. + :param producer_options: additional keyword arguments to be passed to + ``kombu.Producer()``. + """ + name = 'kombu' + + def __init__(self, url='amqp://guest:guest@localhost:5672//', + channel='socketio', write_only=False, logger=None, + connection_options=None, exchange_options=None, + queue_options=None, producer_options=None): + if kombu is None: + raise RuntimeError('Kombu package is not installed ' + '(Run "pip install kombu" in your ' + 'virtualenv).') + super(KombuManager, self).__init__(channel=channel, + write_only=write_only, + logger=logger) + self.url = url + self.connection_options = connection_options or {} + self.exchange_options = exchange_options or {} + self.queue_options = queue_options or {} + self.producer_options = producer_options or {} + self.producer = self._producer() + + def initialize(self): + super(KombuManager, self).initialize() + + monkey_patched = True + if self.server.async_mode == 'eventlet': + from eventlet.patcher import is_monkey_patched + monkey_patched = is_monkey_patched('socket') + elif 'gevent' in self.server.async_mode: + from gevent.monkey import is_module_patched + monkey_patched = is_module_patched('socket') + if not monkey_patched: + raise RuntimeError( + 'Kombu requires a monkey patched socket library to work ' + 'with ' + self.server.async_mode) + + def _connection(self): + return kombu.Connection(self.url, **self.connection_options) + + def _exchange(self): + options = {'type': 'fanout', 'durable': False} + options.update(self.exchange_options) + return kombu.Exchange(self.channel, **options) + + def _queue(self): + queue_name = 'flask-socketio.' + str(uuid.uuid4()) + options = {'durable': False, 'queue_arguments': {'x-expires': 300000}} + options.update(self.queue_options) + return kombu.Queue(queue_name, self._exchange(), **options) + + def _producer(self): + return self._connection().Producer(exchange=self._exchange(), + **self.producer_options) + + def __error_callback(self, exception, interval): + self._get_logger().exception('Sleeping {}s'.format(interval)) + + def _publish(self, data): + connection = self._connection() + publish = connection.ensure(self.producer, self.producer.publish, + errback=self.__error_callback) + publish(pickle.dumps(data)) + + def _listen(self): + reader_queue = self._queue() + + while True: + connection = self._connection().ensure_connection( + errback=self.__error_callback) + try: + with connection.SimpleQueue(reader_queue) as queue: + while True: + message = queue.get(block=True) + message.ack() + yield message.payload + except connection.connection_errors: + self._get_logger().exception("Connection error " + "while reading from queue") diff --git a/libs/socketio/middleware.py b/libs/socketio/middleware.py new file mode 100644 index 000000000..1a6974085 --- /dev/null +++ b/libs/socketio/middleware.py @@ -0,0 +1,42 @@ +import engineio + + +class WSGIApp(engineio.WSGIApp): + """WSGI middleware for Socket.IO. + + This middleware dispatches traffic to a Socket.IO application. It can also + serve a list of static files to the client, or forward unrelated HTTP + traffic to another WSGI application. + + :param socketio_app: The Socket.IO server. Must be an instance of the + ``socketio.Server`` class. + :param wsgi_app: The WSGI app that receives all other traffic. + :param static_files: A dictionary with static file mapping rules. See the + documentation for details on this argument. + :param socketio_path: The endpoint where the Socket.IO application should + be installed. The default value is appropriate for + most cases. + + Example usage:: + + import socketio + import eventlet + from . import wsgi_app + + sio = socketio.Server() + app = socketio.WSGIApp(sio, wsgi_app) + eventlet.wsgi.server(eventlet.listen(('', 8000)), app) + """ + def __init__(self, socketio_app, wsgi_app=None, static_files=None, + socketio_path='socket.io'): + super(WSGIApp, self).__init__(socketio_app, wsgi_app, + static_files=static_files, + engineio_path=socketio_path) + + +class Middleware(WSGIApp): + """This class has been renamed to WSGIApp and is now deprecated.""" + def __init__(self, socketio_app, wsgi_app=None, + socketio_path='socket.io'): + super(Middleware, self).__init__(socketio_app, wsgi_app, + socketio_path=socketio_path) diff --git a/libs/socketio/namespace.py b/libs/socketio/namespace.py new file mode 100644 index 000000000..418615ff8 --- /dev/null +++ b/libs/socketio/namespace.py @@ -0,0 +1,191 @@ +class BaseNamespace(object): + def __init__(self, namespace=None): + self.namespace = namespace or '/' + + def is_asyncio_based(self): + return False + + def trigger_event(self, event, *args): + """Dispatch an event to the proper handler method. + + In the most common usage, this method is not overloaded by subclasses, + as it performs the routing of events to methods. However, this + method can be overriden if special dispatching rules are needed, or if + having a single method that catches all events is desired. + """ + handler_name = 'on_' + event + if hasattr(self, handler_name): + return getattr(self, handler_name)(*args) + + +class Namespace(BaseNamespace): + """Base class for server-side class-based namespaces. + + A class-based namespace is a class that contains all the event handlers + for a Socket.IO namespace. The event handlers are methods of the class + with the prefix ``on_``, such as ``on_connect``, ``on_disconnect``, + ``on_message``, ``on_json``, and so on. + + :param namespace: The Socket.IO namespace to be used with all the event + handlers defined in this class. If this argument is + omitted, the default namespace is used. + """ + def __init__(self, namespace=None): + super(Namespace, self).__init__(namespace=namespace) + self.server = None + + def _set_server(self, server): + self.server = server + + def emit(self, event, data=None, room=None, skip_sid=None, namespace=None, + callback=None): + """Emit a custom event to one or more connected clients. + + The only difference with the :func:`socketio.Server.emit` method is + that when the ``namespace`` argument is not given the namespace + associated with the class is used. + """ + return self.server.emit(event, data=data, room=room, skip_sid=skip_sid, + namespace=namespace or self.namespace, + callback=callback) + + def send(self, data, room=None, skip_sid=None, namespace=None, + callback=None): + """Send a message to one or more connected clients. + + The only difference with the :func:`socketio.Server.send` method is + that when the ``namespace`` argument is not given the namespace + associated with the class is used. + """ + return self.server.send(data, room=room, skip_sid=skip_sid, + namespace=namespace or self.namespace, + callback=callback) + + def enter_room(self, sid, room, namespace=None): + """Enter a room. + + The only difference with the :func:`socketio.Server.enter_room` method + is that when the ``namespace`` argument is not given the namespace + associated with the class is used. + """ + return self.server.enter_room(sid, room, + namespace=namespace or self.namespace) + + def leave_room(self, sid, room, namespace=None): + """Leave a room. + + The only difference with the :func:`socketio.Server.leave_room` method + is that when the ``namespace`` argument is not given the namespace + associated with the class is used. + """ + return self.server.leave_room(sid, room, + namespace=namespace or self.namespace) + + def close_room(self, room, namespace=None): + """Close a room. + + The only difference with the :func:`socketio.Server.close_room` method + is that when the ``namespace`` argument is not given the namespace + associated with the class is used. + """ + return self.server.close_room(room, + namespace=namespace or self.namespace) + + def rooms(self, sid, namespace=None): + """Return the rooms a client is in. + + The only difference with the :func:`socketio.Server.rooms` method is + that when the ``namespace`` argument is not given the namespace + associated with the class is used. + """ + return self.server.rooms(sid, namespace=namespace or self.namespace) + + def get_session(self, sid, namespace=None): + """Return the user session for a client. + + The only difference with the :func:`socketio.Server.get_session` + method is that when the ``namespace`` argument is not given the + namespace associated with the class is used. + """ + return self.server.get_session( + sid, namespace=namespace or self.namespace) + + def save_session(self, sid, session, namespace=None): + """Store the user session for a client. + + The only difference with the :func:`socketio.Server.save_session` + method is that when the ``namespace`` argument is not given the + namespace associated with the class is used. + """ + return self.server.save_session( + sid, session, namespace=namespace or self.namespace) + + def session(self, sid, namespace=None): + """Return the user session for a client with context manager syntax. + + The only difference with the :func:`socketio.Server.session` method is + that when the ``namespace`` argument is not given the namespace + associated with the class is used. + """ + return self.server.session(sid, namespace=namespace or self.namespace) + + def disconnect(self, sid, namespace=None): + """Disconnect a client. + + The only difference with the :func:`socketio.Server.disconnect` method + is that when the ``namespace`` argument is not given the namespace + associated with the class is used. + """ + return self.server.disconnect(sid, + namespace=namespace or self.namespace) + + +class ClientNamespace(BaseNamespace): + """Base class for client-side class-based namespaces. + + A class-based namespace is a class that contains all the event handlers + for a Socket.IO namespace. The event handlers are methods of the class + with the prefix ``on_``, such as ``on_connect``, ``on_disconnect``, + ``on_message``, ``on_json``, and so on. + + :param namespace: The Socket.IO namespace to be used with all the event + handlers defined in this class. If this argument is + omitted, the default namespace is used. + """ + def __init__(self, namespace=None): + super(ClientNamespace, self).__init__(namespace=namespace) + self.client = None + + def _set_client(self, client): + self.client = client + + def emit(self, event, data=None, namespace=None, callback=None): + """Emit a custom event to the server. + + The only difference with the :func:`socketio.Client.emit` method is + that when the ``namespace`` argument is not given the namespace + associated with the class is used. + """ + return self.client.emit(event, data=data, + namespace=namespace or self.namespace, + callback=callback) + + def send(self, data, room=None, skip_sid=None, namespace=None, + callback=None): + """Send a message to the server. + + The only difference with the :func:`socketio.Client.send` method is + that when the ``namespace`` argument is not given the namespace + associated with the class is used. + """ + return self.client.send(data, namespace=namespace or self.namespace, + callback=callback) + + def disconnect(self): + """Disconnect from the server. + + The only difference with the :func:`socketio.Client.disconnect` method + is that when the ``namespace`` argument is not given the namespace + associated with the class is used. + """ + return self.client.disconnect() diff --git a/libs/socketio/packet.py b/libs/socketio/packet.py new file mode 100644 index 000000000..73b469d6d --- /dev/null +++ b/libs/socketio/packet.py @@ -0,0 +1,179 @@ +import functools +import json as _json + +import six + +(CONNECT, DISCONNECT, EVENT, ACK, ERROR, BINARY_EVENT, BINARY_ACK) = \ + (0, 1, 2, 3, 4, 5, 6) +packet_names = ['CONNECT', 'DISCONNECT', 'EVENT', 'ACK', 'ERROR', + 'BINARY_EVENT', 'BINARY_ACK'] + + +class Packet(object): + """Socket.IO packet.""" + + # the format of the Socket.IO packet is as follows: + # + # packet type: 1 byte, values 0-6 + # num_attachments: ASCII encoded, only if num_attachments != 0 + # '-': only if num_attachments != 0 + # namespace: only if namespace != '/' + # ',': only if namespace and one of id and data are defined in this packet + # id: ASCII encoded, only if id is not None + # data: JSON dump of data payload + + json = _json + + def __init__(self, packet_type=EVENT, data=None, namespace=None, id=None, + binary=None, encoded_packet=None): + self.packet_type = packet_type + self.data = data + self.namespace = namespace + self.id = id + if binary or (binary is None and self._data_is_binary(self.data)): + if self.packet_type == EVENT: + self.packet_type = BINARY_EVENT + elif self.packet_type == ACK: + self.packet_type = BINARY_ACK + else: + raise ValueError('Packet does not support binary payload.') + self.attachment_count = 0 + self.attachments = [] + if encoded_packet: + self.attachment_count = self.decode(encoded_packet) + + def encode(self): + """Encode the packet for transmission. + + If the packet contains binary elements, this function returns a list + of packets where the first is the original packet with placeholders for + the binary components and the remaining ones the binary attachments. + """ + encoded_packet = six.text_type(self.packet_type) + if self.packet_type == BINARY_EVENT or self.packet_type == BINARY_ACK: + data, attachments = self._deconstruct_binary(self.data) + encoded_packet += six.text_type(len(attachments)) + '-' + else: + data = self.data + attachments = None + needs_comma = False + if self.namespace is not None and self.namespace != '/': + encoded_packet += self.namespace + needs_comma = True + if self.id is not None: + if needs_comma: + encoded_packet += ',' + needs_comma = False + encoded_packet += six.text_type(self.id) + if data is not None: + if needs_comma: + encoded_packet += ',' + encoded_packet += self.json.dumps(data, separators=(',', ':')) + if attachments is not None: + encoded_packet = [encoded_packet] + attachments + return encoded_packet + + def decode(self, encoded_packet): + """Decode a transmitted package. + + The return value indicates how many binary attachment packets are + necessary to fully decode the packet. + """ + ep = encoded_packet + try: + self.packet_type = int(ep[0:1]) + except TypeError: + self.packet_type = ep + ep = '' + self.namespace = None + self.data = None + ep = ep[1:] + dash = ep.find('-') + attachment_count = 0 + if dash > 0 and ep[0:dash].isdigit(): + attachment_count = int(ep[0:dash]) + ep = ep[dash + 1:] + if ep and ep[0:1] == '/': + sep = ep.find(',') + if sep == -1: + self.namespace = ep + ep = '' + else: + self.namespace = ep[0:sep] + ep = ep[sep + 1:] + q = self.namespace.find('?') + if q != -1: + self.namespace = self.namespace[0:q] + if ep and ep[0].isdigit(): + self.id = 0 + while ep and ep[0].isdigit(): + self.id = self.id * 10 + int(ep[0]) + ep = ep[1:] + if ep: + self.data = self.json.loads(ep) + return attachment_count + + def add_attachment(self, attachment): + if self.attachment_count <= len(self.attachments): + raise ValueError('Unexpected binary attachment') + self.attachments.append(attachment) + if self.attachment_count == len(self.attachments): + self.reconstruct_binary(self.attachments) + return True + return False + + def reconstruct_binary(self, attachments): + """Reconstruct a decoded packet using the given list of binary + attachments. + """ + self.data = self._reconstruct_binary_internal(self.data, + self.attachments) + + def _reconstruct_binary_internal(self, data, attachments): + if isinstance(data, list): + return [self._reconstruct_binary_internal(item, attachments) + for item in data] + elif isinstance(data, dict): + if data.get('_placeholder') and 'num' in data: + return attachments[data['num']] + else: + return {key: self._reconstruct_binary_internal(value, + attachments) + for key, value in six.iteritems(data)} + else: + return data + + def _deconstruct_binary(self, data): + """Extract binary components in the packet.""" + attachments = [] + data = self._deconstruct_binary_internal(data, attachments) + return data, attachments + + def _deconstruct_binary_internal(self, data, attachments): + if isinstance(data, six.binary_type): + attachments.append(data) + return {'_placeholder': True, 'num': len(attachments) - 1} + elif isinstance(data, list): + return [self._deconstruct_binary_internal(item, attachments) + for item in data] + elif isinstance(data, dict): + return {key: self._deconstruct_binary_internal(value, attachments) + for key, value in six.iteritems(data)} + else: + return data + + def _data_is_binary(self, data): + """Check if the data contains binary components.""" + if isinstance(data, six.binary_type): + return True + elif isinstance(data, list): + return functools.reduce( + lambda a, b: a or b, [self._data_is_binary(item) + for item in data], False) + elif isinstance(data, dict): + return functools.reduce( + lambda a, b: a or b, [self._data_is_binary(item) + for item in six.itervalues(data)], + False) + else: + return False diff --git a/libs/socketio/pubsub_manager.py b/libs/socketio/pubsub_manager.py new file mode 100644 index 000000000..2905b2c32 --- /dev/null +++ b/libs/socketio/pubsub_manager.py @@ -0,0 +1,154 @@ +from functools import partial +import uuid + +import json +import pickle +import six + +from .base_manager import BaseManager + + +class PubSubManager(BaseManager): + """Manage a client list attached to a pub/sub backend. + + This is a base class that enables multiple servers to share the list of + clients, with the servers communicating events through a pub/sub backend. + The use of a pub/sub backend also allows any client connected to the + backend to emit events addressed to Socket.IO clients. + + The actual backends must be implemented by subclasses, this class only + provides a pub/sub generic framework. + + :param channel: The channel name on which the server sends and receives + notifications. + """ + name = 'pubsub' + + def __init__(self, channel='socketio', write_only=False, logger=None): + super(PubSubManager, self).__init__() + self.channel = channel + self.write_only = write_only + self.host_id = uuid.uuid4().hex + self.logger = logger + + def initialize(self): + super(PubSubManager, self).initialize() + if not self.write_only: + self.thread = self.server.start_background_task(self._thread) + self._get_logger().info(self.name + ' backend initialized.') + + def emit(self, event, data, namespace=None, room=None, skip_sid=None, + callback=None, **kwargs): + """Emit a message to a single client, a room, or all the clients + connected to the namespace. + + This method takes care or propagating the message to all the servers + that are connected through the message queue. + + The parameters are the same as in :meth:`.Server.emit`. + """ + if kwargs.get('ignore_queue'): + return super(PubSubManager, self).emit( + event, data, namespace=namespace, room=room, skip_sid=skip_sid, + callback=callback) + namespace = namespace or '/' + if callback is not None: + if self.server is None: + raise RuntimeError('Callbacks can only be issued from the ' + 'context of a server.') + if room is None: + raise ValueError('Cannot use callback without a room set.') + id = self._generate_ack_id(room, namespace, callback) + callback = (room, namespace, id) + else: + callback = None + self._publish({'method': 'emit', 'event': event, 'data': data, + 'namespace': namespace, 'room': room, + 'skip_sid': skip_sid, 'callback': callback, + 'host_id': self.host_id}) + + def close_room(self, room, namespace=None): + self._publish({'method': 'close_room', 'room': room, + 'namespace': namespace or '/'}) + + def _publish(self, data): + """Publish a message on the Socket.IO channel. + + This method needs to be implemented by the different subclasses that + support pub/sub backends. + """ + raise NotImplementedError('This method must be implemented in a ' + 'subclass.') # pragma: no cover + + def _listen(self): + """Return the next message published on the Socket.IO channel, + blocking until a message is available. + + This method needs to be implemented by the different subclasses that + support pub/sub backends. + """ + raise NotImplementedError('This method must be implemented in a ' + 'subclass.') # pragma: no cover + + def _handle_emit(self, message): + # Events with callbacks are very tricky to handle across hosts + # Here in the receiving end we set up a local callback that preserves + # the callback host and id from the sender + remote_callback = message.get('callback') + remote_host_id = message.get('host_id') + if remote_callback is not None and len(remote_callback) == 3: + callback = partial(self._return_callback, remote_host_id, + *remote_callback) + else: + callback = None + super(PubSubManager, self).emit(message['event'], message['data'], + namespace=message.get('namespace'), + room=message.get('room'), + skip_sid=message.get('skip_sid'), + callback=callback) + + def _handle_callback(self, message): + if self.host_id == message.get('host_id'): + try: + sid = message['sid'] + namespace = message['namespace'] + id = message['id'] + args = message['args'] + except KeyError: + return + self.trigger_callback(sid, namespace, id, args) + + def _return_callback(self, host_id, sid, namespace, callback_id, *args): + # When an event callback is received, the callback is returned back + # the sender, which is identified by the host_id + self._publish({'method': 'callback', 'host_id': host_id, + 'sid': sid, 'namespace': namespace, 'id': callback_id, + 'args': args}) + + def _handle_close_room(self, message): + super(PubSubManager, self).close_room( + room=message.get('room'), namespace=message.get('namespace')) + + def _thread(self): + for message in self._listen(): + data = None + if isinstance(message, dict): + data = message + else: + if isinstance(message, six.binary_type): # pragma: no cover + try: + data = pickle.loads(message) + except: + pass + if data is None: + try: + data = json.loads(message) + except: + pass + if data and 'method' in data: + if data['method'] == 'emit': + self._handle_emit(data) + elif data['method'] == 'callback': + self._handle_callback(data) + elif data['method'] == 'close_room': + self._handle_close_room(data) diff --git a/libs/socketio/redis_manager.py b/libs/socketio/redis_manager.py new file mode 100644 index 000000000..ad383345e --- /dev/null +++ b/libs/socketio/redis_manager.py @@ -0,0 +1,115 @@ +import logging +import pickle +import time + +try: + import redis +except ImportError: + redis = None + +from .pubsub_manager import PubSubManager + +logger = logging.getLogger('socketio') + + +class RedisManager(PubSubManager): # pragma: no cover + """Redis based client manager. + + This class implements a Redis backend for event sharing across multiple + processes. Only kept here as one more example of how to build a custom + backend, since the kombu backend is perfectly adequate to support a Redis + message queue. + + To use a Redis backend, initialize the :class:`Server` instance as + follows:: + + url = 'redis://hostname:port/0' + server = socketio.Server(client_manager=socketio.RedisManager(url)) + + :param url: The connection URL for the Redis server. For a default Redis + store running on the same host, use ``redis://``. + :param channel: The channel name on which the server sends and receives + notifications. Must be the same in all the servers. + :param write_only: If set ot ``True``, only initialize to emit events. The + default of ``False`` initializes the class for emitting + and receiving. + :param redis_options: additional keyword arguments to be passed to + ``Redis.from_url()``. + """ + name = 'redis' + + def __init__(self, url='redis://localhost:6379/0', channel='socketio', + write_only=False, logger=None, redis_options=None): + if redis is None: + raise RuntimeError('Redis package is not installed ' + '(Run "pip install redis" in your ' + 'virtualenv).') + self.redis_url = url + self.redis_options = redis_options or {} + self._redis_connect() + super(RedisManager, self).__init__(channel=channel, + write_only=write_only, + logger=logger) + + def initialize(self): + super(RedisManager, self).initialize() + + monkey_patched = True + if self.server.async_mode == 'eventlet': + from eventlet.patcher import is_monkey_patched + monkey_patched = is_monkey_patched('socket') + elif 'gevent' in self.server.async_mode: + from gevent.monkey import is_module_patched + monkey_patched = is_module_patched('socket') + if not monkey_patched: + raise RuntimeError( + 'Redis requires a monkey patched socket library to work ' + 'with ' + self.server.async_mode) + + def _redis_connect(self): + self.redis = redis.Redis.from_url(self.redis_url, + **self.redis_options) + self.pubsub = self.redis.pubsub() + + def _publish(self, data): + retry = True + while True: + try: + if not retry: + self._redis_connect() + return self.redis.publish(self.channel, pickle.dumps(data)) + except redis.exceptions.ConnectionError: + if retry: + logger.error('Cannot publish to redis... retrying') + retry = False + else: + logger.error('Cannot publish to redis... giving up') + break + + def _redis_listen_with_retries(self): + retry_sleep = 1 + connect = False + while True: + try: + if connect: + self._redis_connect() + self.pubsub.subscribe(self.channel) + for message in self.pubsub.listen(): + yield message + except redis.exceptions.ConnectionError: + logger.error('Cannot receive from redis... ' + 'retrying in {} secs'.format(retry_sleep)) + connect = True + time.sleep(retry_sleep) + retry_sleep *= 2 + if retry_sleep > 60: + retry_sleep = 60 + + def _listen(self): + channel = self.channel.encode('utf-8') + self.pubsub.subscribe(self.channel) + for message in self._redis_listen_with_retries(): + if message['channel'] == channel and \ + message['type'] == 'message' and 'data' in message: + yield message['data'] + self.pubsub.unsubscribe(self.channel) diff --git a/libs/socketio/server.py b/libs/socketio/server.py new file mode 100644 index 000000000..76b7d2e8f --- /dev/null +++ b/libs/socketio/server.py @@ -0,0 +1,730 @@ +import logging + +import engineio +import six + +from . import base_manager +from . import exceptions +from . import namespace +from . import packet + +default_logger = logging.getLogger('socketio.server') + + +class Server(object): + """A Socket.IO server. + + This class implements a fully compliant Socket.IO web server with support + for websocket and long-polling transports. + + :param client_manager: The client manager instance that will manage the + client list. When this is omitted, the client list + is stored in an in-memory structure, so the use of + multiple connected servers is not possible. + :param logger: To enable logging set to ``True`` or pass a logger object to + use. To disable logging set to ``False``. The default is + ``False``. + :param binary: ``True`` to support binary payloads, ``False`` to treat all + payloads as text. On Python 2, if this is set to ``True``, + ``unicode`` values are treated as text, and ``str`` and + ``bytes`` values are treated as binary. This option has no + effect on Python 3, where text and binary payloads are + always automatically discovered. + :param json: An alternative json module to use for encoding and decoding + packets. Custom json modules must have ``dumps`` and ``loads`` + functions that are compatible with the standard library + versions. + :param async_handlers: If set to ``True``, event handlers for a client are + executed in separate threads. To run handlers for a + client synchronously, set to ``False``. The default + is ``True``. + :param always_connect: When set to ``False``, new connections are + provisory until the connect handler returns + something other than ``False``, at which point they + are accepted. When set to ``True``, connections are + immediately accepted, and then if the connect + handler returns ``False`` a disconnect is issued. + Set to ``True`` if you need to emit events from the + connect handler and your client is confused when it + receives events before the connection acceptance. + In any other case use the default of ``False``. + :param kwargs: Connection parameters for the underlying Engine.IO server. + + The Engine.IO configuration supports the following settings: + + :param async_mode: The asynchronous model to use. See the Deployment + section in the documentation for a description of the + available options. Valid async modes are "threading", + "eventlet", "gevent" and "gevent_uwsgi". If this + argument is not given, "eventlet" is tried first, then + "gevent_uwsgi", then "gevent", and finally "threading". + The first async mode that has all its dependencies + installed is then one that is chosen. + :param ping_timeout: The time in seconds that the client waits for the + server to respond before disconnecting. The default + is 60 seconds. + :param ping_interval: The interval in seconds at which the client pings + the server. The default is 25 seconds. + :param max_http_buffer_size: The maximum size of a message when using the + polling transport. The default is 100,000,000 + bytes. + :param allow_upgrades: Whether to allow transport upgrades or not. The + default is ``True``. + :param http_compression: Whether to compress packages when using the + polling transport. The default is ``True``. + :param compression_threshold: Only compress messages when their byte size + is greater than this value. The default is + 1024 bytes. + :param cookie: Name of the HTTP cookie that contains the client session + id. If set to ``None``, a cookie is not sent to the client. + The default is ``'io'``. + :param cors_allowed_origins: Origin or list of origins that are allowed to + connect to this server. Only the same origin + is allowed by default. Set this argument to + ``'*'`` to allow all origins, or to ``[]`` to + disable CORS handling. + :param cors_credentials: Whether credentials (cookies, authentication) are + allowed in requests to this server. The default is + ``True``. + :param monitor_clients: If set to ``True``, a background task will ensure + inactive clients are closed. Set to ``False`` to + disable the monitoring task (not recommended). The + default is ``True``. + :param engineio_logger: To enable Engine.IO logging set to ``True`` or pass + a logger object to use. To disable logging set to + ``False``. The default is ``False``. + """ + def __init__(self, client_manager=None, logger=False, binary=False, + json=None, async_handlers=True, always_connect=False, + **kwargs): + engineio_options = kwargs + engineio_logger = engineio_options.pop('engineio_logger', None) + if engineio_logger is not None: + engineio_options['logger'] = engineio_logger + if json is not None: + packet.Packet.json = json + engineio_options['json'] = json + engineio_options['async_handlers'] = False + self.eio = self._engineio_server_class()(**engineio_options) + self.eio.on('connect', self._handle_eio_connect) + self.eio.on('message', self._handle_eio_message) + self.eio.on('disconnect', self._handle_eio_disconnect) + self.binary = binary + + self.environ = {} + self.handlers = {} + self.namespace_handlers = {} + + self._binary_packet = {} + + if not isinstance(logger, bool): + self.logger = logger + else: + self.logger = default_logger + if not logging.root.handlers and \ + self.logger.level == logging.NOTSET: + if logger: + self.logger.setLevel(logging.INFO) + else: + self.logger.setLevel(logging.ERROR) + self.logger.addHandler(logging.StreamHandler()) + + if client_manager is None: + client_manager = base_manager.BaseManager() + self.manager = client_manager + self.manager.set_server(self) + self.manager_initialized = False + + self.async_handlers = async_handlers + self.always_connect = always_connect + + self.async_mode = self.eio.async_mode + + def is_asyncio_based(self): + return False + + def on(self, event, handler=None, namespace=None): + """Register an event handler. + + :param event: The event name. It can be any string. The event names + ``'connect'``, ``'message'`` and ``'disconnect'`` are + reserved and should not be used. + :param handler: The function that should be invoked to handle the + event. When this parameter is not given, the method + acts as a decorator for the handler function. + :param namespace: The Socket.IO namespace for the event. If this + argument is omitted the handler is associated with + the default namespace. + + Example usage:: + + # as a decorator: + @socket_io.on('connect', namespace='/chat') + def connect_handler(sid, environ): + print('Connection request') + if environ['REMOTE_ADDR'] in blacklisted: + return False # reject + + # as a method: + def message_handler(sid, msg): + print('Received message: ', msg) + eio.send(sid, 'response') + socket_io.on('message', namespace='/chat', message_handler) + + The handler function receives the ``sid`` (session ID) for the + client as first argument. The ``'connect'`` event handler receives the + WSGI environment as a second argument, and can return ``False`` to + reject the connection. The ``'message'`` handler and handlers for + custom event names receive the message payload as a second argument. + Any values returned from a message handler will be passed to the + client's acknowledgement callback function if it exists. The + ``'disconnect'`` handler does not take a second argument. + """ + namespace = namespace or '/' + + def set_handler(handler): + if namespace not in self.handlers: + self.handlers[namespace] = {} + self.handlers[namespace][event] = handler + return handler + + if handler is None: + return set_handler + set_handler(handler) + + def event(self, *args, **kwargs): + """Decorator to register an event handler. + + This is a simplified version of the ``on()`` method that takes the + event name from the decorated function. + + Example usage:: + + @sio.event + def my_event(data): + print('Received data: ', data) + + The above example is equivalent to:: + + @sio.on('my_event') + def my_event(data): + print('Received data: ', data) + + A custom namespace can be given as an argument to the decorator:: + + @sio.event(namespace='/test') + def my_event(data): + print('Received data: ', data) + """ + if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): + # the decorator was invoked without arguments + # args[0] is the decorated function + return self.on(args[0].__name__)(args[0]) + else: + # the decorator was invoked with arguments + def set_handler(handler): + return self.on(handler.__name__, *args, **kwargs)(handler) + + return set_handler + + def register_namespace(self, namespace_handler): + """Register a namespace handler object. + + :param namespace_handler: An instance of a :class:`Namespace` + subclass that handles all the event traffic + for a namespace. + """ + if not isinstance(namespace_handler, namespace.Namespace): + raise ValueError('Not a namespace instance') + if self.is_asyncio_based() != namespace_handler.is_asyncio_based(): + raise ValueError('Not a valid namespace class for this server') + namespace_handler._set_server(self) + self.namespace_handlers[namespace_handler.namespace] = \ + namespace_handler + + def emit(self, event, data=None, to=None, room=None, skip_sid=None, + namespace=None, callback=None, **kwargs): + """Emit a custom event to one or more connected clients. + + :param event: The event name. It can be any string. The event names + ``'connect'``, ``'message'`` and ``'disconnect'`` are + reserved and should not be used. + :param data: The data to send to the client or clients. Data can be of + type ``str``, ``bytes``, ``list`` or ``dict``. If a + ``list`` or ``dict``, the data will be serialized as JSON. + :param to: The recipient of the message. This can be set to the + session ID of a client to address only that client, or to + to any custom room created by the application to address all + the clients in that room, If this argument is omitted the + event is broadcasted to all connected clients. + :param room: Alias for the ``to`` parameter. + :param skip_sid: The session ID of a client to skip when broadcasting + to a room or to all clients. This can be used to + prevent a message from being sent to the sender. To + skip multiple sids, pass a list. + :param namespace: The Socket.IO namespace for the event. If this + argument is omitted the event is emitted to the + default namespace. + :param callback: If given, this function will be called to acknowledge + the the client has received the message. The arguments + that will be passed to the function are those provided + by the client. Callback functions can only be used + when addressing an individual client. + :param ignore_queue: Only used when a message queue is configured. If + set to ``True``, the event is emitted to the + clients directly, without going through the queue. + This is more efficient, but only works when a + single server process is used. It is recommended + to always leave this parameter with its default + value of ``False``. + """ + namespace = namespace or '/' + room = to or room + self.logger.info('emitting event "%s" to %s [%s]', event, + room or 'all', namespace) + self.manager.emit(event, data, namespace, room=room, + skip_sid=skip_sid, callback=callback, **kwargs) + + def send(self, data, to=None, room=None, skip_sid=None, namespace=None, + callback=None, **kwargs): + """Send a message to one or more connected clients. + + This function emits an event with the name ``'message'``. Use + :func:`emit` to issue custom event names. + + :param data: The data to send to the client or clients. Data can be of + type ``str``, ``bytes``, ``list`` or ``dict``. If a + ``list`` or ``dict``, the data will be serialized as JSON. + :param to: The recipient of the message. This can be set to the + session ID of a client to address only that client, or to + to any custom room created by the application to address all + the clients in that room, If this argument is omitted the + event is broadcasted to all connected clients. + :param room: Alias for the ``to`` parameter. + :param skip_sid: The session ID of a client to skip when broadcasting + to a room or to all clients. This can be used to + prevent a message from being sent to the sender. To + skip multiple sids, pass a list. + :param namespace: The Socket.IO namespace for the event. If this + argument is omitted the event is emitted to the + default namespace. + :param callback: If given, this function will be called to acknowledge + the the client has received the message. The arguments + that will be passed to the function are those provided + by the client. Callback functions can only be used + when addressing an individual client. + :param ignore_queue: Only used when a message queue is configured. If + set to ``True``, the event is emitted to the + clients directly, without going through the queue. + This is more efficient, but only works when a + single server process is used. It is recommended + to always leave this parameter with its default + value of ``False``. + """ + self.emit('message', data=data, to=to, room=room, skip_sid=skip_sid, + namespace=namespace, callback=callback, **kwargs) + + def call(self, event, data=None, to=None, sid=None, namespace=None, + timeout=60, **kwargs): + """Emit a custom event to a client and wait for the response. + + :param event: The event name. It can be any string. The event names + ``'connect'``, ``'message'`` and ``'disconnect'`` are + reserved and should not be used. + :param data: The data to send to the client or clients. Data can be of + type ``str``, ``bytes``, ``list`` or ``dict``. If a + ``list`` or ``dict``, the data will be serialized as JSON. + :param to: The session ID of the recipient client. + :param sid: Alias for the ``to`` parameter. + :param namespace: The Socket.IO namespace for the event. If this + argument is omitted the event is emitted to the + default namespace. + :param timeout: The waiting timeout. If the timeout is reached before + the client acknowledges the event, then a + ``TimeoutError`` exception is raised. + :param ignore_queue: Only used when a message queue is configured. If + set to ``True``, the event is emitted to the + client directly, without going through the queue. + This is more efficient, but only works when a + single server process is used. It is recommended + to always leave this parameter with its default + value of ``False``. + """ + if not self.async_handlers: + raise RuntimeError( + 'Cannot use call() when async_handlers is False.') + callback_event = self.eio.create_event() + callback_args = [] + + def event_callback(*args): + callback_args.append(args) + callback_event.set() + + self.emit(event, data=data, room=to or sid, namespace=namespace, + callback=event_callback, **kwargs) + if not callback_event.wait(timeout=timeout): + raise exceptions.TimeoutError() + return callback_args[0] if len(callback_args[0]) > 1 \ + else callback_args[0][0] if len(callback_args[0]) == 1 \ + else None + + def enter_room(self, sid, room, namespace=None): + """Enter a room. + + This function adds the client to a room. The :func:`emit` and + :func:`send` functions can optionally broadcast events to all the + clients in a room. + + :param sid: Session ID of the client. + :param room: Room name. If the room does not exist it is created. + :param namespace: The Socket.IO namespace for the event. If this + argument is omitted the default namespace is used. + """ + namespace = namespace or '/' + self.logger.info('%s is entering room %s [%s]', sid, room, namespace) + self.manager.enter_room(sid, namespace, room) + + def leave_room(self, sid, room, namespace=None): + """Leave a room. + + This function removes the client from a room. + + :param sid: Session ID of the client. + :param room: Room name. + :param namespace: The Socket.IO namespace for the event. If this + argument is omitted the default namespace is used. + """ + namespace = namespace or '/' + self.logger.info('%s is leaving room %s [%s]', sid, room, namespace) + self.manager.leave_room(sid, namespace, room) + + def close_room(self, room, namespace=None): + """Close a room. + + This function removes all the clients from the given room. + + :param room: Room name. + :param namespace: The Socket.IO namespace for the event. If this + argument is omitted the default namespace is used. + """ + namespace = namespace or '/' + self.logger.info('room %s is closing [%s]', room, namespace) + self.manager.close_room(room, namespace) + + def rooms(self, sid, namespace=None): + """Return the rooms a client is in. + + :param sid: Session ID of the client. + :param namespace: The Socket.IO namespace for the event. If this + argument is omitted the default namespace is used. + """ + namespace = namespace or '/' + return self.manager.get_rooms(sid, namespace) + + def get_session(self, sid, namespace=None): + """Return the user session for a client. + + :param sid: The session id of the client. + :param namespace: The Socket.IO namespace. If this argument is omitted + the default namespace is used. + + The return value is a dictionary. Modifications made to this + dictionary are not guaranteed to be preserved unless + ``save_session()`` is called, or when the ``session`` context manager + is used. + """ + namespace = namespace or '/' + eio_session = self.eio.get_session(sid) + return eio_session.setdefault(namespace, {}) + + def save_session(self, sid, session, namespace=None): + """Store the user session for a client. + + :param sid: The session id of the client. + :param session: The session dictionary. + :param namespace: The Socket.IO namespace. If this argument is omitted + the default namespace is used. + """ + namespace = namespace or '/' + eio_session = self.eio.get_session(sid) + eio_session[namespace] = session + + def session(self, sid, namespace=None): + """Return the user session for a client with context manager syntax. + + :param sid: The session id of the client. + + This is a context manager that returns the user session dictionary for + the client. Any changes that are made to this dictionary inside the + context manager block are saved back to the session. Example usage:: + + @sio.on('connect') + def on_connect(sid, environ): + username = authenticate_user(environ) + if not username: + return False + with sio.session(sid) as session: + session['username'] = username + + @sio.on('message') + def on_message(sid, msg): + with sio.session(sid) as session: + print('received message from ', session['username']) + """ + class _session_context_manager(object): + def __init__(self, server, sid, namespace): + self.server = server + self.sid = sid + self.namespace = namespace + self.session = None + + def __enter__(self): + self.session = self.server.get_session(sid, + namespace=namespace) + return self.session + + def __exit__(self, *args): + self.server.save_session(sid, self.session, + namespace=namespace) + + return _session_context_manager(self, sid, namespace) + + def disconnect(self, sid, namespace=None): + """Disconnect a client. + + :param sid: Session ID of the client. + :param namespace: The Socket.IO namespace to disconnect. If this + argument is omitted the default namespace is used. + """ + namespace = namespace or '/' + if self.manager.is_connected(sid, namespace=namespace): + self.logger.info('Disconnecting %s [%s]', sid, namespace) + self.manager.pre_disconnect(sid, namespace=namespace) + self._send_packet(sid, packet.Packet(packet.DISCONNECT, + namespace=namespace)) + self._trigger_event('disconnect', namespace, sid) + self.manager.disconnect(sid, namespace=namespace) + if namespace == '/': + self.eio.disconnect(sid) + + def transport(self, sid): + """Return the name of the transport used by the client. + + The two possible values returned by this function are ``'polling'`` + and ``'websocket'``. + + :param sid: The session of the client. + """ + return self.eio.transport(sid) + + def handle_request(self, environ, start_response): + """Handle an HTTP request from the client. + + This is the entry point of the Socket.IO application, using the same + interface as a WSGI application. For the typical usage, this function + is invoked by the :class:`Middleware` instance, but it can be invoked + directly when the middleware is not used. + + :param environ: The WSGI environment. + :param start_response: The WSGI ``start_response`` function. + + This function returns the HTTP response body to deliver to the client + as a byte sequence. + """ + return self.eio.handle_request(environ, start_response) + + def start_background_task(self, target, *args, **kwargs): + """Start a background task using the appropriate async model. + + This is a utility function that applications can use to start a + background task using the method that is compatible with the + selected async mode. + + :param target: the target function to execute. + :param args: arguments to pass to the function. + :param kwargs: keyword arguments to pass to the function. + + This function returns an object compatible with the `Thread` class in + the Python standard library. The `start()` method on this object is + already called by this function. + """ + return self.eio.start_background_task(target, *args, **kwargs) + + def sleep(self, seconds=0): + """Sleep for the requested amount of time using the appropriate async + model. + + This is a utility function that applications can use to put a task to + sleep without having to worry about using the correct call for the + selected async mode. + """ + return self.eio.sleep(seconds) + + def _emit_internal(self, sid, event, data, namespace=None, id=None): + """Send a message to a client.""" + if six.PY2 and not self.binary: + binary = False # pragma: nocover + else: + binary = None + # tuples are expanded to multiple arguments, everything else is sent + # as a single argument + if isinstance(data, tuple): + data = list(data) + else: + data = [data] + self._send_packet(sid, packet.Packet(packet.EVENT, namespace=namespace, + data=[event] + data, id=id, + binary=binary)) + + def _send_packet(self, sid, pkt): + """Send a Socket.IO packet to a client.""" + encoded_packet = pkt.encode() + if isinstance(encoded_packet, list): + binary = False + for ep in encoded_packet: + self.eio.send(sid, ep, binary=binary) + binary = True + else: + self.eio.send(sid, encoded_packet, binary=False) + + def _handle_connect(self, sid, namespace): + """Handle a client connection request.""" + namespace = namespace or '/' + self.manager.connect(sid, namespace) + if self.always_connect: + self._send_packet(sid, packet.Packet(packet.CONNECT, + namespace=namespace)) + fail_reason = None + try: + success = self._trigger_event('connect', namespace, sid, + self.environ[sid]) + except exceptions.ConnectionRefusedError as exc: + fail_reason = exc.error_args + success = False + + if success is False: + if self.always_connect: + self.manager.pre_disconnect(sid, namespace) + self._send_packet(sid, packet.Packet( + packet.DISCONNECT, data=fail_reason, namespace=namespace)) + self.manager.disconnect(sid, namespace) + if not self.always_connect: + self._send_packet(sid, packet.Packet( + packet.ERROR, data=fail_reason, namespace=namespace)) + if sid in self.environ: # pragma: no cover + del self.environ[sid] + elif not self.always_connect: + self._send_packet(sid, packet.Packet(packet.CONNECT, + namespace=namespace)) + + def _handle_disconnect(self, sid, namespace): + """Handle a client disconnect.""" + namespace = namespace or '/' + if namespace == '/': + namespace_list = list(self.manager.get_namespaces()) + else: + namespace_list = [namespace] + for n in namespace_list: + if n != '/' and self.manager.is_connected(sid, n): + self._trigger_event('disconnect', n, sid) + self.manager.disconnect(sid, n) + if namespace == '/' and self.manager.is_connected(sid, namespace): + self._trigger_event('disconnect', '/', sid) + self.manager.disconnect(sid, '/') + + def _handle_event(self, sid, namespace, id, data): + """Handle an incoming client event.""" + namespace = namespace or '/' + self.logger.info('received event "%s" from %s [%s]', data[0], sid, + namespace) + if not self.manager.is_connected(sid, namespace): + self.logger.warning('%s is not connected to namespace %s', + sid, namespace) + return + if self.async_handlers: + self.start_background_task(self._handle_event_internal, self, sid, + data, namespace, id) + else: + self._handle_event_internal(self, sid, data, namespace, id) + + def _handle_event_internal(self, server, sid, data, namespace, id): + r = server._trigger_event(data[0], namespace, sid, *data[1:]) + if id is not None: + # send ACK packet with the response returned by the handler + # tuples are expanded as multiple arguments + if r is None: + data = [] + elif isinstance(r, tuple): + data = list(r) + else: + data = [r] + if six.PY2 and not self.binary: + binary = False # pragma: nocover + else: + binary = None + server._send_packet(sid, packet.Packet(packet.ACK, + namespace=namespace, + id=id, data=data, + binary=binary)) + + def _handle_ack(self, sid, namespace, id, data): + """Handle ACK packets from the client.""" + namespace = namespace or '/' + self.logger.info('received ack from %s [%s]', sid, namespace) + self.manager.trigger_callback(sid, namespace, id, data) + + def _trigger_event(self, event, namespace, *args): + """Invoke an application event handler.""" + # first see if we have an explicit handler for the event + if namespace in self.handlers and event in self.handlers[namespace]: + return self.handlers[namespace][event](*args) + + # or else, forward the event to a namespace handler if one exists + elif namespace in self.namespace_handlers: + return self.namespace_handlers[namespace].trigger_event( + event, *args) + + def _handle_eio_connect(self, sid, environ): + """Handle the Engine.IO connection event.""" + if not self.manager_initialized: + self.manager_initialized = True + self.manager.initialize() + self.environ[sid] = environ + return self._handle_connect(sid, '/') + + def _handle_eio_message(self, sid, data): + """Dispatch Engine.IO messages.""" + if sid in self._binary_packet: + pkt = self._binary_packet[sid] + if pkt.add_attachment(data): + del self._binary_packet[sid] + if pkt.packet_type == packet.BINARY_EVENT: + self._handle_event(sid, pkt.namespace, pkt.id, pkt.data) + else: + self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data) + else: + pkt = packet.Packet(encoded_packet=data) + if pkt.packet_type == packet.CONNECT: + self._handle_connect(sid, pkt.namespace) + elif pkt.packet_type == packet.DISCONNECT: + self._handle_disconnect(sid, pkt.namespace) + elif pkt.packet_type == packet.EVENT: + self._handle_event(sid, pkt.namespace, pkt.id, pkt.data) + elif pkt.packet_type == packet.ACK: + self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data) + elif pkt.packet_type == packet.BINARY_EVENT or \ + pkt.packet_type == packet.BINARY_ACK: + self._binary_packet[sid] = pkt + elif pkt.packet_type == packet.ERROR: + raise ValueError('Unexpected ERROR packet.') + else: + raise ValueError('Unknown packet type.') + + def _handle_eio_disconnect(self, sid): + """Handle Engine.IO disconnect event.""" + self._handle_disconnect(sid, '/') + if sid in self.environ: + del self.environ[sid] + + def _engineio_server_class(self): + return engineio.Server diff --git a/libs/socketio/tornado.py b/libs/socketio/tornado.py new file mode 100644 index 000000000..5b2e6f684 --- /dev/null +++ b/libs/socketio/tornado.py @@ -0,0 +1,11 @@ +import sys +if sys.version_info >= (3, 5): + try: + from engineio.async_drivers.tornado import get_tornado_handler as \ + get_engineio_handler + except ImportError: # pragma: no cover + get_engineio_handler = None + + +def get_tornado_handler(socketio_server): # pragma: no cover + return get_engineio_handler(socketio_server.eio) diff --git a/libs/socketio/zmq_manager.py b/libs/socketio/zmq_manager.py new file mode 100644 index 000000000..f2a2ae5dc --- /dev/null +++ b/libs/socketio/zmq_manager.py @@ -0,0 +1,111 @@ +import pickle +import re + +try: + import eventlet.green.zmq as zmq +except ImportError: + zmq = None +import six + +from .pubsub_manager import PubSubManager + + +class ZmqManager(PubSubManager): # pragma: no cover + """zmq based client manager. + + NOTE: this zmq implementation should be considered experimental at this + time. At this time, eventlet is required to use zmq. + + This class implements a zmq backend for event sharing across multiple + processes. To use a zmq backend, initialize the :class:`Server` instance as + follows:: + + url = 'zmq+tcp://hostname:port1+port2' + server = socketio.Server(client_manager=socketio.ZmqManager(url)) + + :param url: The connection URL for the zmq message broker, + which will need to be provided and running. + :param channel: The channel name on which the server sends and receives + notifications. Must be the same in all the servers. + :param write_only: If set to ``True``, only initialize to emit events. The + default of ``False`` initializes the class for emitting + and receiving. + + A zmq message broker must be running for the zmq_manager to work. + you can write your own or adapt one from the following simple broker + below:: + + import zmq + + receiver = zmq.Context().socket(zmq.PULL) + receiver.bind("tcp://*:5555") + + publisher = zmq.Context().socket(zmq.PUB) + publisher.bind("tcp://*:5556") + + while True: + publisher.send(receiver.recv()) + """ + name = 'zmq' + + def __init__(self, url='zmq+tcp://localhost:5555+5556', + channel='socketio', + write_only=False, + logger=None): + if zmq is None: + raise RuntimeError('zmq package is not installed ' + '(Run "pip install pyzmq" in your ' + 'virtualenv).') + + r = re.compile(r':\d+\+\d+$') + if not (url.startswith('zmq+tcp://') and r.search(url)): + raise RuntimeError('unexpected connection string: ' + url) + + url = url.replace('zmq+', '') + (sink_url, sub_port) = url.split('+') + sink_port = sink_url.split(':')[-1] + sub_url = sink_url.replace(sink_port, sub_port) + + sink = zmq.Context().socket(zmq.PUSH) + sink.connect(sink_url) + + sub = zmq.Context().socket(zmq.SUB) + sub.setsockopt_string(zmq.SUBSCRIBE, u'') + sub.connect(sub_url) + + self.sink = sink + self.sub = sub + self.channel = channel + super(ZmqManager, self).__init__(channel=channel, + write_only=write_only, + logger=logger) + + def _publish(self, data): + pickled_data = pickle.dumps( + { + 'type': 'message', + 'channel': self.channel, + 'data': data + } + ) + return self.sink.send(pickled_data) + + def zmq_listen(self): + while True: + response = self.sub.recv() + if response is not None: + yield response + + def _listen(self): + for message in self.zmq_listen(): + if isinstance(message, six.binary_type): + try: + message = pickle.loads(message) + except Exception: + pass + if isinstance(message, dict) and \ + message['type'] == 'message' and \ + message['channel'] == self.channel and \ + 'data' in message: + yield message['data'] + return diff --git a/libs/yaml/__init__.py b/libs/yaml/__init__.py new file mode 100644 index 000000000..5df0bb5fd --- /dev/null +++ b/libs/yaml/__init__.py @@ -0,0 +1,402 @@ + +from .error import * + +from .tokens import * +from .events import * +from .nodes import * + +from .loader import * +from .dumper import * + +__version__ = '5.1' +try: + from .cyaml import * + __with_libyaml__ = True +except ImportError: + __with_libyaml__ = False + +import io + +#------------------------------------------------------------------------------ +# Warnings control +#------------------------------------------------------------------------------ + +# 'Global' warnings state: +_warnings_enabled = { + 'YAMLLoadWarning': True, +} + +# Get or set global warnings' state +def warnings(settings=None): + if settings is None: + return _warnings_enabled + + if type(settings) is dict: + for key in settings: + if key in _warnings_enabled: + _warnings_enabled[key] = settings[key] + +# Warn when load() is called without Loader=... +class YAMLLoadWarning(RuntimeWarning): + pass + +def load_warning(method): + if _warnings_enabled['YAMLLoadWarning'] is False: + return + + import warnings + + message = ( + "calling yaml.%s() without Loader=... is deprecated, as the " + "default Loader is unsafe. Please read " + "https://msg.pyyaml.org/load for full details." + ) % method + + warnings.warn(message, YAMLLoadWarning, stacklevel=3) + +#------------------------------------------------------------------------------ +def scan(stream, Loader=Loader): + """ + Scan a YAML stream and produce scanning tokens. + """ + loader = Loader(stream) + try: + while loader.check_token(): + yield loader.get_token() + finally: + loader.dispose() + +def parse(stream, Loader=Loader): + """ + Parse a YAML stream and produce parsing events. + """ + loader = Loader(stream) + try: + while loader.check_event(): + yield loader.get_event() + finally: + loader.dispose() + +def compose(stream, Loader=Loader): + """ + Parse the first YAML document in a stream + and produce the corresponding representation tree. + """ + loader = Loader(stream) + try: + return loader.get_single_node() + finally: + loader.dispose() + +def compose_all(stream, Loader=Loader): + """ + Parse all YAML documents in a stream + and produce corresponding representation trees. + """ + loader = Loader(stream) + try: + while loader.check_node(): + yield loader.get_node() + finally: + loader.dispose() + +def load(stream, Loader=None): + """ + Parse the first YAML document in a stream + and produce the corresponding Python object. + """ + if Loader is None: + load_warning('load') + Loader = FullLoader + + loader = Loader(stream) + try: + return loader.get_single_data() + finally: + loader.dispose() + +def load_all(stream, Loader=None): + """ + Parse all YAML documents in a stream + and produce corresponding Python objects. + """ + if Loader is None: + load_warning('load_all') + Loader = FullLoader + + loader = Loader(stream) + try: + while loader.check_data(): + yield loader.get_data() + finally: + loader.dispose() + +def full_load(stream): + """ + Parse the first YAML document in a stream + and produce the corresponding Python object. + + Resolve all tags except those known to be + unsafe on untrusted input. + """ + return load(stream, FullLoader) + +def full_load_all(stream): + """ + Parse all YAML documents in a stream + and produce corresponding Python objects. + + Resolve all tags except those known to be + unsafe on untrusted input. + """ + return load_all(stream, FullLoader) + +def safe_load(stream): + """ + Parse the first YAML document in a stream + and produce the corresponding Python object. + + Resolve only basic YAML tags. This is known + to be safe for untrusted input. + """ + return load(stream, SafeLoader) + +def safe_load_all(stream): + """ + Parse all YAML documents in a stream + and produce corresponding Python objects. + + Resolve only basic YAML tags. This is known + to be safe for untrusted input. + """ + return load_all(stream, SafeLoader) + +def unsafe_load(stream): + """ + Parse the first YAML document in a stream + and produce the corresponding Python object. + + Resolve all tags, even those known to be + unsafe on untrusted input. + """ + return load(stream, UnsafeLoader) + +def unsafe_load_all(stream): + """ + Parse all YAML documents in a stream + and produce corresponding Python objects. + + Resolve all tags, even those known to be + unsafe on untrusted input. + """ + return load_all(stream, UnsafeLoader) + +def emit(events, stream=None, Dumper=Dumper, + canonical=None, indent=None, width=None, + allow_unicode=None, line_break=None): + """ + Emit YAML parsing events into a stream. + If stream is None, return the produced string instead. + """ + getvalue = None + if stream is None: + stream = io.StringIO() + getvalue = stream.getvalue + dumper = Dumper(stream, canonical=canonical, indent=indent, width=width, + allow_unicode=allow_unicode, line_break=line_break) + try: + for event in events: + dumper.emit(event) + finally: + dumper.dispose() + if getvalue: + return getvalue() + +def serialize_all(nodes, stream=None, Dumper=Dumper, + canonical=None, indent=None, width=None, + allow_unicode=None, line_break=None, + encoding=None, explicit_start=None, explicit_end=None, + version=None, tags=None): + """ + Serialize a sequence of representation trees into a YAML stream. + If stream is None, return the produced string instead. + """ + getvalue = None + if stream is None: + if encoding is None: + stream = io.StringIO() + else: + stream = io.BytesIO() + getvalue = stream.getvalue + dumper = Dumper(stream, canonical=canonical, indent=indent, width=width, + allow_unicode=allow_unicode, line_break=line_break, + encoding=encoding, version=version, tags=tags, + explicit_start=explicit_start, explicit_end=explicit_end) + try: + dumper.open() + for node in nodes: + dumper.serialize(node) + dumper.close() + finally: + dumper.dispose() + if getvalue: + return getvalue() + +def serialize(node, stream=None, Dumper=Dumper, **kwds): + """ + Serialize a representation tree into a YAML stream. + If stream is None, return the produced string instead. + """ + return serialize_all([node], stream, Dumper=Dumper, **kwds) + +def dump_all(documents, stream=None, Dumper=Dumper, + default_style=None, default_flow_style=False, + canonical=None, indent=None, width=None, + allow_unicode=None, line_break=None, + encoding=None, explicit_start=None, explicit_end=None, + version=None, tags=None, sort_keys=True): + """ + Serialize a sequence of Python objects into a YAML stream. + If stream is None, return the produced string instead. + """ + getvalue = None + if stream is None: + if encoding is None: + stream = io.StringIO() + else: + stream = io.BytesIO() + getvalue = stream.getvalue + dumper = Dumper(stream, default_style=default_style, + default_flow_style=default_flow_style, + canonical=canonical, indent=indent, width=width, + allow_unicode=allow_unicode, line_break=line_break, + encoding=encoding, version=version, tags=tags, + explicit_start=explicit_start, explicit_end=explicit_end, sort_keys=sort_keys) + try: + dumper.open() + for data in documents: + dumper.represent(data) + dumper.close() + finally: + dumper.dispose() + if getvalue: + return getvalue() + +def dump(data, stream=None, Dumper=Dumper, **kwds): + """ + Serialize a Python object into a YAML stream. + If stream is None, return the produced string instead. + """ + return dump_all([data], stream, Dumper=Dumper, **kwds) + +def safe_dump_all(documents, stream=None, **kwds): + """ + Serialize a sequence of Python objects into a YAML stream. + Produce only basic YAML tags. + If stream is None, return the produced string instead. + """ + return dump_all(documents, stream, Dumper=SafeDumper, **kwds) + +def safe_dump(data, stream=None, **kwds): + """ + Serialize a Python object into a YAML stream. + Produce only basic YAML tags. + If stream is None, return the produced string instead. + """ + return dump_all([data], stream, Dumper=SafeDumper, **kwds) + +def add_implicit_resolver(tag, regexp, first=None, + Loader=Loader, Dumper=Dumper): + """ + Add an implicit scalar detector. + If an implicit scalar value matches the given regexp, + the corresponding tag is assigned to the scalar. + first is a sequence of possible initial characters or None. + """ + Loader.add_implicit_resolver(tag, regexp, first) + Dumper.add_implicit_resolver(tag, regexp, first) + +def add_path_resolver(tag, path, kind=None, Loader=Loader, Dumper=Dumper): + """ + Add a path based resolver for the given tag. + A path is a list of keys that forms a path + to a node in the representation tree. + Keys can be string values, integers, or None. + """ + Loader.add_path_resolver(tag, path, kind) + Dumper.add_path_resolver(tag, path, kind) + +def add_constructor(tag, constructor, Loader=Loader): + """ + Add a constructor for the given tag. + Constructor is a function that accepts a Loader instance + and a node object and produces the corresponding Python object. + """ + Loader.add_constructor(tag, constructor) + +def add_multi_constructor(tag_prefix, multi_constructor, Loader=Loader): + """ + Add a multi-constructor for the given tag prefix. + Multi-constructor is called for a node if its tag starts with tag_prefix. + Multi-constructor accepts a Loader instance, a tag suffix, + and a node object and produces the corresponding Python object. + """ + Loader.add_multi_constructor(tag_prefix, multi_constructor) + +def add_representer(data_type, representer, Dumper=Dumper): + """ + Add a representer for the given type. + Representer is a function accepting a Dumper instance + and an instance of the given data type + and producing the corresponding representation node. + """ + Dumper.add_representer(data_type, representer) + +def add_multi_representer(data_type, multi_representer, Dumper=Dumper): + """ + Add a representer for the given type. + Multi-representer is a function accepting a Dumper instance + and an instance of the given data type or subtype + and producing the corresponding representation node. + """ + Dumper.add_multi_representer(data_type, multi_representer) + +class YAMLObjectMetaclass(type): + """ + The metaclass for YAMLObject. + """ + def __init__(cls, name, bases, kwds): + super(YAMLObjectMetaclass, cls).__init__(name, bases, kwds) + if 'yaml_tag' in kwds and kwds['yaml_tag'] is not None: + cls.yaml_loader.add_constructor(cls.yaml_tag, cls.from_yaml) + cls.yaml_dumper.add_representer(cls, cls.to_yaml) + +class YAMLObject(metaclass=YAMLObjectMetaclass): + """ + An object that can dump itself to a YAML stream + and load itself from a YAML stream. + """ + + __slots__ = () # no direct instantiation, so allow immutable subclasses + + yaml_loader = Loader + yaml_dumper = Dumper + + yaml_tag = None + yaml_flow_style = None + + @classmethod + def from_yaml(cls, loader, node): + """ + Convert a representation node to a Python object. + """ + return loader.construct_yaml_object(node, cls) + + @classmethod + def to_yaml(cls, dumper, data): + """ + Convert a Python object to a representation node. + """ + return dumper.represent_yaml_object(cls.yaml_tag, data, cls, + flow_style=cls.yaml_flow_style) + diff --git a/libs/yaml/composer.py b/libs/yaml/composer.py new file mode 100644 index 000000000..6d15cb40e --- /dev/null +++ b/libs/yaml/composer.py @@ -0,0 +1,139 @@ + +__all__ = ['Composer', 'ComposerError'] + +from .error import MarkedYAMLError +from .events import * +from .nodes import * + +class ComposerError(MarkedYAMLError): + pass + +class Composer: + + def __init__(self): + self.anchors = {} + + def check_node(self): + # Drop the STREAM-START event. + if self.check_event(StreamStartEvent): + self.get_event() + + # If there are more documents available? + return not self.check_event(StreamEndEvent) + + def get_node(self): + # Get the root node of the next document. + if not self.check_event(StreamEndEvent): + return self.compose_document() + + def get_single_node(self): + # Drop the STREAM-START event. + self.get_event() + + # Compose a document if the stream is not empty. + document = None + if not self.check_event(StreamEndEvent): + document = self.compose_document() + + # Ensure that the stream contains no more documents. + if not self.check_event(StreamEndEvent): + event = self.get_event() + raise ComposerError("expected a single document in the stream", + document.start_mark, "but found another document", + event.start_mark) + + # Drop the STREAM-END event. + self.get_event() + + return document + + def compose_document(self): + # Drop the DOCUMENT-START event. + self.get_event() + + # Compose the root node. + node = self.compose_node(None, None) + + # Drop the DOCUMENT-END event. + self.get_event() + + self.anchors = {} + return node + + def compose_node(self, parent, index): + if self.check_event(AliasEvent): + event = self.get_event() + anchor = event.anchor + if anchor not in self.anchors: + raise ComposerError(None, None, "found undefined alias %r" + % anchor, event.start_mark) + return self.anchors[anchor] + event = self.peek_event() + anchor = event.anchor + if anchor is not None: + if anchor in self.anchors: + raise ComposerError("found duplicate anchor %r; first occurrence" + % anchor, self.anchors[anchor].start_mark, + "second occurrence", event.start_mark) + self.descend_resolver(parent, index) + if self.check_event(ScalarEvent): + node = self.compose_scalar_node(anchor) + elif self.check_event(SequenceStartEvent): + node = self.compose_sequence_node(anchor) + elif self.check_event(MappingStartEvent): + node = self.compose_mapping_node(anchor) + self.ascend_resolver() + return node + + def compose_scalar_node(self, anchor): + event = self.get_event() + tag = event.tag + if tag is None or tag == '!': + tag = self.resolve(ScalarNode, event.value, event.implicit) + node = ScalarNode(tag, event.value, + event.start_mark, event.end_mark, style=event.style) + if anchor is not None: + self.anchors[anchor] = node + return node + + def compose_sequence_node(self, anchor): + start_event = self.get_event() + tag = start_event.tag + if tag is None or tag == '!': + tag = self.resolve(SequenceNode, None, start_event.implicit) + node = SequenceNode(tag, [], + start_event.start_mark, None, + flow_style=start_event.flow_style) + if anchor is not None: + self.anchors[anchor] = node + index = 0 + while not self.check_event(SequenceEndEvent): + node.value.append(self.compose_node(node, index)) + index += 1 + end_event = self.get_event() + node.end_mark = end_event.end_mark + return node + + def compose_mapping_node(self, anchor): + start_event = self.get_event() + tag = start_event.tag + if tag is None or tag == '!': + tag = self.resolve(MappingNode, None, start_event.implicit) + node = MappingNode(tag, [], + start_event.start_mark, None, + flow_style=start_event.flow_style) + if anchor is not None: + self.anchors[anchor] = node + while not self.check_event(MappingEndEvent): + #key_event = self.peek_event() + item_key = self.compose_node(node, None) + #if item_key in node.value: + # raise ComposerError("while composing a mapping", start_event.start_mark, + # "found duplicate key", key_event.start_mark) + item_value = self.compose_node(node, item_key) + #node.value[item_key] = item_value + node.value.append((item_key, item_value)) + end_event = self.get_event() + node.end_mark = end_event.end_mark + return node + diff --git a/libs/yaml/constructor.py b/libs/yaml/constructor.py new file mode 100644 index 000000000..34fc1ae92 --- /dev/null +++ b/libs/yaml/constructor.py @@ -0,0 +1,720 @@ + +__all__ = [ + 'BaseConstructor', + 'SafeConstructor', + 'FullConstructor', + 'UnsafeConstructor', + 'Constructor', + 'ConstructorError' +] + +from .error import * +from .nodes import * + +import collections.abc, datetime, base64, binascii, re, sys, types + +class ConstructorError(MarkedYAMLError): + pass + +class BaseConstructor: + + yaml_constructors = {} + yaml_multi_constructors = {} + + def __init__(self): + self.constructed_objects = {} + self.recursive_objects = {} + self.state_generators = [] + self.deep_construct = False + + def check_data(self): + # If there are more documents available? + return self.check_node() + + def get_data(self): + # Construct and return the next document. + if self.check_node(): + return self.construct_document(self.get_node()) + + def get_single_data(self): + # Ensure that the stream contains a single document and construct it. + node = self.get_single_node() + if node is not None: + return self.construct_document(node) + return None + + def construct_document(self, node): + data = self.construct_object(node) + while self.state_generators: + state_generators = self.state_generators + self.state_generators = [] + for generator in state_generators: + for dummy in generator: + pass + self.constructed_objects = {} + self.recursive_objects = {} + self.deep_construct = False + return data + + def construct_object(self, node, deep=False): + if node in self.constructed_objects: + return self.constructed_objects[node] + if deep: + old_deep = self.deep_construct + self.deep_construct = True + if node in self.recursive_objects: + raise ConstructorError(None, None, + "found unconstructable recursive node", node.start_mark) + self.recursive_objects[node] = None + constructor = None + tag_suffix = None + if node.tag in self.yaml_constructors: + constructor = self.yaml_constructors[node.tag] + else: + for tag_prefix in self.yaml_multi_constructors: + if node.tag.startswith(tag_prefix): + tag_suffix = node.tag[len(tag_prefix):] + constructor = self.yaml_multi_constructors[tag_prefix] + break + else: + if None in self.yaml_multi_constructors: + tag_suffix = node.tag + constructor = self.yaml_multi_constructors[None] + elif None in self.yaml_constructors: + constructor = self.yaml_constructors[None] + elif isinstance(node, ScalarNode): + constructor = self.__class__.construct_scalar + elif isinstance(node, SequenceNode): + constructor = self.__class__.construct_sequence + elif isinstance(node, MappingNode): + constructor = self.__class__.construct_mapping + if tag_suffix is None: + data = constructor(self, node) + else: + data = constructor(self, tag_suffix, node) + if isinstance(data, types.GeneratorType): + generator = data + data = next(generator) + if self.deep_construct: + for dummy in generator: + pass + else: + self.state_generators.append(generator) + self.constructed_objects[node] = data + del self.recursive_objects[node] + if deep: + self.deep_construct = old_deep + return data + + def construct_scalar(self, node): + if not isinstance(node, ScalarNode): + raise ConstructorError(None, None, + "expected a scalar node, but found %s" % node.id, + node.start_mark) + return node.value + + def construct_sequence(self, node, deep=False): + if not isinstance(node, SequenceNode): + raise ConstructorError(None, None, + "expected a sequence node, but found %s" % node.id, + node.start_mark) + return [self.construct_object(child, deep=deep) + for child in node.value] + + def construct_mapping(self, node, deep=False): + if not isinstance(node, MappingNode): + raise ConstructorError(None, None, + "expected a mapping node, but found %s" % node.id, + node.start_mark) + mapping = {} + for key_node, value_node in node.value: + key = self.construct_object(key_node, deep=deep) + if not isinstance(key, collections.abc.Hashable): + raise ConstructorError("while constructing a mapping", node.start_mark, + "found unhashable key", key_node.start_mark) + value = self.construct_object(value_node, deep=deep) + mapping[key] = value + return mapping + + def construct_pairs(self, node, deep=False): + if not isinstance(node, MappingNode): + raise ConstructorError(None, None, + "expected a mapping node, but found %s" % node.id, + node.start_mark) + pairs = [] + for key_node, value_node in node.value: + key = self.construct_object(key_node, deep=deep) + value = self.construct_object(value_node, deep=deep) + pairs.append((key, value)) + return pairs + + @classmethod + def add_constructor(cls, tag, constructor): + if not 'yaml_constructors' in cls.__dict__: + cls.yaml_constructors = cls.yaml_constructors.copy() + cls.yaml_constructors[tag] = constructor + + @classmethod + def add_multi_constructor(cls, tag_prefix, multi_constructor): + if not 'yaml_multi_constructors' in cls.__dict__: + cls.yaml_multi_constructors = cls.yaml_multi_constructors.copy() + cls.yaml_multi_constructors[tag_prefix] = multi_constructor + +class SafeConstructor(BaseConstructor): + + def construct_scalar(self, node): + if isinstance(node, MappingNode): + for key_node, value_node in node.value: + if key_node.tag == 'tag:yaml.org,2002:value': + return self.construct_scalar(value_node) + return super().construct_scalar(node) + + def flatten_mapping(self, node): + merge = [] + index = 0 + while index < len(node.value): + key_node, value_node = node.value[index] + if key_node.tag == 'tag:yaml.org,2002:merge': + del node.value[index] + if isinstance(value_node, MappingNode): + self.flatten_mapping(value_node) + merge.extend(value_node.value) + elif isinstance(value_node, SequenceNode): + submerge = [] + for subnode in value_node.value: + if not isinstance(subnode, MappingNode): + raise ConstructorError("while constructing a mapping", + node.start_mark, + "expected a mapping for merging, but found %s" + % subnode.id, subnode.start_mark) + self.flatten_mapping(subnode) + submerge.append(subnode.value) + submerge.reverse() + for value in submerge: + merge.extend(value) + else: + raise ConstructorError("while constructing a mapping", node.start_mark, + "expected a mapping or list of mappings for merging, but found %s" + % value_node.id, value_node.start_mark) + elif key_node.tag == 'tag:yaml.org,2002:value': + key_node.tag = 'tag:yaml.org,2002:str' + index += 1 + else: + index += 1 + if merge: + node.value = merge + node.value + + def construct_mapping(self, node, deep=False): + if isinstance(node, MappingNode): + self.flatten_mapping(node) + return super().construct_mapping(node, deep=deep) + + def construct_yaml_null(self, node): + self.construct_scalar(node) + return None + + bool_values = { + 'yes': True, + 'no': False, + 'true': True, + 'false': False, + 'on': True, + 'off': False, + } + + def construct_yaml_bool(self, node): + value = self.construct_scalar(node) + return self.bool_values[value.lower()] + + def construct_yaml_int(self, node): + value = self.construct_scalar(node) + value = value.replace('_', '') + sign = +1 + if value[0] == '-': + sign = -1 + if value[0] in '+-': + value = value[1:] + if value == '0': + return 0 + elif value.startswith('0b'): + return sign*int(value[2:], 2) + elif value.startswith('0x'): + return sign*int(value[2:], 16) + elif value[0] == '0': + return sign*int(value, 8) + elif ':' in value: + digits = [int(part) for part in value.split(':')] + digits.reverse() + base = 1 + value = 0 + for digit in digits: + value += digit*base + base *= 60 + return sign*value + else: + return sign*int(value) + + inf_value = 1e300 + while inf_value != inf_value*inf_value: + inf_value *= inf_value + nan_value = -inf_value/inf_value # Trying to make a quiet NaN (like C99). + + def construct_yaml_float(self, node): + value = self.construct_scalar(node) + value = value.replace('_', '').lower() + sign = +1 + if value[0] == '-': + sign = -1 + if value[0] in '+-': + value = value[1:] + if value == '.inf': + return sign*self.inf_value + elif value == '.nan': + return self.nan_value + elif ':' in value: + digits = [float(part) for part in value.split(':')] + digits.reverse() + base = 1 + value = 0.0 + for digit in digits: + value += digit*base + base *= 60 + return sign*value + else: + return sign*float(value) + + def construct_yaml_binary(self, node): + try: + value = self.construct_scalar(node).encode('ascii') + except UnicodeEncodeError as exc: + raise ConstructorError(None, None, + "failed to convert base64 data into ascii: %s" % exc, + node.start_mark) + try: + if hasattr(base64, 'decodebytes'): + return base64.decodebytes(value) + else: + return base64.decodestring(value) + except binascii.Error as exc: + raise ConstructorError(None, None, + "failed to decode base64 data: %s" % exc, node.start_mark) + + timestamp_regexp = re.compile( + r'''^(?P<year>[0-9][0-9][0-9][0-9]) + -(?P<month>[0-9][0-9]?) + -(?P<day>[0-9][0-9]?) + (?:(?:[Tt]|[ \t]+) + (?P<hour>[0-9][0-9]?) + :(?P<minute>[0-9][0-9]) + :(?P<second>[0-9][0-9]) + (?:\.(?P<fraction>[0-9]*))? + (?:[ \t]*(?P<tz>Z|(?P<tz_sign>[-+])(?P<tz_hour>[0-9][0-9]?) + (?::(?P<tz_minute>[0-9][0-9]))?))?)?$''', re.X) + + def construct_yaml_timestamp(self, node): + value = self.construct_scalar(node) + match = self.timestamp_regexp.match(node.value) + values = match.groupdict() + year = int(values['year']) + month = int(values['month']) + day = int(values['day']) + if not values['hour']: + return datetime.date(year, month, day) + hour = int(values['hour']) + minute = int(values['minute']) + second = int(values['second']) + fraction = 0 + if values['fraction']: + fraction = values['fraction'][:6] + while len(fraction) < 6: + fraction += '0' + fraction = int(fraction) + delta = None + if values['tz_sign']: + tz_hour = int(values['tz_hour']) + tz_minute = int(values['tz_minute'] or 0) + delta = datetime.timedelta(hours=tz_hour, minutes=tz_minute) + if values['tz_sign'] == '-': + delta = -delta + data = datetime.datetime(year, month, day, hour, minute, second, fraction) + if delta: + data -= delta + return data + + def construct_yaml_omap(self, node): + # Note: we do not check for duplicate keys, because it's too + # CPU-expensive. + omap = [] + yield omap + if not isinstance(node, SequenceNode): + raise ConstructorError("while constructing an ordered map", node.start_mark, + "expected a sequence, but found %s" % node.id, node.start_mark) + for subnode in node.value: + if not isinstance(subnode, MappingNode): + raise ConstructorError("while constructing an ordered map", node.start_mark, + "expected a mapping of length 1, but found %s" % subnode.id, + subnode.start_mark) + if len(subnode.value) != 1: + raise ConstructorError("while constructing an ordered map", node.start_mark, + "expected a single mapping item, but found %d items" % len(subnode.value), + subnode.start_mark) + key_node, value_node = subnode.value[0] + key = self.construct_object(key_node) + value = self.construct_object(value_node) + omap.append((key, value)) + + def construct_yaml_pairs(self, node): + # Note: the same code as `construct_yaml_omap`. + pairs = [] + yield pairs + if not isinstance(node, SequenceNode): + raise ConstructorError("while constructing pairs", node.start_mark, + "expected a sequence, but found %s" % node.id, node.start_mark) + for subnode in node.value: + if not isinstance(subnode, MappingNode): + raise ConstructorError("while constructing pairs", node.start_mark, + "expected a mapping of length 1, but found %s" % subnode.id, + subnode.start_mark) + if len(subnode.value) != 1: + raise ConstructorError("while constructing pairs", node.start_mark, + "expected a single mapping item, but found %d items" % len(subnode.value), + subnode.start_mark) + key_node, value_node = subnode.value[0] + key = self.construct_object(key_node) + value = self.construct_object(value_node) + pairs.append((key, value)) + + def construct_yaml_set(self, node): + data = set() + yield data + value = self.construct_mapping(node) + data.update(value) + + def construct_yaml_str(self, node): + return self.construct_scalar(node) + + def construct_yaml_seq(self, node): + data = [] + yield data + data.extend(self.construct_sequence(node)) + + def construct_yaml_map(self, node): + data = {} + yield data + value = self.construct_mapping(node) + data.update(value) + + def construct_yaml_object(self, node, cls): + data = cls.__new__(cls) + yield data + if hasattr(data, '__setstate__'): + state = self.construct_mapping(node, deep=True) + data.__setstate__(state) + else: + state = self.construct_mapping(node) + data.__dict__.update(state) + + def construct_undefined(self, node): + raise ConstructorError(None, None, + "could not determine a constructor for the tag %r" % node.tag, + node.start_mark) + +SafeConstructor.add_constructor( + 'tag:yaml.org,2002:null', + SafeConstructor.construct_yaml_null) + +SafeConstructor.add_constructor( + 'tag:yaml.org,2002:bool', + SafeConstructor.construct_yaml_bool) + +SafeConstructor.add_constructor( + 'tag:yaml.org,2002:int', + SafeConstructor.construct_yaml_int) + +SafeConstructor.add_constructor( + 'tag:yaml.org,2002:float', + SafeConstructor.construct_yaml_float) + +SafeConstructor.add_constructor( + 'tag:yaml.org,2002:binary', + SafeConstructor.construct_yaml_binary) + +SafeConstructor.add_constructor( + 'tag:yaml.org,2002:timestamp', + SafeConstructor.construct_yaml_timestamp) + +SafeConstructor.add_constructor( + 'tag:yaml.org,2002:omap', + SafeConstructor.construct_yaml_omap) + +SafeConstructor.add_constructor( + 'tag:yaml.org,2002:pairs', + SafeConstructor.construct_yaml_pairs) + +SafeConstructor.add_constructor( + 'tag:yaml.org,2002:set', + SafeConstructor.construct_yaml_set) + +SafeConstructor.add_constructor( + 'tag:yaml.org,2002:str', + SafeConstructor.construct_yaml_str) + +SafeConstructor.add_constructor( + 'tag:yaml.org,2002:seq', + SafeConstructor.construct_yaml_seq) + +SafeConstructor.add_constructor( + 'tag:yaml.org,2002:map', + SafeConstructor.construct_yaml_map) + +SafeConstructor.add_constructor(None, + SafeConstructor.construct_undefined) + +class FullConstructor(SafeConstructor): + + def construct_python_str(self, node): + return self.construct_scalar(node) + + def construct_python_unicode(self, node): + return self.construct_scalar(node) + + def construct_python_bytes(self, node): + try: + value = self.construct_scalar(node).encode('ascii') + except UnicodeEncodeError as exc: + raise ConstructorError(None, None, + "failed to convert base64 data into ascii: %s" % exc, + node.start_mark) + try: + if hasattr(base64, 'decodebytes'): + return base64.decodebytes(value) + else: + return base64.decodestring(value) + except binascii.Error as exc: + raise ConstructorError(None, None, + "failed to decode base64 data: %s" % exc, node.start_mark) + + def construct_python_long(self, node): + return self.construct_yaml_int(node) + + def construct_python_complex(self, node): + return complex(self.construct_scalar(node)) + + def construct_python_tuple(self, node): + return tuple(self.construct_sequence(node)) + + def find_python_module(self, name, mark, unsafe=False): + if not name: + raise ConstructorError("while constructing a Python module", mark, + "expected non-empty name appended to the tag", mark) + if unsafe: + try: + __import__(name) + except ImportError as exc: + raise ConstructorError("while constructing a Python module", mark, + "cannot find module %r (%s)" % (name, exc), mark) + if not name in sys.modules: + raise ConstructorError("while constructing a Python module", mark, + "module %r is not imported" % name, mark) + return sys.modules[name] + + def find_python_name(self, name, mark, unsafe=False): + if not name: + raise ConstructorError("while constructing a Python object", mark, + "expected non-empty name appended to the tag", mark) + if '.' in name: + module_name, object_name = name.rsplit('.', 1) + else: + module_name = 'builtins' + object_name = name + if unsafe: + try: + __import__(module_name) + except ImportError as exc: + raise ConstructorError("while constructing a Python object", mark, + "cannot find module %r (%s)" % (module_name, exc), mark) + if not module_name in sys.modules: + raise ConstructorError("while constructing a Python object", mark, + "module %r is not imported" % module_name, mark) + module = sys.modules[module_name] + if not hasattr(module, object_name): + raise ConstructorError("while constructing a Python object", mark, + "cannot find %r in the module %r" + % (object_name, module.__name__), mark) + return getattr(module, object_name) + + def construct_python_name(self, suffix, node): + value = self.construct_scalar(node) + if value: + raise ConstructorError("while constructing a Python name", node.start_mark, + "expected the empty value, but found %r" % value, node.start_mark) + return self.find_python_name(suffix, node.start_mark) + + def construct_python_module(self, suffix, node): + value = self.construct_scalar(node) + if value: + raise ConstructorError("while constructing a Python module", node.start_mark, + "expected the empty value, but found %r" % value, node.start_mark) + return self.find_python_module(suffix, node.start_mark) + + def make_python_instance(self, suffix, node, + args=None, kwds=None, newobj=False, unsafe=False): + if not args: + args = [] + if not kwds: + kwds = {} + cls = self.find_python_name(suffix, node.start_mark) + if not (unsafe or isinstance(cls, type)): + raise ConstructorError("while constructing a Python instance", node.start_mark, + "expected a class, but found %r" % type(cls), + node.start_mark) + if newobj and isinstance(cls, type): + return cls.__new__(cls, *args, **kwds) + else: + return cls(*args, **kwds) + + def set_python_instance_state(self, instance, state): + if hasattr(instance, '__setstate__'): + instance.__setstate__(state) + else: + slotstate = {} + if isinstance(state, tuple) and len(state) == 2: + state, slotstate = state + if hasattr(instance, '__dict__'): + instance.__dict__.update(state) + elif state: + slotstate.update(state) + for key, value in slotstate.items(): + setattr(object, key, value) + + def construct_python_object(self, suffix, node): + # Format: + # !!python/object:module.name { ... state ... } + instance = self.make_python_instance(suffix, node, newobj=True) + yield instance + deep = hasattr(instance, '__setstate__') + state = self.construct_mapping(node, deep=deep) + self.set_python_instance_state(instance, state) + + def construct_python_object_apply(self, suffix, node, newobj=False): + # Format: + # !!python/object/apply # (or !!python/object/new) + # args: [ ... arguments ... ] + # kwds: { ... keywords ... } + # state: ... state ... + # listitems: [ ... listitems ... ] + # dictitems: { ... dictitems ... } + # or short format: + # !!python/object/apply [ ... arguments ... ] + # The difference between !!python/object/apply and !!python/object/new + # is how an object is created, check make_python_instance for details. + if isinstance(node, SequenceNode): + args = self.construct_sequence(node, deep=True) + kwds = {} + state = {} + listitems = [] + dictitems = {} + else: + value = self.construct_mapping(node, deep=True) + args = value.get('args', []) + kwds = value.get('kwds', {}) + state = value.get('state', {}) + listitems = value.get('listitems', []) + dictitems = value.get('dictitems', {}) + instance = self.make_python_instance(suffix, node, args, kwds, newobj) + if state: + self.set_python_instance_state(instance, state) + if listitems: + instance.extend(listitems) + if dictitems: + for key in dictitems: + instance[key] = dictitems[key] + return instance + + def construct_python_object_new(self, suffix, node): + return self.construct_python_object_apply(suffix, node, newobj=True) + +FullConstructor.add_constructor( + 'tag:yaml.org,2002:python/none', + FullConstructor.construct_yaml_null) + +FullConstructor.add_constructor( + 'tag:yaml.org,2002:python/bool', + FullConstructor.construct_yaml_bool) + +FullConstructor.add_constructor( + 'tag:yaml.org,2002:python/str', + FullConstructor.construct_python_str) + +FullConstructor.add_constructor( + 'tag:yaml.org,2002:python/unicode', + FullConstructor.construct_python_unicode) + +FullConstructor.add_constructor( + 'tag:yaml.org,2002:python/bytes', + FullConstructor.construct_python_bytes) + +FullConstructor.add_constructor( + 'tag:yaml.org,2002:python/int', + FullConstructor.construct_yaml_int) + +FullConstructor.add_constructor( + 'tag:yaml.org,2002:python/long', + FullConstructor.construct_python_long) + +FullConstructor.add_constructor( + 'tag:yaml.org,2002:python/float', + FullConstructor.construct_yaml_float) + +FullConstructor.add_constructor( + 'tag:yaml.org,2002:python/complex', + FullConstructor.construct_python_complex) + +FullConstructor.add_constructor( + 'tag:yaml.org,2002:python/list', + FullConstructor.construct_yaml_seq) + +FullConstructor.add_constructor( + 'tag:yaml.org,2002:python/tuple', + FullConstructor.construct_python_tuple) + +FullConstructor.add_constructor( + 'tag:yaml.org,2002:python/dict', + FullConstructor.construct_yaml_map) + +FullConstructor.add_multi_constructor( + 'tag:yaml.org,2002:python/name:', + FullConstructor.construct_python_name) + +FullConstructor.add_multi_constructor( + 'tag:yaml.org,2002:python/module:', + FullConstructor.construct_python_module) + +FullConstructor.add_multi_constructor( + 'tag:yaml.org,2002:python/object:', + FullConstructor.construct_python_object) + +FullConstructor.add_multi_constructor( + 'tag:yaml.org,2002:python/object/apply:', + FullConstructor.construct_python_object_apply) + +FullConstructor.add_multi_constructor( + 'tag:yaml.org,2002:python/object/new:', + FullConstructor.construct_python_object_new) + +class UnsafeConstructor(FullConstructor): + + def find_python_module(self, name, mark): + return super(UnsafeConstructor, self).find_python_module(name, mark, unsafe=True) + + def find_python_name(self, name, mark): + return super(UnsafeConstructor, self).find_python_name(name, mark, unsafe=True) + + def make_python_instance(self, suffix, node, args=None, kwds=None, newobj=False): + return super(UnsafeConstructor, self).make_python_instance( + suffix, node, args, kwds, newobj, unsafe=True) + +# Constructor is same as UnsafeConstructor. Need to leave this in place in case +# people have extended it directly. +class Constructor(UnsafeConstructor): + pass diff --git a/libs/yaml/cyaml.py b/libs/yaml/cyaml.py new file mode 100644 index 000000000..1e606c74b --- /dev/null +++ b/libs/yaml/cyaml.py @@ -0,0 +1,101 @@ + +__all__ = [ + 'CBaseLoader', 'CSafeLoader', 'CFullLoader', 'CUnsafeLoader', 'CLoader', + 'CBaseDumper', 'CSafeDumper', 'CDumper' +] + +from _yaml import CParser, CEmitter + +from .constructor import * + +from .serializer import * +from .representer import * + +from .resolver import * + +class CBaseLoader(CParser, BaseConstructor, BaseResolver): + + def __init__(self, stream): + CParser.__init__(self, stream) + BaseConstructor.__init__(self) + BaseResolver.__init__(self) + +class CSafeLoader(CParser, SafeConstructor, Resolver): + + def __init__(self, stream): + CParser.__init__(self, stream) + SafeConstructor.__init__(self) + Resolver.__init__(self) + +class CFullLoader(CParser, FullConstructor, Resolver): + + def __init__(self, stream): + CParser.__init__(self, stream) + FullConstructor.__init__(self) + Resolver.__init__(self) + +class CUnsafeLoader(CParser, UnsafeConstructor, Resolver): + + def __init__(self, stream): + CParser.__init__(self, stream) + UnsafeConstructor.__init__(self) + Resolver.__init__(self) + +class CLoader(CParser, Constructor, Resolver): + + def __init__(self, stream): + CParser.__init__(self, stream) + Constructor.__init__(self) + Resolver.__init__(self) + +class CBaseDumper(CEmitter, BaseRepresenter, BaseResolver): + + def __init__(self, stream, + default_style=None, default_flow_style=False, + canonical=None, indent=None, width=None, + allow_unicode=None, line_break=None, + encoding=None, explicit_start=None, explicit_end=None, + version=None, tags=None, sort_keys=True): + CEmitter.__init__(self, stream, canonical=canonical, + indent=indent, width=width, encoding=encoding, + allow_unicode=allow_unicode, line_break=line_break, + explicit_start=explicit_start, explicit_end=explicit_end, + version=version, tags=tags) + Representer.__init__(self, default_style=default_style, + default_flow_style=default_flow_style, sort_keys=sort_keys) + Resolver.__init__(self) + +class CSafeDumper(CEmitter, SafeRepresenter, Resolver): + + def __init__(self, stream, + default_style=None, default_flow_style=False, + canonical=None, indent=None, width=None, + allow_unicode=None, line_break=None, + encoding=None, explicit_start=None, explicit_end=None, + version=None, tags=None, sort_keys=True): + CEmitter.__init__(self, stream, canonical=canonical, + indent=indent, width=width, encoding=encoding, + allow_unicode=allow_unicode, line_break=line_break, + explicit_start=explicit_start, explicit_end=explicit_end, + version=version, tags=tags) + SafeRepresenter.__init__(self, default_style=default_style, + default_flow_style=default_flow_style, sort_keys=sort_keys) + Resolver.__init__(self) + +class CDumper(CEmitter, Serializer, Representer, Resolver): + + def __init__(self, stream, + default_style=None, default_flow_style=False, + canonical=None, indent=None, width=None, + allow_unicode=None, line_break=None, + encoding=None, explicit_start=None, explicit_end=None, + version=None, tags=None, sort_keys=True): + CEmitter.__init__(self, stream, canonical=canonical, + indent=indent, width=width, encoding=encoding, + allow_unicode=allow_unicode, line_break=line_break, + explicit_start=explicit_start, explicit_end=explicit_end, + version=version, tags=tags) + Representer.__init__(self, default_style=default_style, + default_flow_style=default_flow_style, sort_keys=sort_keys) + Resolver.__init__(self) + diff --git a/libs/yaml/dumper.py b/libs/yaml/dumper.py new file mode 100644 index 000000000..6aadba551 --- /dev/null +++ b/libs/yaml/dumper.py @@ -0,0 +1,62 @@ + +__all__ = ['BaseDumper', 'SafeDumper', 'Dumper'] + +from .emitter import * +from .serializer import * +from .representer import * +from .resolver import * + +class BaseDumper(Emitter, Serializer, BaseRepresenter, BaseResolver): + + def __init__(self, stream, + default_style=None, default_flow_style=False, + canonical=None, indent=None, width=None, + allow_unicode=None, line_break=None, + encoding=None, explicit_start=None, explicit_end=None, + version=None, tags=None, sort_keys=True): + Emitter.__init__(self, stream, canonical=canonical, + indent=indent, width=width, + allow_unicode=allow_unicode, line_break=line_break) + Serializer.__init__(self, encoding=encoding, + explicit_start=explicit_start, explicit_end=explicit_end, + version=version, tags=tags) + Representer.__init__(self, default_style=default_style, + default_flow_style=default_flow_style, sort_keys=sort_keys) + Resolver.__init__(self) + +class SafeDumper(Emitter, Serializer, SafeRepresenter, Resolver): + + def __init__(self, stream, + default_style=None, default_flow_style=False, + canonical=None, indent=None, width=None, + allow_unicode=None, line_break=None, + encoding=None, explicit_start=None, explicit_end=None, + version=None, tags=None, sort_keys=True): + Emitter.__init__(self, stream, canonical=canonical, + indent=indent, width=width, + allow_unicode=allow_unicode, line_break=line_break) + Serializer.__init__(self, encoding=encoding, + explicit_start=explicit_start, explicit_end=explicit_end, + version=version, tags=tags) + SafeRepresenter.__init__(self, default_style=default_style, + default_flow_style=default_flow_style, sort_keys=sort_keys) + Resolver.__init__(self) + +class Dumper(Emitter, Serializer, Representer, Resolver): + + def __init__(self, stream, + default_style=None, default_flow_style=False, + canonical=None, indent=None, width=None, + allow_unicode=None, line_break=None, + encoding=None, explicit_start=None, explicit_end=None, + version=None, tags=None, sort_keys=True): + Emitter.__init__(self, stream, canonical=canonical, + indent=indent, width=width, + allow_unicode=allow_unicode, line_break=line_break) + Serializer.__init__(self, encoding=encoding, + explicit_start=explicit_start, explicit_end=explicit_end, + version=version, tags=tags) + Representer.__init__(self, default_style=default_style, + default_flow_style=default_flow_style, sort_keys=sort_keys) + Resolver.__init__(self) + diff --git a/libs/yaml/emitter.py b/libs/yaml/emitter.py new file mode 100644 index 000000000..a664d0111 --- /dev/null +++ b/libs/yaml/emitter.py @@ -0,0 +1,1137 @@ + +# Emitter expects events obeying the following grammar: +# stream ::= STREAM-START document* STREAM-END +# document ::= DOCUMENT-START node DOCUMENT-END +# node ::= SCALAR | sequence | mapping +# sequence ::= SEQUENCE-START node* SEQUENCE-END +# mapping ::= MAPPING-START (node node)* MAPPING-END + +__all__ = ['Emitter', 'EmitterError'] + +from .error import YAMLError +from .events import * + +class EmitterError(YAMLError): + pass + +class ScalarAnalysis: + def __init__(self, scalar, empty, multiline, + allow_flow_plain, allow_block_plain, + allow_single_quoted, allow_double_quoted, + allow_block): + self.scalar = scalar + self.empty = empty + self.multiline = multiline + self.allow_flow_plain = allow_flow_plain + self.allow_block_plain = allow_block_plain + self.allow_single_quoted = allow_single_quoted + self.allow_double_quoted = allow_double_quoted + self.allow_block = allow_block + +class Emitter: + + DEFAULT_TAG_PREFIXES = { + '!' : '!', + 'tag:yaml.org,2002:' : '!!', + } + + def __init__(self, stream, canonical=None, indent=None, width=None, + allow_unicode=None, line_break=None): + + # The stream should have the methods `write` and possibly `flush`. + self.stream = stream + + # Encoding can be overridden by STREAM-START. + self.encoding = None + + # Emitter is a state machine with a stack of states to handle nested + # structures. + self.states = [] + self.state = self.expect_stream_start + + # Current event and the event queue. + self.events = [] + self.event = None + + # The current indentation level and the stack of previous indents. + self.indents = [] + self.indent = None + + # Flow level. + self.flow_level = 0 + + # Contexts. + self.root_context = False + self.sequence_context = False + self.mapping_context = False + self.simple_key_context = False + + # Characteristics of the last emitted character: + # - current position. + # - is it a whitespace? + # - is it an indention character + # (indentation space, '-', '?', or ':')? + self.line = 0 + self.column = 0 + self.whitespace = True + self.indention = True + + # Whether the document requires an explicit document indicator + self.open_ended = False + + # Formatting details. + self.canonical = canonical + self.allow_unicode = allow_unicode + self.best_indent = 2 + if indent and 1 < indent < 10: + self.best_indent = indent + self.best_width = 80 + if width and width > self.best_indent*2: + self.best_width = width + self.best_line_break = '\n' + if line_break in ['\r', '\n', '\r\n']: + self.best_line_break = line_break + + # Tag prefixes. + self.tag_prefixes = None + + # Prepared anchor and tag. + self.prepared_anchor = None + self.prepared_tag = None + + # Scalar analysis and style. + self.analysis = None + self.style = None + + def dispose(self): + # Reset the state attributes (to clear self-references) + self.states = [] + self.state = None + + def emit(self, event): + self.events.append(event) + while not self.need_more_events(): + self.event = self.events.pop(0) + self.state() + self.event = None + + # In some cases, we wait for a few next events before emitting. + + def need_more_events(self): + if not self.events: + return True + event = self.events[0] + if isinstance(event, DocumentStartEvent): + return self.need_events(1) + elif isinstance(event, SequenceStartEvent): + return self.need_events(2) + elif isinstance(event, MappingStartEvent): + return self.need_events(3) + else: + return False + + def need_events(self, count): + level = 0 + for event in self.events[1:]: + if isinstance(event, (DocumentStartEvent, CollectionStartEvent)): + level += 1 + elif isinstance(event, (DocumentEndEvent, CollectionEndEvent)): + level -= 1 + elif isinstance(event, StreamEndEvent): + level = -1 + if level < 0: + return False + return (len(self.events) < count+1) + + def increase_indent(self, flow=False, indentless=False): + self.indents.append(self.indent) + if self.indent is None: + if flow: + self.indent = self.best_indent + else: + self.indent = 0 + elif not indentless: + self.indent += self.best_indent + + # States. + + # Stream handlers. + + def expect_stream_start(self): + if isinstance(self.event, StreamStartEvent): + if self.event.encoding and not hasattr(self.stream, 'encoding'): + self.encoding = self.event.encoding + self.write_stream_start() + self.state = self.expect_first_document_start + else: + raise EmitterError("expected StreamStartEvent, but got %s" + % self.event) + + def expect_nothing(self): + raise EmitterError("expected nothing, but got %s" % self.event) + + # Document handlers. + + def expect_first_document_start(self): + return self.expect_document_start(first=True) + + def expect_document_start(self, first=False): + if isinstance(self.event, DocumentStartEvent): + if (self.event.version or self.event.tags) and self.open_ended: + self.write_indicator('...', True) + self.write_indent() + if self.event.version: + version_text = self.prepare_version(self.event.version) + self.write_version_directive(version_text) + self.tag_prefixes = self.DEFAULT_TAG_PREFIXES.copy() + if self.event.tags: + handles = sorted(self.event.tags.keys()) + for handle in handles: + prefix = self.event.tags[handle] + self.tag_prefixes[prefix] = handle + handle_text = self.prepare_tag_handle(handle) + prefix_text = self.prepare_tag_prefix(prefix) + self.write_tag_directive(handle_text, prefix_text) + implicit = (first and not self.event.explicit and not self.canonical + and not self.event.version and not self.event.tags + and not self.check_empty_document()) + if not implicit: + self.write_indent() + self.write_indicator('---', True) + if self.canonical: + self.write_indent() + self.state = self.expect_document_root + elif isinstance(self.event, StreamEndEvent): + if self.open_ended: + self.write_indicator('...', True) + self.write_indent() + self.write_stream_end() + self.state = self.expect_nothing + else: + raise EmitterError("expected DocumentStartEvent, but got %s" + % self.event) + + def expect_document_end(self): + if isinstance(self.event, DocumentEndEvent): + self.write_indent() + if self.event.explicit: + self.write_indicator('...', True) + self.write_indent() + self.flush_stream() + self.state = self.expect_document_start + else: + raise EmitterError("expected DocumentEndEvent, but got %s" + % self.event) + + def expect_document_root(self): + self.states.append(self.expect_document_end) + self.expect_node(root=True) + + # Node handlers. + + def expect_node(self, root=False, sequence=False, mapping=False, + simple_key=False): + self.root_context = root + self.sequence_context = sequence + self.mapping_context = mapping + self.simple_key_context = simple_key + if isinstance(self.event, AliasEvent): + self.expect_alias() + elif isinstance(self.event, (ScalarEvent, CollectionStartEvent)): + self.process_anchor('&') + self.process_tag() + if isinstance(self.event, ScalarEvent): + self.expect_scalar() + elif isinstance(self.event, SequenceStartEvent): + if self.flow_level or self.canonical or self.event.flow_style \ + or self.check_empty_sequence(): + self.expect_flow_sequence() + else: + self.expect_block_sequence() + elif isinstance(self.event, MappingStartEvent): + if self.flow_level or self.canonical or self.event.flow_style \ + or self.check_empty_mapping(): + self.expect_flow_mapping() + else: + self.expect_block_mapping() + else: + raise EmitterError("expected NodeEvent, but got %s" % self.event) + + def expect_alias(self): + if self.event.anchor is None: + raise EmitterError("anchor is not specified for alias") + self.process_anchor('*') + self.state = self.states.pop() + + def expect_scalar(self): + self.increase_indent(flow=True) + self.process_scalar() + self.indent = self.indents.pop() + self.state = self.states.pop() + + # Flow sequence handlers. + + def expect_flow_sequence(self): + self.write_indicator('[', True, whitespace=True) + self.flow_level += 1 + self.increase_indent(flow=True) + self.state = self.expect_first_flow_sequence_item + + def expect_first_flow_sequence_item(self): + if isinstance(self.event, SequenceEndEvent): + self.indent = self.indents.pop() + self.flow_level -= 1 + self.write_indicator(']', False) + self.state = self.states.pop() + else: + if self.canonical or self.column > self.best_width: + self.write_indent() + self.states.append(self.expect_flow_sequence_item) + self.expect_node(sequence=True) + + def expect_flow_sequence_item(self): + if isinstance(self.event, SequenceEndEvent): + self.indent = self.indents.pop() + self.flow_level -= 1 + if self.canonical: + self.write_indicator(',', False) + self.write_indent() + self.write_indicator(']', False) + self.state = self.states.pop() + else: + self.write_indicator(',', False) + if self.canonical or self.column > self.best_width: + self.write_indent() + self.states.append(self.expect_flow_sequence_item) + self.expect_node(sequence=True) + + # Flow mapping handlers. + + def expect_flow_mapping(self): + self.write_indicator('{', True, whitespace=True) + self.flow_level += 1 + self.increase_indent(flow=True) + self.state = self.expect_first_flow_mapping_key + + def expect_first_flow_mapping_key(self): + if isinstance(self.event, MappingEndEvent): + self.indent = self.indents.pop() + self.flow_level -= 1 + self.write_indicator('}', False) + self.state = self.states.pop() + else: + if self.canonical or self.column > self.best_width: + self.write_indent() + if not self.canonical and self.check_simple_key(): + self.states.append(self.expect_flow_mapping_simple_value) + self.expect_node(mapping=True, simple_key=True) + else: + self.write_indicator('?', True) + self.states.append(self.expect_flow_mapping_value) + self.expect_node(mapping=True) + + def expect_flow_mapping_key(self): + if isinstance(self.event, MappingEndEvent): + self.indent = self.indents.pop() + self.flow_level -= 1 + if self.canonical: + self.write_indicator(',', False) + self.write_indent() + self.write_indicator('}', False) + self.state = self.states.pop() + else: + self.write_indicator(',', False) + if self.canonical or self.column > self.best_width: + self.write_indent() + if not self.canonical and self.check_simple_key(): + self.states.append(self.expect_flow_mapping_simple_value) + self.expect_node(mapping=True, simple_key=True) + else: + self.write_indicator('?', True) + self.states.append(self.expect_flow_mapping_value) + self.expect_node(mapping=True) + + def expect_flow_mapping_simple_value(self): + self.write_indicator(':', False) + self.states.append(self.expect_flow_mapping_key) + self.expect_node(mapping=True) + + def expect_flow_mapping_value(self): + if self.canonical or self.column > self.best_width: + self.write_indent() + self.write_indicator(':', True) + self.states.append(self.expect_flow_mapping_key) + self.expect_node(mapping=True) + + # Block sequence handlers. + + def expect_block_sequence(self): + indentless = (self.mapping_context and not self.indention) + self.increase_indent(flow=False, indentless=indentless) + self.state = self.expect_first_block_sequence_item + + def expect_first_block_sequence_item(self): + return self.expect_block_sequence_item(first=True) + + def expect_block_sequence_item(self, first=False): + if not first and isinstance(self.event, SequenceEndEvent): + self.indent = self.indents.pop() + self.state = self.states.pop() + else: + self.write_indent() + self.write_indicator('-', True, indention=True) + self.states.append(self.expect_block_sequence_item) + self.expect_node(sequence=True) + + # Block mapping handlers. + + def expect_block_mapping(self): + self.increase_indent(flow=False) + self.state = self.expect_first_block_mapping_key + + def expect_first_block_mapping_key(self): + return self.expect_block_mapping_key(first=True) + + def expect_block_mapping_key(self, first=False): + if not first and isinstance(self.event, MappingEndEvent): + self.indent = self.indents.pop() + self.state = self.states.pop() + else: + self.write_indent() + if self.check_simple_key(): + self.states.append(self.expect_block_mapping_simple_value) + self.expect_node(mapping=True, simple_key=True) + else: + self.write_indicator('?', True, indention=True) + self.states.append(self.expect_block_mapping_value) + self.expect_node(mapping=True) + + def expect_block_mapping_simple_value(self): + self.write_indicator(':', False) + self.states.append(self.expect_block_mapping_key) + self.expect_node(mapping=True) + + def expect_block_mapping_value(self): + self.write_indent() + self.write_indicator(':', True, indention=True) + self.states.append(self.expect_block_mapping_key) + self.expect_node(mapping=True) + + # Checkers. + + def check_empty_sequence(self): + return (isinstance(self.event, SequenceStartEvent) and self.events + and isinstance(self.events[0], SequenceEndEvent)) + + def check_empty_mapping(self): + return (isinstance(self.event, MappingStartEvent) and self.events + and isinstance(self.events[0], MappingEndEvent)) + + def check_empty_document(self): + if not isinstance(self.event, DocumentStartEvent) or not self.events: + return False + event = self.events[0] + return (isinstance(event, ScalarEvent) and event.anchor is None + and event.tag is None and event.implicit and event.value == '') + + def check_simple_key(self): + length = 0 + if isinstance(self.event, NodeEvent) and self.event.anchor is not None: + if self.prepared_anchor is None: + self.prepared_anchor = self.prepare_anchor(self.event.anchor) + length += len(self.prepared_anchor) + if isinstance(self.event, (ScalarEvent, CollectionStartEvent)) \ + and self.event.tag is not None: + if self.prepared_tag is None: + self.prepared_tag = self.prepare_tag(self.event.tag) + length += len(self.prepared_tag) + if isinstance(self.event, ScalarEvent): + if self.analysis is None: + self.analysis = self.analyze_scalar(self.event.value) + length += len(self.analysis.scalar) + return (length < 128 and (isinstance(self.event, AliasEvent) + or (isinstance(self.event, ScalarEvent) + and not self.analysis.empty and not self.analysis.multiline) + or self.check_empty_sequence() or self.check_empty_mapping())) + + # Anchor, Tag, and Scalar processors. + + def process_anchor(self, indicator): + if self.event.anchor is None: + self.prepared_anchor = None + return + if self.prepared_anchor is None: + self.prepared_anchor = self.prepare_anchor(self.event.anchor) + if self.prepared_anchor: + self.write_indicator(indicator+self.prepared_anchor, True) + self.prepared_anchor = None + + def process_tag(self): + tag = self.event.tag + if isinstance(self.event, ScalarEvent): + if self.style is None: + self.style = self.choose_scalar_style() + if ((not self.canonical or tag is None) and + ((self.style == '' and self.event.implicit[0]) + or (self.style != '' and self.event.implicit[1]))): + self.prepared_tag = None + return + if self.event.implicit[0] and tag is None: + tag = '!' + self.prepared_tag = None + else: + if (not self.canonical or tag is None) and self.event.implicit: + self.prepared_tag = None + return + if tag is None: + raise EmitterError("tag is not specified") + if self.prepared_tag is None: + self.prepared_tag = self.prepare_tag(tag) + if self.prepared_tag: + self.write_indicator(self.prepared_tag, True) + self.prepared_tag = None + + def choose_scalar_style(self): + if self.analysis is None: + self.analysis = self.analyze_scalar(self.event.value) + if self.event.style == '"' or self.canonical: + return '"' + if not self.event.style and self.event.implicit[0]: + if (not (self.simple_key_context and + (self.analysis.empty or self.analysis.multiline)) + and (self.flow_level and self.analysis.allow_flow_plain + or (not self.flow_level and self.analysis.allow_block_plain))): + return '' + if self.event.style and self.event.style in '|>': + if (not self.flow_level and not self.simple_key_context + and self.analysis.allow_block): + return self.event.style + if not self.event.style or self.event.style == '\'': + if (self.analysis.allow_single_quoted and + not (self.simple_key_context and self.analysis.multiline)): + return '\'' + return '"' + + def process_scalar(self): + if self.analysis is None: + self.analysis = self.analyze_scalar(self.event.value) + if self.style is None: + self.style = self.choose_scalar_style() + split = (not self.simple_key_context) + #if self.analysis.multiline and split \ + # and (not self.style or self.style in '\'\"'): + # self.write_indent() + if self.style == '"': + self.write_double_quoted(self.analysis.scalar, split) + elif self.style == '\'': + self.write_single_quoted(self.analysis.scalar, split) + elif self.style == '>': + self.write_folded(self.analysis.scalar) + elif self.style == '|': + self.write_literal(self.analysis.scalar) + else: + self.write_plain(self.analysis.scalar, split) + self.analysis = None + self.style = None + + # Analyzers. + + def prepare_version(self, version): + major, minor = version + if major != 1: + raise EmitterError("unsupported YAML version: %d.%d" % (major, minor)) + return '%d.%d' % (major, minor) + + def prepare_tag_handle(self, handle): + if not handle: + raise EmitterError("tag handle must not be empty") + if handle[0] != '!' or handle[-1] != '!': + raise EmitterError("tag handle must start and end with '!': %r" % handle) + for ch in handle[1:-1]: + if not ('0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \ + or ch in '-_'): + raise EmitterError("invalid character %r in the tag handle: %r" + % (ch, handle)) + return handle + + def prepare_tag_prefix(self, prefix): + if not prefix: + raise EmitterError("tag prefix must not be empty") + chunks = [] + start = end = 0 + if prefix[0] == '!': + end = 1 + while end < len(prefix): + ch = prefix[end] + if '0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \ + or ch in '-;/?!:@&=+$,_.~*\'()[]': + end += 1 + else: + if start < end: + chunks.append(prefix[start:end]) + start = end = end+1 + data = ch.encode('utf-8') + for ch in data: + chunks.append('%%%02X' % ord(ch)) + if start < end: + chunks.append(prefix[start:end]) + return ''.join(chunks) + + def prepare_tag(self, tag): + if not tag: + raise EmitterError("tag must not be empty") + if tag == '!': + return tag + handle = None + suffix = tag + prefixes = sorted(self.tag_prefixes.keys()) + for prefix in prefixes: + if tag.startswith(prefix) \ + and (prefix == '!' or len(prefix) < len(tag)): + handle = self.tag_prefixes[prefix] + suffix = tag[len(prefix):] + chunks = [] + start = end = 0 + while end < len(suffix): + ch = suffix[end] + if '0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \ + or ch in '-;/?:@&=+$,_.~*\'()[]' \ + or (ch == '!' and handle != '!'): + end += 1 + else: + if start < end: + chunks.append(suffix[start:end]) + start = end = end+1 + data = ch.encode('utf-8') + for ch in data: + chunks.append('%%%02X' % ch) + if start < end: + chunks.append(suffix[start:end]) + suffix_text = ''.join(chunks) + if handle: + return '%s%s' % (handle, suffix_text) + else: + return '!<%s>' % suffix_text + + def prepare_anchor(self, anchor): + if not anchor: + raise EmitterError("anchor must not be empty") + for ch in anchor: + if not ('0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \ + or ch in '-_'): + raise EmitterError("invalid character %r in the anchor: %r" + % (ch, anchor)) + return anchor + + def analyze_scalar(self, scalar): + + # Empty scalar is a special case. + if not scalar: + return ScalarAnalysis(scalar=scalar, empty=True, multiline=False, + allow_flow_plain=False, allow_block_plain=True, + allow_single_quoted=True, allow_double_quoted=True, + allow_block=False) + + # Indicators and special characters. + block_indicators = False + flow_indicators = False + line_breaks = False + special_characters = False + + # Important whitespace combinations. + leading_space = False + leading_break = False + trailing_space = False + trailing_break = False + break_space = False + space_break = False + + # Check document indicators. + if scalar.startswith('---') or scalar.startswith('...'): + block_indicators = True + flow_indicators = True + + # First character or preceded by a whitespace. + preceded_by_whitespace = True + + # Last character or followed by a whitespace. + followed_by_whitespace = (len(scalar) == 1 or + scalar[1] in '\0 \t\r\n\x85\u2028\u2029') + + # The previous character is a space. + previous_space = False + + # The previous character is a break. + previous_break = False + + index = 0 + while index < len(scalar): + ch = scalar[index] + + # Check for indicators. + if index == 0: + # Leading indicators are special characters. + if ch in '#,[]{}&*!|>\'\"%@`': + flow_indicators = True + block_indicators = True + if ch in '?:': + flow_indicators = True + if followed_by_whitespace: + block_indicators = True + if ch == '-' and followed_by_whitespace: + flow_indicators = True + block_indicators = True + else: + # Some indicators cannot appear within a scalar as well. + if ch in ',?[]{}': + flow_indicators = True + if ch == ':': + flow_indicators = True + if followed_by_whitespace: + block_indicators = True + if ch == '#' and preceded_by_whitespace: + flow_indicators = True + block_indicators = True + + # Check for line breaks, special, and unicode characters. + if ch in '\n\x85\u2028\u2029': + line_breaks = True + if not (ch == '\n' or '\x20' <= ch <= '\x7E'): + if (ch == '\x85' or '\xA0' <= ch <= '\uD7FF' + or '\uE000' <= ch <= '\uFFFD' + or '\U00010000' <= ch < '\U0010ffff') and ch != '\uFEFF': + unicode_characters = True + if not self.allow_unicode: + special_characters = True + else: + special_characters = True + + # Detect important whitespace combinations. + if ch == ' ': + if index == 0: + leading_space = True + if index == len(scalar)-1: + trailing_space = True + if previous_break: + break_space = True + previous_space = True + previous_break = False + elif ch in '\n\x85\u2028\u2029': + if index == 0: + leading_break = True + if index == len(scalar)-1: + trailing_break = True + if previous_space: + space_break = True + previous_space = False + previous_break = True + else: + previous_space = False + previous_break = False + + # Prepare for the next character. + index += 1 + preceded_by_whitespace = (ch in '\0 \t\r\n\x85\u2028\u2029') + followed_by_whitespace = (index+1 >= len(scalar) or + scalar[index+1] in '\0 \t\r\n\x85\u2028\u2029') + + # Let's decide what styles are allowed. + allow_flow_plain = True + allow_block_plain = True + allow_single_quoted = True + allow_double_quoted = True + allow_block = True + + # Leading and trailing whitespaces are bad for plain scalars. + if (leading_space or leading_break + or trailing_space or trailing_break): + allow_flow_plain = allow_block_plain = False + + # We do not permit trailing spaces for block scalars. + if trailing_space: + allow_block = False + + # Spaces at the beginning of a new line are only acceptable for block + # scalars. + if break_space: + allow_flow_plain = allow_block_plain = allow_single_quoted = False + + # Spaces followed by breaks, as well as special character are only + # allowed for double quoted scalars. + if space_break or special_characters: + allow_flow_plain = allow_block_plain = \ + allow_single_quoted = allow_block = False + + # Although the plain scalar writer supports breaks, we never emit + # multiline plain scalars. + if line_breaks: + allow_flow_plain = allow_block_plain = False + + # Flow indicators are forbidden for flow plain scalars. + if flow_indicators: + allow_flow_plain = False + + # Block indicators are forbidden for block plain scalars. + if block_indicators: + allow_block_plain = False + + return ScalarAnalysis(scalar=scalar, + empty=False, multiline=line_breaks, + allow_flow_plain=allow_flow_plain, + allow_block_plain=allow_block_plain, + allow_single_quoted=allow_single_quoted, + allow_double_quoted=allow_double_quoted, + allow_block=allow_block) + + # Writers. + + def flush_stream(self): + if hasattr(self.stream, 'flush'): + self.stream.flush() + + def write_stream_start(self): + # Write BOM if needed. + if self.encoding and self.encoding.startswith('utf-16'): + self.stream.write('\uFEFF'.encode(self.encoding)) + + def write_stream_end(self): + self.flush_stream() + + def write_indicator(self, indicator, need_whitespace, + whitespace=False, indention=False): + if self.whitespace or not need_whitespace: + data = indicator + else: + data = ' '+indicator + self.whitespace = whitespace + self.indention = self.indention and indention + self.column += len(data) + self.open_ended = False + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + + def write_indent(self): + indent = self.indent or 0 + if not self.indention or self.column > indent \ + or (self.column == indent and not self.whitespace): + self.write_line_break() + if self.column < indent: + self.whitespace = True + data = ' '*(indent-self.column) + self.column = indent + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + + def write_line_break(self, data=None): + if data is None: + data = self.best_line_break + self.whitespace = True + self.indention = True + self.line += 1 + self.column = 0 + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + + def write_version_directive(self, version_text): + data = '%%YAML %s' % version_text + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + self.write_line_break() + + def write_tag_directive(self, handle_text, prefix_text): + data = '%%TAG %s %s' % (handle_text, prefix_text) + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + self.write_line_break() + + # Scalar streams. + + def write_single_quoted(self, text, split=True): + self.write_indicator('\'', True) + spaces = False + breaks = False + start = end = 0 + while end <= len(text): + ch = None + if end < len(text): + ch = text[end] + if spaces: + if ch is None or ch != ' ': + if start+1 == end and self.column > self.best_width and split \ + and start != 0 and end != len(text): + self.write_indent() + else: + data = text[start:end] + self.column += len(data) + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + start = end + elif breaks: + if ch is None or ch not in '\n\x85\u2028\u2029': + if text[start] == '\n': + self.write_line_break() + for br in text[start:end]: + if br == '\n': + self.write_line_break() + else: + self.write_line_break(br) + self.write_indent() + start = end + else: + if ch is None or ch in ' \n\x85\u2028\u2029' or ch == '\'': + if start < end: + data = text[start:end] + self.column += len(data) + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + start = end + if ch == '\'': + data = '\'\'' + self.column += 2 + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + start = end + 1 + if ch is not None: + spaces = (ch == ' ') + breaks = (ch in '\n\x85\u2028\u2029') + end += 1 + self.write_indicator('\'', False) + + ESCAPE_REPLACEMENTS = { + '\0': '0', + '\x07': 'a', + '\x08': 'b', + '\x09': 't', + '\x0A': 'n', + '\x0B': 'v', + '\x0C': 'f', + '\x0D': 'r', + '\x1B': 'e', + '\"': '\"', + '\\': '\\', + '\x85': 'N', + '\xA0': '_', + '\u2028': 'L', + '\u2029': 'P', + } + + def write_double_quoted(self, text, split=True): + self.write_indicator('"', True) + start = end = 0 + while end <= len(text): + ch = None + if end < len(text): + ch = text[end] + if ch is None or ch in '"\\\x85\u2028\u2029\uFEFF' \ + or not ('\x20' <= ch <= '\x7E' + or (self.allow_unicode + and ('\xA0' <= ch <= '\uD7FF' + or '\uE000' <= ch <= '\uFFFD'))): + if start < end: + data = text[start:end] + self.column += len(data) + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + start = end + if ch is not None: + if ch in self.ESCAPE_REPLACEMENTS: + data = '\\'+self.ESCAPE_REPLACEMENTS[ch] + elif ch <= '\xFF': + data = '\\x%02X' % ord(ch) + elif ch <= '\uFFFF': + data = '\\u%04X' % ord(ch) + else: + data = '\\U%08X' % ord(ch) + self.column += len(data) + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + start = end+1 + if 0 < end < len(text)-1 and (ch == ' ' or start >= end) \ + and self.column+(end-start) > self.best_width and split: + data = text[start:end]+'\\' + if start < end: + start = end + self.column += len(data) + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + self.write_indent() + self.whitespace = False + self.indention = False + if text[start] == ' ': + data = '\\' + self.column += len(data) + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + end += 1 + self.write_indicator('"', False) + + def determine_block_hints(self, text): + hints = '' + if text: + if text[0] in ' \n\x85\u2028\u2029': + hints += str(self.best_indent) + if text[-1] not in '\n\x85\u2028\u2029': + hints += '-' + elif len(text) == 1 or text[-2] in '\n\x85\u2028\u2029': + hints += '+' + return hints + + def write_folded(self, text): + hints = self.determine_block_hints(text) + self.write_indicator('>'+hints, True) + if hints[-1:] == '+': + self.open_ended = True + self.write_line_break() + leading_space = True + spaces = False + breaks = True + start = end = 0 + while end <= len(text): + ch = None + if end < len(text): + ch = text[end] + if breaks: + if ch is None or ch not in '\n\x85\u2028\u2029': + if not leading_space and ch is not None and ch != ' ' \ + and text[start] == '\n': + self.write_line_break() + leading_space = (ch == ' ') + for br in text[start:end]: + if br == '\n': + self.write_line_break() + else: + self.write_line_break(br) + if ch is not None: + self.write_indent() + start = end + elif spaces: + if ch != ' ': + if start+1 == end and self.column > self.best_width: + self.write_indent() + else: + data = text[start:end] + self.column += len(data) + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + start = end + else: + if ch is None or ch in ' \n\x85\u2028\u2029': + data = text[start:end] + self.column += len(data) + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + if ch is None: + self.write_line_break() + start = end + if ch is not None: + breaks = (ch in '\n\x85\u2028\u2029') + spaces = (ch == ' ') + end += 1 + + def write_literal(self, text): + hints = self.determine_block_hints(text) + self.write_indicator('|'+hints, True) + if hints[-1:] == '+': + self.open_ended = True + self.write_line_break() + breaks = True + start = end = 0 + while end <= len(text): + ch = None + if end < len(text): + ch = text[end] + if breaks: + if ch is None or ch not in '\n\x85\u2028\u2029': + for br in text[start:end]: + if br == '\n': + self.write_line_break() + else: + self.write_line_break(br) + if ch is not None: + self.write_indent() + start = end + else: + if ch is None or ch in '\n\x85\u2028\u2029': + data = text[start:end] + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + if ch is None: + self.write_line_break() + start = end + if ch is not None: + breaks = (ch in '\n\x85\u2028\u2029') + end += 1 + + def write_plain(self, text, split=True): + if self.root_context: + self.open_ended = True + if not text: + return + if not self.whitespace: + data = ' ' + self.column += len(data) + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + self.whitespace = False + self.indention = False + spaces = False + breaks = False + start = end = 0 + while end <= len(text): + ch = None + if end < len(text): + ch = text[end] + if spaces: + if ch != ' ': + if start+1 == end and self.column > self.best_width and split: + self.write_indent() + self.whitespace = False + self.indention = False + else: + data = text[start:end] + self.column += len(data) + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + start = end + elif breaks: + if ch not in '\n\x85\u2028\u2029': + if text[start] == '\n': + self.write_line_break() + for br in text[start:end]: + if br == '\n': + self.write_line_break() + else: + self.write_line_break(br) + self.write_indent() + self.whitespace = False + self.indention = False + start = end + else: + if ch is None or ch in ' \n\x85\u2028\u2029': + data = text[start:end] + self.column += len(data) + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + start = end + if ch is not None: + spaces = (ch == ' ') + breaks = (ch in '\n\x85\u2028\u2029') + end += 1 diff --git a/libs/yaml/error.py b/libs/yaml/error.py new file mode 100644 index 000000000..b796b4dc5 --- /dev/null +++ b/libs/yaml/error.py @@ -0,0 +1,75 @@ + +__all__ = ['Mark', 'YAMLError', 'MarkedYAMLError'] + +class Mark: + + def __init__(self, name, index, line, column, buffer, pointer): + self.name = name + self.index = index + self.line = line + self.column = column + self.buffer = buffer + self.pointer = pointer + + def get_snippet(self, indent=4, max_length=75): + if self.buffer is None: + return None + head = '' + start = self.pointer + while start > 0 and self.buffer[start-1] not in '\0\r\n\x85\u2028\u2029': + start -= 1 + if self.pointer-start > max_length/2-1: + head = ' ... ' + start += 5 + break + tail = '' + end = self.pointer + while end < len(self.buffer) and self.buffer[end] not in '\0\r\n\x85\u2028\u2029': + end += 1 + if end-self.pointer > max_length/2-1: + tail = ' ... ' + end -= 5 + break + snippet = self.buffer[start:end] + return ' '*indent + head + snippet + tail + '\n' \ + + ' '*(indent+self.pointer-start+len(head)) + '^' + + def __str__(self): + snippet = self.get_snippet() + where = " in \"%s\", line %d, column %d" \ + % (self.name, self.line+1, self.column+1) + if snippet is not None: + where += ":\n"+snippet + return where + +class YAMLError(Exception): + pass + +class MarkedYAMLError(YAMLError): + + def __init__(self, context=None, context_mark=None, + problem=None, problem_mark=None, note=None): + self.context = context + self.context_mark = context_mark + self.problem = problem + self.problem_mark = problem_mark + self.note = note + + def __str__(self): + lines = [] + if self.context is not None: + lines.append(self.context) + if self.context_mark is not None \ + and (self.problem is None or self.problem_mark is None + or self.context_mark.name != self.problem_mark.name + or self.context_mark.line != self.problem_mark.line + or self.context_mark.column != self.problem_mark.column): + lines.append(str(self.context_mark)) + if self.problem is not None: + lines.append(self.problem) + if self.problem_mark is not None: + lines.append(str(self.problem_mark)) + if self.note is not None: + lines.append(self.note) + return '\n'.join(lines) + diff --git a/libs/yaml/events.py b/libs/yaml/events.py new file mode 100644 index 000000000..f79ad389c --- /dev/null +++ b/libs/yaml/events.py @@ -0,0 +1,86 @@ + +# Abstract classes. + +class Event(object): + def __init__(self, start_mark=None, end_mark=None): + self.start_mark = start_mark + self.end_mark = end_mark + def __repr__(self): + attributes = [key for key in ['anchor', 'tag', 'implicit', 'value'] + if hasattr(self, key)] + arguments = ', '.join(['%s=%r' % (key, getattr(self, key)) + for key in attributes]) + return '%s(%s)' % (self.__class__.__name__, arguments) + +class NodeEvent(Event): + def __init__(self, anchor, start_mark=None, end_mark=None): + self.anchor = anchor + self.start_mark = start_mark + self.end_mark = end_mark + +class CollectionStartEvent(NodeEvent): + def __init__(self, anchor, tag, implicit, start_mark=None, end_mark=None, + flow_style=None): + self.anchor = anchor + self.tag = tag + self.implicit = implicit + self.start_mark = start_mark + self.end_mark = end_mark + self.flow_style = flow_style + +class CollectionEndEvent(Event): + pass + +# Implementations. + +class StreamStartEvent(Event): + def __init__(self, start_mark=None, end_mark=None, encoding=None): + self.start_mark = start_mark + self.end_mark = end_mark + self.encoding = encoding + +class StreamEndEvent(Event): + pass + +class DocumentStartEvent(Event): + def __init__(self, start_mark=None, end_mark=None, + explicit=None, version=None, tags=None): + self.start_mark = start_mark + self.end_mark = end_mark + self.explicit = explicit + self.version = version + self.tags = tags + +class DocumentEndEvent(Event): + def __init__(self, start_mark=None, end_mark=None, + explicit=None): + self.start_mark = start_mark + self.end_mark = end_mark + self.explicit = explicit + +class AliasEvent(NodeEvent): + pass + +class ScalarEvent(NodeEvent): + def __init__(self, anchor, tag, implicit, value, + start_mark=None, end_mark=None, style=None): + self.anchor = anchor + self.tag = tag + self.implicit = implicit + self.value = value + self.start_mark = start_mark + self.end_mark = end_mark + self.style = style + +class SequenceStartEvent(CollectionStartEvent): + pass + +class SequenceEndEvent(CollectionEndEvent): + pass + +class MappingStartEvent(CollectionStartEvent): + pass + +class MappingEndEvent(CollectionEndEvent): + pass + diff --git a/libs/yaml/loader.py b/libs/yaml/loader.py new file mode 100644 index 000000000..414cb2c15 --- /dev/null +++ b/libs/yaml/loader.py @@ -0,0 +1,63 @@ + +__all__ = ['BaseLoader', 'FullLoader', 'SafeLoader', 'Loader', 'UnsafeLoader'] + +from .reader import * +from .scanner import * +from .parser import * +from .composer import * +from .constructor import * +from .resolver import * + +class BaseLoader(Reader, Scanner, Parser, Composer, BaseConstructor, BaseResolver): + + def __init__(self, stream): + Reader.__init__(self, stream) + Scanner.__init__(self) + Parser.__init__(self) + Composer.__init__(self) + BaseConstructor.__init__(self) + BaseResolver.__init__(self) + +class FullLoader(Reader, Scanner, Parser, Composer, FullConstructor, Resolver): + + def __init__(self, stream): + Reader.__init__(self, stream) + Scanner.__init__(self) + Parser.__init__(self) + Composer.__init__(self) + FullConstructor.__init__(self) + Resolver.__init__(self) + +class SafeLoader(Reader, Scanner, Parser, Composer, SafeConstructor, Resolver): + + def __init__(self, stream): + Reader.__init__(self, stream) + Scanner.__init__(self) + Parser.__init__(self) + Composer.__init__(self) + SafeConstructor.__init__(self) + Resolver.__init__(self) + +class Loader(Reader, Scanner, Parser, Composer, Constructor, Resolver): + + def __init__(self, stream): + Reader.__init__(self, stream) + Scanner.__init__(self) + Parser.__init__(self) + Composer.__init__(self) + Constructor.__init__(self) + Resolver.__init__(self) + +# UnsafeLoader is the same as Loader (which is and was always unsafe on +# untrusted input). Use of either Loader or UnsafeLoader should be rare, since +# FullLoad should be able to load almost all YAML safely. Loader is left intact +# to ensure backwards compatability. +class UnsafeLoader(Reader, Scanner, Parser, Composer, Constructor, Resolver): + + def __init__(self, stream): + Reader.__init__(self, stream) + Scanner.__init__(self) + Parser.__init__(self) + Composer.__init__(self) + Constructor.__init__(self) + Resolver.__init__(self) diff --git a/libs/yaml/nodes.py b/libs/yaml/nodes.py new file mode 100644 index 000000000..c4f070c41 --- /dev/null +++ b/libs/yaml/nodes.py @@ -0,0 +1,49 @@ + +class Node(object): + def __init__(self, tag, value, start_mark, end_mark): + self.tag = tag + self.value = value + self.start_mark = start_mark + self.end_mark = end_mark + def __repr__(self): + value = self.value + #if isinstance(value, list): + # if len(value) == 0: + # value = '<empty>' + # elif len(value) == 1: + # value = '<1 item>' + # else: + # value = '<%d items>' % len(value) + #else: + # if len(value) > 75: + # value = repr(value[:70]+u' ... ') + # else: + # value = repr(value) + value = repr(value) + return '%s(tag=%r, value=%s)' % (self.__class__.__name__, self.tag, value) + +class ScalarNode(Node): + id = 'scalar' + def __init__(self, tag, value, + start_mark=None, end_mark=None, style=None): + self.tag = tag + self.value = value + self.start_mark = start_mark + self.end_mark = end_mark + self.style = style + +class CollectionNode(Node): + def __init__(self, tag, value, + start_mark=None, end_mark=None, flow_style=None): + self.tag = tag + self.value = value + self.start_mark = start_mark + self.end_mark = end_mark + self.flow_style = flow_style + +class SequenceNode(CollectionNode): + id = 'sequence' + +class MappingNode(CollectionNode): + id = 'mapping' + diff --git a/libs/yaml/parser.py b/libs/yaml/parser.py new file mode 100644 index 000000000..13a5995d2 --- /dev/null +++ b/libs/yaml/parser.py @@ -0,0 +1,589 @@ + +# The following YAML grammar is LL(1) and is parsed by a recursive descent +# parser. +# +# stream ::= STREAM-START implicit_document? explicit_document* STREAM-END +# implicit_document ::= block_node DOCUMENT-END* +# explicit_document ::= DIRECTIVE* DOCUMENT-START block_node? DOCUMENT-END* +# block_node_or_indentless_sequence ::= +# ALIAS +# | properties (block_content | indentless_block_sequence)? +# | block_content +# | indentless_block_sequence +# block_node ::= ALIAS +# | properties block_content? +# | block_content +# flow_node ::= ALIAS +# | properties flow_content? +# | flow_content +# properties ::= TAG ANCHOR? | ANCHOR TAG? +# block_content ::= block_collection | flow_collection | SCALAR +# flow_content ::= flow_collection | SCALAR +# block_collection ::= block_sequence | block_mapping +# flow_collection ::= flow_sequence | flow_mapping +# block_sequence ::= BLOCK-SEQUENCE-START (BLOCK-ENTRY block_node?)* BLOCK-END +# indentless_sequence ::= (BLOCK-ENTRY block_node?)+ +# block_mapping ::= BLOCK-MAPPING_START +# ((KEY block_node_or_indentless_sequence?)? +# (VALUE block_node_or_indentless_sequence?)?)* +# BLOCK-END +# flow_sequence ::= FLOW-SEQUENCE-START +# (flow_sequence_entry FLOW-ENTRY)* +# flow_sequence_entry? +# FLOW-SEQUENCE-END +# flow_sequence_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? +# flow_mapping ::= FLOW-MAPPING-START +# (flow_mapping_entry FLOW-ENTRY)* +# flow_mapping_entry? +# FLOW-MAPPING-END +# flow_mapping_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? +# +# FIRST sets: +# +# stream: { STREAM-START } +# explicit_document: { DIRECTIVE DOCUMENT-START } +# implicit_document: FIRST(block_node) +# block_node: { ALIAS TAG ANCHOR SCALAR BLOCK-SEQUENCE-START BLOCK-MAPPING-START FLOW-SEQUENCE-START FLOW-MAPPING-START } +# flow_node: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START FLOW-MAPPING-START } +# block_content: { BLOCK-SEQUENCE-START BLOCK-MAPPING-START FLOW-SEQUENCE-START FLOW-MAPPING-START SCALAR } +# flow_content: { FLOW-SEQUENCE-START FLOW-MAPPING-START SCALAR } +# block_collection: { BLOCK-SEQUENCE-START BLOCK-MAPPING-START } +# flow_collection: { FLOW-SEQUENCE-START FLOW-MAPPING-START } +# block_sequence: { BLOCK-SEQUENCE-START } +# block_mapping: { BLOCK-MAPPING-START } +# block_node_or_indentless_sequence: { ALIAS ANCHOR TAG SCALAR BLOCK-SEQUENCE-START BLOCK-MAPPING-START FLOW-SEQUENCE-START FLOW-MAPPING-START BLOCK-ENTRY } +# indentless_sequence: { ENTRY } +# flow_collection: { FLOW-SEQUENCE-START FLOW-MAPPING-START } +# flow_sequence: { FLOW-SEQUENCE-START } +# flow_mapping: { FLOW-MAPPING-START } +# flow_sequence_entry: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START FLOW-MAPPING-START KEY } +# flow_mapping_entry: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START FLOW-MAPPING-START KEY } + +__all__ = ['Parser', 'ParserError'] + +from .error import MarkedYAMLError +from .tokens import * +from .events import * +from .scanner import * + +class ParserError(MarkedYAMLError): + pass + +class Parser: + # Since writing a recursive-descendant parser is a straightforward task, we + # do not give many comments here. + + DEFAULT_TAGS = { + '!': '!', + '!!': 'tag:yaml.org,2002:', + } + + def __init__(self): + self.current_event = None + self.yaml_version = None + self.tag_handles = {} + self.states = [] + self.marks = [] + self.state = self.parse_stream_start + + def dispose(self): + # Reset the state attributes (to clear self-references) + self.states = [] + self.state = None + + def check_event(self, *choices): + # Check the type of the next event. + if self.current_event is None: + if self.state: + self.current_event = self.state() + if self.current_event is not None: + if not choices: + return True + for choice in choices: + if isinstance(self.current_event, choice): + return True + return False + + def peek_event(self): + # Get the next event. + if self.current_event is None: + if self.state: + self.current_event = self.state() + return self.current_event + + def get_event(self): + # Get the next event and proceed further. + if self.current_event is None: + if self.state: + self.current_event = self.state() + value = self.current_event + self.current_event = None + return value + + # stream ::= STREAM-START implicit_document? explicit_document* STREAM-END + # implicit_document ::= block_node DOCUMENT-END* + # explicit_document ::= DIRECTIVE* DOCUMENT-START block_node? DOCUMENT-END* + + def parse_stream_start(self): + + # Parse the stream start. + token = self.get_token() + event = StreamStartEvent(token.start_mark, token.end_mark, + encoding=token.encoding) + + # Prepare the next state. + self.state = self.parse_implicit_document_start + + return event + + def parse_implicit_document_start(self): + + # Parse an implicit document. + if not self.check_token(DirectiveToken, DocumentStartToken, + StreamEndToken): + self.tag_handles = self.DEFAULT_TAGS + token = self.peek_token() + start_mark = end_mark = token.start_mark + event = DocumentStartEvent(start_mark, end_mark, + explicit=False) + + # Prepare the next state. + self.states.append(self.parse_document_end) + self.state = self.parse_block_node + + return event + + else: + return self.parse_document_start() + + def parse_document_start(self): + + # Parse any extra document end indicators. + while self.check_token(DocumentEndToken): + self.get_token() + + # Parse an explicit document. + if not self.check_token(StreamEndToken): + token = self.peek_token() + start_mark = token.start_mark + version, tags = self.process_directives() + if not self.check_token(DocumentStartToken): + raise ParserError(None, None, + "expected '<document start>', but found %r" + % self.peek_token().id, + self.peek_token().start_mark) + token = self.get_token() + end_mark = token.end_mark + event = DocumentStartEvent(start_mark, end_mark, + explicit=True, version=version, tags=tags) + self.states.append(self.parse_document_end) + self.state = self.parse_document_content + else: + # Parse the end of the stream. + token = self.get_token() + event = StreamEndEvent(token.start_mark, token.end_mark) + assert not self.states + assert not self.marks + self.state = None + return event + + def parse_document_end(self): + + # Parse the document end. + token = self.peek_token() + start_mark = end_mark = token.start_mark + explicit = False + if self.check_token(DocumentEndToken): + token = self.get_token() + end_mark = token.end_mark + explicit = True + event = DocumentEndEvent(start_mark, end_mark, + explicit=explicit) + + # Prepare the next state. + self.state = self.parse_document_start + + return event + + def parse_document_content(self): + if self.check_token(DirectiveToken, + DocumentStartToken, DocumentEndToken, StreamEndToken): + event = self.process_empty_scalar(self.peek_token().start_mark) + self.state = self.states.pop() + return event + else: + return self.parse_block_node() + + def process_directives(self): + self.yaml_version = None + self.tag_handles = {} + while self.check_token(DirectiveToken): + token = self.get_token() + if token.name == 'YAML': + if self.yaml_version is not None: + raise ParserError(None, None, + "found duplicate YAML directive", token.start_mark) + major, minor = token.value + if major != 1: + raise ParserError(None, None, + "found incompatible YAML document (version 1.* is required)", + token.start_mark) + self.yaml_version = token.value + elif token.name == 'TAG': + handle, prefix = token.value + if handle in self.tag_handles: + raise ParserError(None, None, + "duplicate tag handle %r" % handle, + token.start_mark) + self.tag_handles[handle] = prefix + if self.tag_handles: + value = self.yaml_version, self.tag_handles.copy() + else: + value = self.yaml_version, None + for key in self.DEFAULT_TAGS: + if key not in self.tag_handles: + self.tag_handles[key] = self.DEFAULT_TAGS[key] + return value + + # block_node_or_indentless_sequence ::= ALIAS + # | properties (block_content | indentless_block_sequence)? + # | block_content + # | indentless_block_sequence + # block_node ::= ALIAS + # | properties block_content? + # | block_content + # flow_node ::= ALIAS + # | properties flow_content? + # | flow_content + # properties ::= TAG ANCHOR? | ANCHOR TAG? + # block_content ::= block_collection | flow_collection | SCALAR + # flow_content ::= flow_collection | SCALAR + # block_collection ::= block_sequence | block_mapping + # flow_collection ::= flow_sequence | flow_mapping + + def parse_block_node(self): + return self.parse_node(block=True) + + def parse_flow_node(self): + return self.parse_node() + + def parse_block_node_or_indentless_sequence(self): + return self.parse_node(block=True, indentless_sequence=True) + + def parse_node(self, block=False, indentless_sequence=False): + if self.check_token(AliasToken): + token = self.get_token() + event = AliasEvent(token.value, token.start_mark, token.end_mark) + self.state = self.states.pop() + else: + anchor = None + tag = None + start_mark = end_mark = tag_mark = None + if self.check_token(AnchorToken): + token = self.get_token() + start_mark = token.start_mark + end_mark = token.end_mark + anchor = token.value + if self.check_token(TagToken): + token = self.get_token() + tag_mark = token.start_mark + end_mark = token.end_mark + tag = token.value + elif self.check_token(TagToken): + token = self.get_token() + start_mark = tag_mark = token.start_mark + end_mark = token.end_mark + tag = token.value + if self.check_token(AnchorToken): + token = self.get_token() + end_mark = token.end_mark + anchor = token.value + if tag is not None: + handle, suffix = tag + if handle is not None: + if handle not in self.tag_handles: + raise ParserError("while parsing a node", start_mark, + "found undefined tag handle %r" % handle, + tag_mark) + tag = self.tag_handles[handle]+suffix + else: + tag = suffix + #if tag == '!': + # raise ParserError("while parsing a node", start_mark, + # "found non-specific tag '!'", tag_mark, + # "Please check 'http://pyyaml.org/wiki/YAMLNonSpecificTag' and share your opinion.") + if start_mark is None: + start_mark = end_mark = self.peek_token().start_mark + event = None + implicit = (tag is None or tag == '!') + if indentless_sequence and self.check_token(BlockEntryToken): + end_mark = self.peek_token().end_mark + event = SequenceStartEvent(anchor, tag, implicit, + start_mark, end_mark) + self.state = self.parse_indentless_sequence_entry + else: + if self.check_token(ScalarToken): + token = self.get_token() + end_mark = token.end_mark + if (token.plain and tag is None) or tag == '!': + implicit = (True, False) + elif tag is None: + implicit = (False, True) + else: + implicit = (False, False) + event = ScalarEvent(anchor, tag, implicit, token.value, + start_mark, end_mark, style=token.style) + self.state = self.states.pop() + elif self.check_token(FlowSequenceStartToken): + end_mark = self.peek_token().end_mark + event = SequenceStartEvent(anchor, tag, implicit, + start_mark, end_mark, flow_style=True) + self.state = self.parse_flow_sequence_first_entry + elif self.check_token(FlowMappingStartToken): + end_mark = self.peek_token().end_mark + event = MappingStartEvent(anchor, tag, implicit, + start_mark, end_mark, flow_style=True) + self.state = self.parse_flow_mapping_first_key + elif block and self.check_token(BlockSequenceStartToken): + end_mark = self.peek_token().start_mark + event = SequenceStartEvent(anchor, tag, implicit, + start_mark, end_mark, flow_style=False) + self.state = self.parse_block_sequence_first_entry + elif block and self.check_token(BlockMappingStartToken): + end_mark = self.peek_token().start_mark + event = MappingStartEvent(anchor, tag, implicit, + start_mark, end_mark, flow_style=False) + self.state = self.parse_block_mapping_first_key + elif anchor is not None or tag is not None: + # Empty scalars are allowed even if a tag or an anchor is + # specified. + event = ScalarEvent(anchor, tag, (implicit, False), '', + start_mark, end_mark) + self.state = self.states.pop() + else: + if block: + node = 'block' + else: + node = 'flow' + token = self.peek_token() + raise ParserError("while parsing a %s node" % node, start_mark, + "expected the node content, but found %r" % token.id, + token.start_mark) + return event + + # block_sequence ::= BLOCK-SEQUENCE-START (BLOCK-ENTRY block_node?)* BLOCK-END + + def parse_block_sequence_first_entry(self): + token = self.get_token() + self.marks.append(token.start_mark) + return self.parse_block_sequence_entry() + + def parse_block_sequence_entry(self): + if self.check_token(BlockEntryToken): + token = self.get_token() + if not self.check_token(BlockEntryToken, BlockEndToken): + self.states.append(self.parse_block_sequence_entry) + return self.parse_block_node() + else: + self.state = self.parse_block_sequence_entry + return self.process_empty_scalar(token.end_mark) + if not self.check_token(BlockEndToken): + token = self.peek_token() + raise ParserError("while parsing a block collection", self.marks[-1], + "expected <block end>, but found %r" % token.id, token.start_mark) + token = self.get_token() + event = SequenceEndEvent(token.start_mark, token.end_mark) + self.state = self.states.pop() + self.marks.pop() + return event + + # indentless_sequence ::= (BLOCK-ENTRY block_node?)+ + + def parse_indentless_sequence_entry(self): + if self.check_token(BlockEntryToken): + token = self.get_token() + if not self.check_token(BlockEntryToken, + KeyToken, ValueToken, BlockEndToken): + self.states.append(self.parse_indentless_sequence_entry) + return self.parse_block_node() + else: + self.state = self.parse_indentless_sequence_entry + return self.process_empty_scalar(token.end_mark) + token = self.peek_token() + event = SequenceEndEvent(token.start_mark, token.start_mark) + self.state = self.states.pop() + return event + + # block_mapping ::= BLOCK-MAPPING_START + # ((KEY block_node_or_indentless_sequence?)? + # (VALUE block_node_or_indentless_sequence?)?)* + # BLOCK-END + + def parse_block_mapping_first_key(self): + token = self.get_token() + self.marks.append(token.start_mark) + return self.parse_block_mapping_key() + + def parse_block_mapping_key(self): + if self.check_token(KeyToken): + token = self.get_token() + if not self.check_token(KeyToken, ValueToken, BlockEndToken): + self.states.append(self.parse_block_mapping_value) + return self.parse_block_node_or_indentless_sequence() + else: + self.state = self.parse_block_mapping_value + return self.process_empty_scalar(token.end_mark) + if not self.check_token(BlockEndToken): + token = self.peek_token() + raise ParserError("while parsing a block mapping", self.marks[-1], + "expected <block end>, but found %r" % token.id, token.start_mark) + token = self.get_token() + event = MappingEndEvent(token.start_mark, token.end_mark) + self.state = self.states.pop() + self.marks.pop() + return event + + def parse_block_mapping_value(self): + if self.check_token(ValueToken): + token = self.get_token() + if not self.check_token(KeyToken, ValueToken, BlockEndToken): + self.states.append(self.parse_block_mapping_key) + return self.parse_block_node_or_indentless_sequence() + else: + self.state = self.parse_block_mapping_key + return self.process_empty_scalar(token.end_mark) + else: + self.state = self.parse_block_mapping_key + token = self.peek_token() + return self.process_empty_scalar(token.start_mark) + + # flow_sequence ::= FLOW-SEQUENCE-START + # (flow_sequence_entry FLOW-ENTRY)* + # flow_sequence_entry? + # FLOW-SEQUENCE-END + # flow_sequence_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? + # + # Note that while production rules for both flow_sequence_entry and + # flow_mapping_entry are equal, their interpretations are different. + # For `flow_sequence_entry`, the part `KEY flow_node? (VALUE flow_node?)?` + # generate an inline mapping (set syntax). + + def parse_flow_sequence_first_entry(self): + token = self.get_token() + self.marks.append(token.start_mark) + return self.parse_flow_sequence_entry(first=True) + + def parse_flow_sequence_entry(self, first=False): + if not self.check_token(FlowSequenceEndToken): + if not first: + if self.check_token(FlowEntryToken): + self.get_token() + else: + token = self.peek_token() + raise ParserError("while parsing a flow sequence", self.marks[-1], + "expected ',' or ']', but got %r" % token.id, token.start_mark) + + if self.check_token(KeyToken): + token = self.peek_token() + event = MappingStartEvent(None, None, True, + token.start_mark, token.end_mark, + flow_style=True) + self.state = self.parse_flow_sequence_entry_mapping_key + return event + elif not self.check_token(FlowSequenceEndToken): + self.states.append(self.parse_flow_sequence_entry) + return self.parse_flow_node() + token = self.get_token() + event = SequenceEndEvent(token.start_mark, token.end_mark) + self.state = self.states.pop() + self.marks.pop() + return event + + def parse_flow_sequence_entry_mapping_key(self): + token = self.get_token() + if not self.check_token(ValueToken, + FlowEntryToken, FlowSequenceEndToken): + self.states.append(self.parse_flow_sequence_entry_mapping_value) + return self.parse_flow_node() + else: + self.state = self.parse_flow_sequence_entry_mapping_value + return self.process_empty_scalar(token.end_mark) + + def parse_flow_sequence_entry_mapping_value(self): + if self.check_token(ValueToken): + token = self.get_token() + if not self.check_token(FlowEntryToken, FlowSequenceEndToken): + self.states.append(self.parse_flow_sequence_entry_mapping_end) + return self.parse_flow_node() + else: + self.state = self.parse_flow_sequence_entry_mapping_end + return self.process_empty_scalar(token.end_mark) + else: + self.state = self.parse_flow_sequence_entry_mapping_end + token = self.peek_token() + return self.process_empty_scalar(token.start_mark) + + def parse_flow_sequence_entry_mapping_end(self): + self.state = self.parse_flow_sequence_entry + token = self.peek_token() + return MappingEndEvent(token.start_mark, token.start_mark) + + # flow_mapping ::= FLOW-MAPPING-START + # (flow_mapping_entry FLOW-ENTRY)* + # flow_mapping_entry? + # FLOW-MAPPING-END + # flow_mapping_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? + + def parse_flow_mapping_first_key(self): + token = self.get_token() + self.marks.append(token.start_mark) + return self.parse_flow_mapping_key(first=True) + + def parse_flow_mapping_key(self, first=False): + if not self.check_token(FlowMappingEndToken): + if not first: + if self.check_token(FlowEntryToken): + self.get_token() + else: + token = self.peek_token() + raise ParserError("while parsing a flow mapping", self.marks[-1], + "expected ',' or '}', but got %r" % token.id, token.start_mark) + if self.check_token(KeyToken): + token = self.get_token() + if not self.check_token(ValueToken, + FlowEntryToken, FlowMappingEndToken): + self.states.append(self.parse_flow_mapping_value) + return self.parse_flow_node() + else: + self.state = self.parse_flow_mapping_value + return self.process_empty_scalar(token.end_mark) + elif not self.check_token(FlowMappingEndToken): + self.states.append(self.parse_flow_mapping_empty_value) + return self.parse_flow_node() + token = self.get_token() + event = MappingEndEvent(token.start_mark, token.end_mark) + self.state = self.states.pop() + self.marks.pop() + return event + + def parse_flow_mapping_value(self): + if self.check_token(ValueToken): + token = self.get_token() + if not self.check_token(FlowEntryToken, FlowMappingEndToken): + self.states.append(self.parse_flow_mapping_key) + return self.parse_flow_node() + else: + self.state = self.parse_flow_mapping_key + return self.process_empty_scalar(token.end_mark) + else: + self.state = self.parse_flow_mapping_key + token = self.peek_token() + return self.process_empty_scalar(token.start_mark) + + def parse_flow_mapping_empty_value(self): + self.state = self.parse_flow_mapping_key + return self.process_empty_scalar(self.peek_token().start_mark) + + def process_empty_scalar(self, mark): + return ScalarEvent(None, None, (True, False), '', mark, mark) + diff --git a/libs/yaml/reader.py b/libs/yaml/reader.py new file mode 100644 index 000000000..774b0219b --- /dev/null +++ b/libs/yaml/reader.py @@ -0,0 +1,185 @@ +# This module contains abstractions for the input stream. You don't have to +# looks further, there are no pretty code. +# +# We define two classes here. +# +# Mark(source, line, column) +# It's just a record and its only use is producing nice error messages. +# Parser does not use it for any other purposes. +# +# Reader(source, data) +# Reader determines the encoding of `data` and converts it to unicode. +# Reader provides the following methods and attributes: +# reader.peek(length=1) - return the next `length` characters +# reader.forward(length=1) - move the current position to `length` characters. +# reader.index - the number of the current character. +# reader.line, stream.column - the line and the column of the current character. + +__all__ = ['Reader', 'ReaderError'] + +from .error import YAMLError, Mark + +import codecs, re + +class ReaderError(YAMLError): + + def __init__(self, name, position, character, encoding, reason): + self.name = name + self.character = character + self.position = position + self.encoding = encoding + self.reason = reason + + def __str__(self): + if isinstance(self.character, bytes): + return "'%s' codec can't decode byte #x%02x: %s\n" \ + " in \"%s\", position %d" \ + % (self.encoding, ord(self.character), self.reason, + self.name, self.position) + else: + return "unacceptable character #x%04x: %s\n" \ + " in \"%s\", position %d" \ + % (self.character, self.reason, + self.name, self.position) + +class Reader(object): + # Reader: + # - determines the data encoding and converts it to a unicode string, + # - checks if characters are in allowed range, + # - adds '\0' to the end. + + # Reader accepts + # - a `bytes` object, + # - a `str` object, + # - a file-like object with its `read` method returning `str`, + # - a file-like object with its `read` method returning `unicode`. + + # Yeah, it's ugly and slow. + + def __init__(self, stream): + self.name = None + self.stream = None + self.stream_pointer = 0 + self.eof = True + self.buffer = '' + self.pointer = 0 + self.raw_buffer = None + self.raw_decode = None + self.encoding = None + self.index = 0 + self.line = 0 + self.column = 0 + if isinstance(stream, str): + self.name = "<unicode string>" + self.check_printable(stream) + self.buffer = stream+'\0' + elif isinstance(stream, bytes): + self.name = "<byte string>" + self.raw_buffer = stream + self.determine_encoding() + else: + self.stream = stream + self.name = getattr(stream, 'name', "<file>") + self.eof = False + self.raw_buffer = None + self.determine_encoding() + + def peek(self, index=0): + try: + return self.buffer[self.pointer+index] + except IndexError: + self.update(index+1) + return self.buffer[self.pointer+index] + + def prefix(self, length=1): + if self.pointer+length >= len(self.buffer): + self.update(length) + return self.buffer[self.pointer:self.pointer+length] + + def forward(self, length=1): + if self.pointer+length+1 >= len(self.buffer): + self.update(length+1) + while length: + ch = self.buffer[self.pointer] + self.pointer += 1 + self.index += 1 + if ch in '\n\x85\u2028\u2029' \ + or (ch == '\r' and self.buffer[self.pointer] != '\n'): + self.line += 1 + self.column = 0 + elif ch != '\uFEFF': + self.column += 1 + length -= 1 + + def get_mark(self): + if self.stream is None: + return Mark(self.name, self.index, self.line, self.column, + self.buffer, self.pointer) + else: + return Mark(self.name, self.index, self.line, self.column, + None, None) + + def determine_encoding(self): + while not self.eof and (self.raw_buffer is None or len(self.raw_buffer) < 2): + self.update_raw() + if isinstance(self.raw_buffer, bytes): + if self.raw_buffer.startswith(codecs.BOM_UTF16_LE): + self.raw_decode = codecs.utf_16_le_decode + self.encoding = 'utf-16-le' + elif self.raw_buffer.startswith(codecs.BOM_UTF16_BE): + self.raw_decode = codecs.utf_16_be_decode + self.encoding = 'utf-16-be' + else: + self.raw_decode = codecs.utf_8_decode + self.encoding = 'utf-8' + self.update(1) + + NON_PRINTABLE = re.compile('[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]') + def check_printable(self, data): + match = self.NON_PRINTABLE.search(data) + if match: + character = match.group() + position = self.index+(len(self.buffer)-self.pointer)+match.start() + raise ReaderError(self.name, position, ord(character), + 'unicode', "special characters are not allowed") + + def update(self, length): + if self.raw_buffer is None: + return + self.buffer = self.buffer[self.pointer:] + self.pointer = 0 + while len(self.buffer) < length: + if not self.eof: + self.update_raw() + if self.raw_decode is not None: + try: + data, converted = self.raw_decode(self.raw_buffer, + 'strict', self.eof) + except UnicodeDecodeError as exc: + character = self.raw_buffer[exc.start] + if self.stream is not None: + position = self.stream_pointer-len(self.raw_buffer)+exc.start + else: + position = exc.start + raise ReaderError(self.name, position, character, + exc.encoding, exc.reason) + else: + data = self.raw_buffer + converted = len(data) + self.check_printable(data) + self.buffer += data + self.raw_buffer = self.raw_buffer[converted:] + if self.eof: + self.buffer += '\0' + self.raw_buffer = None + break + + def update_raw(self, size=4096): + data = self.stream.read(size) + if self.raw_buffer is None: + self.raw_buffer = data + else: + self.raw_buffer += data + self.stream_pointer += len(data) + if not data: + self.eof = True diff --git a/libs/yaml/representer.py b/libs/yaml/representer.py new file mode 100644 index 000000000..dd144017b --- /dev/null +++ b/libs/yaml/representer.py @@ -0,0 +1,389 @@ + +__all__ = ['BaseRepresenter', 'SafeRepresenter', 'Representer', + 'RepresenterError'] + +from .error import * +from .nodes import * + +import datetime, sys, copyreg, types, base64, collections + +class RepresenterError(YAMLError): + pass + +class BaseRepresenter: + + yaml_representers = {} + yaml_multi_representers = {} + + def __init__(self, default_style=None, default_flow_style=False, sort_keys=True): + self.default_style = default_style + self.sort_keys = sort_keys + self.default_flow_style = default_flow_style + self.represented_objects = {} + self.object_keeper = [] + self.alias_key = None + + def represent(self, data): + node = self.represent_data(data) + self.serialize(node) + self.represented_objects = {} + self.object_keeper = [] + self.alias_key = None + + def represent_data(self, data): + if self.ignore_aliases(data): + self.alias_key = None + else: + self.alias_key = id(data) + if self.alias_key is not None: + if self.alias_key in self.represented_objects: + node = self.represented_objects[self.alias_key] + #if node is None: + # raise RepresenterError("recursive objects are not allowed: %r" % data) + return node + #self.represented_objects[alias_key] = None + self.object_keeper.append(data) + data_types = type(data).__mro__ + if data_types[0] in self.yaml_representers: + node = self.yaml_representers[data_types[0]](self, data) + else: + for data_type in data_types: + if data_type in self.yaml_multi_representers: + node = self.yaml_multi_representers[data_type](self, data) + break + else: + if None in self.yaml_multi_representers: + node = self.yaml_multi_representers[None](self, data) + elif None in self.yaml_representers: + node = self.yaml_representers[None](self, data) + else: + node = ScalarNode(None, str(data)) + #if alias_key is not None: + # self.represented_objects[alias_key] = node + return node + + @classmethod + def add_representer(cls, data_type, representer): + if not 'yaml_representers' in cls.__dict__: + cls.yaml_representers = cls.yaml_representers.copy() + cls.yaml_representers[data_type] = representer + + @classmethod + def add_multi_representer(cls, data_type, representer): + if not 'yaml_multi_representers' in cls.__dict__: + cls.yaml_multi_representers = cls.yaml_multi_representers.copy() + cls.yaml_multi_representers[data_type] = representer + + def represent_scalar(self, tag, value, style=None): + if style is None: + style = self.default_style + node = ScalarNode(tag, value, style=style) + if self.alias_key is not None: + self.represented_objects[self.alias_key] = node + return node + + def represent_sequence(self, tag, sequence, flow_style=None): + value = [] + node = SequenceNode(tag, value, flow_style=flow_style) + if self.alias_key is not None: + self.represented_objects[self.alias_key] = node + best_style = True + for item in sequence: + node_item = self.represent_data(item) + if not (isinstance(node_item, ScalarNode) and not node_item.style): + best_style = False + value.append(node_item) + if flow_style is None: + if self.default_flow_style is not None: + node.flow_style = self.default_flow_style + else: + node.flow_style = best_style + return node + + def represent_mapping(self, tag, mapping, flow_style=None): + value = [] + node = MappingNode(tag, value, flow_style=flow_style) + if self.alias_key is not None: + self.represented_objects[self.alias_key] = node + best_style = True + if hasattr(mapping, 'items'): + mapping = list(mapping.items()) + if self.sort_keys: + try: + mapping = sorted(mapping) + except TypeError: + pass + for item_key, item_value in mapping: + node_key = self.represent_data(item_key) + node_value = self.represent_data(item_value) + if not (isinstance(node_key, ScalarNode) and not node_key.style): + best_style = False + if not (isinstance(node_value, ScalarNode) and not node_value.style): + best_style = False + value.append((node_key, node_value)) + if flow_style is None: + if self.default_flow_style is not None: + node.flow_style = self.default_flow_style + else: + node.flow_style = best_style + return node + + def ignore_aliases(self, data): + return False + +class SafeRepresenter(BaseRepresenter): + + def ignore_aliases(self, data): + if data is None: + return True + if isinstance(data, tuple) and data == (): + return True + if isinstance(data, (str, bytes, bool, int, float)): + return True + + def represent_none(self, data): + return self.represent_scalar('tag:yaml.org,2002:null', 'null') + + def represent_str(self, data): + return self.represent_scalar('tag:yaml.org,2002:str', data) + + def represent_binary(self, data): + if hasattr(base64, 'encodebytes'): + data = base64.encodebytes(data).decode('ascii') + else: + data = base64.encodestring(data).decode('ascii') + return self.represent_scalar('tag:yaml.org,2002:binary', data, style='|') + + def represent_bool(self, data): + if data: + value = 'true' + else: + value = 'false' + return self.represent_scalar('tag:yaml.org,2002:bool', value) + + def represent_int(self, data): + return self.represent_scalar('tag:yaml.org,2002:int', str(data)) + + inf_value = 1e300 + while repr(inf_value) != repr(inf_value*inf_value): + inf_value *= inf_value + + def represent_float(self, data): + if data != data or (data == 0.0 and data == 1.0): + value = '.nan' + elif data == self.inf_value: + value = '.inf' + elif data == -self.inf_value: + value = '-.inf' + else: + value = repr(data).lower() + # Note that in some cases `repr(data)` represents a float number + # without the decimal parts. For instance: + # >>> repr(1e17) + # '1e17' + # Unfortunately, this is not a valid float representation according + # to the definition of the `!!float` tag. We fix this by adding + # '.0' before the 'e' symbol. + if '.' not in value and 'e' in value: + value = value.replace('e', '.0e', 1) + return self.represent_scalar('tag:yaml.org,2002:float', value) + + def represent_list(self, data): + #pairs = (len(data) > 0 and isinstance(data, list)) + #if pairs: + # for item in data: + # if not isinstance(item, tuple) or len(item) != 2: + # pairs = False + # break + #if not pairs: + return self.represent_sequence('tag:yaml.org,2002:seq', data) + #value = [] + #for item_key, item_value in data: + # value.append(self.represent_mapping(u'tag:yaml.org,2002:map', + # [(item_key, item_value)])) + #return SequenceNode(u'tag:yaml.org,2002:pairs', value) + + def represent_dict(self, data): + return self.represent_mapping('tag:yaml.org,2002:map', data) + + def represent_set(self, data): + value = {} + for key in data: + value[key] = None + return self.represent_mapping('tag:yaml.org,2002:set', value) + + def represent_date(self, data): + value = data.isoformat() + return self.represent_scalar('tag:yaml.org,2002:timestamp', value) + + def represent_datetime(self, data): + value = data.isoformat(' ') + return self.represent_scalar('tag:yaml.org,2002:timestamp', value) + + def represent_yaml_object(self, tag, data, cls, flow_style=None): + if hasattr(data, '__getstate__'): + state = data.__getstate__() + else: + state = data.__dict__.copy() + return self.represent_mapping(tag, state, flow_style=flow_style) + + def represent_undefined(self, data): + raise RepresenterError("cannot represent an object", data) + +SafeRepresenter.add_representer(type(None), + SafeRepresenter.represent_none) + +SafeRepresenter.add_representer(str, + SafeRepresenter.represent_str) + +SafeRepresenter.add_representer(bytes, + SafeRepresenter.represent_binary) + +SafeRepresenter.add_representer(bool, + SafeRepresenter.represent_bool) + +SafeRepresenter.add_representer(int, + SafeRepresenter.represent_int) + +SafeRepresenter.add_representer(float, + SafeRepresenter.represent_float) + +SafeRepresenter.add_representer(list, + SafeRepresenter.represent_list) + +SafeRepresenter.add_representer(tuple, + SafeRepresenter.represent_list) + +SafeRepresenter.add_representer(dict, + SafeRepresenter.represent_dict) + +SafeRepresenter.add_representer(set, + SafeRepresenter.represent_set) + +SafeRepresenter.add_representer(datetime.date, + SafeRepresenter.represent_date) + +SafeRepresenter.add_representer(datetime.datetime, + SafeRepresenter.represent_datetime) + +SafeRepresenter.add_representer(None, + SafeRepresenter.represent_undefined) + +class Representer(SafeRepresenter): + + def represent_complex(self, data): + if data.imag == 0.0: + data = '%r' % data.real + elif data.real == 0.0: + data = '%rj' % data.imag + elif data.imag > 0: + data = '%r+%rj' % (data.real, data.imag) + else: + data = '%r%rj' % (data.real, data.imag) + return self.represent_scalar('tag:yaml.org,2002:python/complex', data) + + def represent_tuple(self, data): + return self.represent_sequence('tag:yaml.org,2002:python/tuple', data) + + def represent_name(self, data): + name = '%s.%s' % (data.__module__, data.__name__) + return self.represent_scalar('tag:yaml.org,2002:python/name:'+name, '') + + def represent_module(self, data): + return self.represent_scalar( + 'tag:yaml.org,2002:python/module:'+data.__name__, '') + + def represent_object(self, data): + # We use __reduce__ API to save the data. data.__reduce__ returns + # a tuple of length 2-5: + # (function, args, state, listitems, dictitems) + + # For reconstructing, we calls function(*args), then set its state, + # listitems, and dictitems if they are not None. + + # A special case is when function.__name__ == '__newobj__'. In this + # case we create the object with args[0].__new__(*args). + + # Another special case is when __reduce__ returns a string - we don't + # support it. + + # We produce a !!python/object, !!python/object/new or + # !!python/object/apply node. + + cls = type(data) + if cls in copyreg.dispatch_table: + reduce = copyreg.dispatch_table[cls](data) + elif hasattr(data, '__reduce_ex__'): + reduce = data.__reduce_ex__(2) + elif hasattr(data, '__reduce__'): + reduce = data.__reduce__() + else: + raise RepresenterError("cannot represent an object", data) + reduce = (list(reduce)+[None]*5)[:5] + function, args, state, listitems, dictitems = reduce + args = list(args) + if state is None: + state = {} + if listitems is not None: + listitems = list(listitems) + if dictitems is not None: + dictitems = dict(dictitems) + if function.__name__ == '__newobj__': + function = args[0] + args = args[1:] + tag = 'tag:yaml.org,2002:python/object/new:' + newobj = True + else: + tag = 'tag:yaml.org,2002:python/object/apply:' + newobj = False + function_name = '%s.%s' % (function.__module__, function.__name__) + if not args and not listitems and not dictitems \ + and isinstance(state, dict) and newobj: + return self.represent_mapping( + 'tag:yaml.org,2002:python/object:'+function_name, state) + if not listitems and not dictitems \ + and isinstance(state, dict) and not state: + return self.represent_sequence(tag+function_name, args) + value = {} + if args: + value['args'] = args + if state or not isinstance(state, dict): + value['state'] = state + if listitems: + value['listitems'] = listitems + if dictitems: + value['dictitems'] = dictitems + return self.represent_mapping(tag+function_name, value) + + def represent_ordered_dict(self, data): + # Provide uniform representation across different Python versions. + data_type = type(data) + tag = 'tag:yaml.org,2002:python/object/apply:%s.%s' \ + % (data_type.__module__, data_type.__name__) + items = [[key, value] for key, value in data.items()] + return self.represent_sequence(tag, [items]) + +Representer.add_representer(complex, + Representer.represent_complex) + +Representer.add_representer(tuple, + Representer.represent_tuple) + +Representer.add_representer(type, + Representer.represent_name) + +Representer.add_representer(collections.OrderedDict, + Representer.represent_ordered_dict) + +Representer.add_representer(types.FunctionType, + Representer.represent_name) + +Representer.add_representer(types.BuiltinFunctionType, + Representer.represent_name) + +Representer.add_representer(types.ModuleType, + Representer.represent_module) + +Representer.add_multi_representer(object, + Representer.represent_object) + diff --git a/libs/yaml/resolver.py b/libs/yaml/resolver.py new file mode 100644 index 000000000..02b82e73e --- /dev/null +++ b/libs/yaml/resolver.py @@ -0,0 +1,227 @@ + +__all__ = ['BaseResolver', 'Resolver'] + +from .error import * +from .nodes import * + +import re + +class ResolverError(YAMLError): + pass + +class BaseResolver: + + DEFAULT_SCALAR_TAG = 'tag:yaml.org,2002:str' + DEFAULT_SEQUENCE_TAG = 'tag:yaml.org,2002:seq' + DEFAULT_MAPPING_TAG = 'tag:yaml.org,2002:map' + + yaml_implicit_resolvers = {} + yaml_path_resolvers = {} + + def __init__(self): + self.resolver_exact_paths = [] + self.resolver_prefix_paths = [] + + @classmethod + def add_implicit_resolver(cls, tag, regexp, first): + if not 'yaml_implicit_resolvers' in cls.__dict__: + implicit_resolvers = {} + for key in cls.yaml_implicit_resolvers: + implicit_resolvers[key] = cls.yaml_implicit_resolvers[key][:] + cls.yaml_implicit_resolvers = implicit_resolvers + if first is None: + first = [None] + for ch in first: + cls.yaml_implicit_resolvers.setdefault(ch, []).append((tag, regexp)) + + @classmethod + def add_path_resolver(cls, tag, path, kind=None): + # Note: `add_path_resolver` is experimental. The API could be changed. + # `new_path` is a pattern that is matched against the path from the + # root to the node that is being considered. `node_path` elements are + # tuples `(node_check, index_check)`. `node_check` is a node class: + # `ScalarNode`, `SequenceNode`, `MappingNode` or `None`. `None` + # matches any kind of a node. `index_check` could be `None`, a boolean + # value, a string value, or a number. `None` and `False` match against + # any _value_ of sequence and mapping nodes. `True` matches against + # any _key_ of a mapping node. A string `index_check` matches against + # a mapping value that corresponds to a scalar key which content is + # equal to the `index_check` value. An integer `index_check` matches + # against a sequence value with the index equal to `index_check`. + if not 'yaml_path_resolvers' in cls.__dict__: + cls.yaml_path_resolvers = cls.yaml_path_resolvers.copy() + new_path = [] + for element in path: + if isinstance(element, (list, tuple)): + if len(element) == 2: + node_check, index_check = element + elif len(element) == 1: + node_check = element[0] + index_check = True + else: + raise ResolverError("Invalid path element: %s" % element) + else: + node_check = None + index_check = element + if node_check is str: + node_check = ScalarNode + elif node_check is list: + node_check = SequenceNode + elif node_check is dict: + node_check = MappingNode + elif node_check not in [ScalarNode, SequenceNode, MappingNode] \ + and not isinstance(node_check, str) \ + and node_check is not None: + raise ResolverError("Invalid node checker: %s" % node_check) + if not isinstance(index_check, (str, int)) \ + and index_check is not None: + raise ResolverError("Invalid index checker: %s" % index_check) + new_path.append((node_check, index_check)) + if kind is str: + kind = ScalarNode + elif kind is list: + kind = SequenceNode + elif kind is dict: + kind = MappingNode + elif kind not in [ScalarNode, SequenceNode, MappingNode] \ + and kind is not None: + raise ResolverError("Invalid node kind: %s" % kind) + cls.yaml_path_resolvers[tuple(new_path), kind] = tag + + def descend_resolver(self, current_node, current_index): + if not self.yaml_path_resolvers: + return + exact_paths = {} + prefix_paths = [] + if current_node: + depth = len(self.resolver_prefix_paths) + for path, kind in self.resolver_prefix_paths[-1]: + if self.check_resolver_prefix(depth, path, kind, + current_node, current_index): + if len(path) > depth: + prefix_paths.append((path, kind)) + else: + exact_paths[kind] = self.yaml_path_resolvers[path, kind] + else: + for path, kind in self.yaml_path_resolvers: + if not path: + exact_paths[kind] = self.yaml_path_resolvers[path, kind] + else: + prefix_paths.append((path, kind)) + self.resolver_exact_paths.append(exact_paths) + self.resolver_prefix_paths.append(prefix_paths) + + def ascend_resolver(self): + if not self.yaml_path_resolvers: + return + self.resolver_exact_paths.pop() + self.resolver_prefix_paths.pop() + + def check_resolver_prefix(self, depth, path, kind, + current_node, current_index): + node_check, index_check = path[depth-1] + if isinstance(node_check, str): + if current_node.tag != node_check: + return + elif node_check is not None: + if not isinstance(current_node, node_check): + return + if index_check is True and current_index is not None: + return + if (index_check is False or index_check is None) \ + and current_index is None: + return + if isinstance(index_check, str): + if not (isinstance(current_index, ScalarNode) + and index_check == current_index.value): + return + elif isinstance(index_check, int) and not isinstance(index_check, bool): + if index_check != current_index: + return + return True + + def resolve(self, kind, value, implicit): + if kind is ScalarNode and implicit[0]: + if value == '': + resolvers = self.yaml_implicit_resolvers.get('', []) + else: + resolvers = self.yaml_implicit_resolvers.get(value[0], []) + resolvers += self.yaml_implicit_resolvers.get(None, []) + for tag, regexp in resolvers: + if regexp.match(value): + return tag + implicit = implicit[1] + if self.yaml_path_resolvers: + exact_paths = self.resolver_exact_paths[-1] + if kind in exact_paths: + return exact_paths[kind] + if None in exact_paths: + return exact_paths[None] + if kind is ScalarNode: + return self.DEFAULT_SCALAR_TAG + elif kind is SequenceNode: + return self.DEFAULT_SEQUENCE_TAG + elif kind is MappingNode: + return self.DEFAULT_MAPPING_TAG + +class Resolver(BaseResolver): + pass + +Resolver.add_implicit_resolver( + 'tag:yaml.org,2002:bool', + re.compile(r'''^(?:yes|Yes|YES|no|No|NO + |true|True|TRUE|false|False|FALSE + |on|On|ON|off|Off|OFF)$''', re.X), + list('yYnNtTfFoO')) + +Resolver.add_implicit_resolver( + 'tag:yaml.org,2002:float', + re.compile(r'''^(?:[-+]?(?:[0-9][0-9_]*)\.[0-9_]*(?:[eE][-+][0-9]+)? + |\.[0-9_]+(?:[eE][-+][0-9]+)? + |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\.[0-9_]* + |[-+]?\.(?:inf|Inf|INF) + |\.(?:nan|NaN|NAN))$''', re.X), + list('-+0123456789.')) + +Resolver.add_implicit_resolver( + 'tag:yaml.org,2002:int', + re.compile(r'''^(?:[-+]?0b[0-1_]+ + |[-+]?0[0-7_]+ + |[-+]?(?:0|[1-9][0-9_]*) + |[-+]?0x[0-9a-fA-F_]+ + |[-+]?[1-9][0-9_]*(?::[0-5]?[0-9])+)$''', re.X), + list('-+0123456789')) + +Resolver.add_implicit_resolver( + 'tag:yaml.org,2002:merge', + re.compile(r'^(?:<<)$'), + ['<']) + +Resolver.add_implicit_resolver( + 'tag:yaml.org,2002:null', + re.compile(r'''^(?: ~ + |null|Null|NULL + | )$''', re.X), + ['~', 'n', 'N', '']) + +Resolver.add_implicit_resolver( + 'tag:yaml.org,2002:timestamp', + re.compile(r'''^(?:[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9] + |[0-9][0-9][0-9][0-9] -[0-9][0-9]? -[0-9][0-9]? + (?:[Tt]|[ \t]+)[0-9][0-9]? + :[0-9][0-9] :[0-9][0-9] (?:\.[0-9]*)? + (?:[ \t]*(?:Z|[-+][0-9][0-9]?(?::[0-9][0-9])?))?)$''', re.X), + list('0123456789')) + +Resolver.add_implicit_resolver( + 'tag:yaml.org,2002:value', + re.compile(r'^(?:=)$'), + ['=']) + +# The following resolver is only for documentation purposes. It cannot work +# because plain scalars cannot start with '!', '&', or '*'. +Resolver.add_implicit_resolver( + 'tag:yaml.org,2002:yaml', + re.compile(r'^(?:!|&|\*)$'), + list('!&*')) + diff --git a/libs/yaml/scanner.py b/libs/yaml/scanner.py new file mode 100644 index 000000000..775dbcc6d --- /dev/null +++ b/libs/yaml/scanner.py @@ -0,0 +1,1435 @@ + +# Scanner produces tokens of the following types: +# STREAM-START +# STREAM-END +# DIRECTIVE(name, value) +# DOCUMENT-START +# DOCUMENT-END +# BLOCK-SEQUENCE-START +# BLOCK-MAPPING-START +# BLOCK-END +# FLOW-SEQUENCE-START +# FLOW-MAPPING-START +# FLOW-SEQUENCE-END +# FLOW-MAPPING-END +# BLOCK-ENTRY +# FLOW-ENTRY +# KEY +# VALUE +# ALIAS(value) +# ANCHOR(value) +# TAG(value) +# SCALAR(value, plain, style) +# +# Read comments in the Scanner code for more details. +# + +__all__ = ['Scanner', 'ScannerError'] + +from .error import MarkedYAMLError +from .tokens import * + +class ScannerError(MarkedYAMLError): + pass + +class SimpleKey: + # See below simple keys treatment. + + def __init__(self, token_number, required, index, line, column, mark): + self.token_number = token_number + self.required = required + self.index = index + self.line = line + self.column = column + self.mark = mark + +class Scanner: + + def __init__(self): + """Initialize the scanner.""" + # It is assumed that Scanner and Reader will have a common descendant. + # Reader do the dirty work of checking for BOM and converting the + # input data to Unicode. It also adds NUL to the end. + # + # Reader supports the following methods + # self.peek(i=0) # peek the next i-th character + # self.prefix(l=1) # peek the next l characters + # self.forward(l=1) # read the next l characters and move the pointer. + + # Had we reached the end of the stream? + self.done = False + + # The number of unclosed '{' and '['. `flow_level == 0` means block + # context. + self.flow_level = 0 + + # List of processed tokens that are not yet emitted. + self.tokens = [] + + # Add the STREAM-START token. + self.fetch_stream_start() + + # Number of tokens that were emitted through the `get_token` method. + self.tokens_taken = 0 + + # The current indentation level. + self.indent = -1 + + # Past indentation levels. + self.indents = [] + + # Variables related to simple keys treatment. + + # A simple key is a key that is not denoted by the '?' indicator. + # Example of simple keys: + # --- + # block simple key: value + # ? not a simple key: + # : { flow simple key: value } + # We emit the KEY token before all keys, so when we find a potential + # simple key, we try to locate the corresponding ':' indicator. + # Simple keys should be limited to a single line and 1024 characters. + + # Can a simple key start at the current position? A simple key may + # start: + # - at the beginning of the line, not counting indentation spaces + # (in block context), + # - after '{', '[', ',' (in the flow context), + # - after '?', ':', '-' (in the block context). + # In the block context, this flag also signifies if a block collection + # may start at the current position. + self.allow_simple_key = True + + # Keep track of possible simple keys. This is a dictionary. The key + # is `flow_level`; there can be no more that one possible simple key + # for each level. The value is a SimpleKey record: + # (token_number, required, index, line, column, mark) + # A simple key may start with ALIAS, ANCHOR, TAG, SCALAR(flow), + # '[', or '{' tokens. + self.possible_simple_keys = {} + + # Public methods. + + def check_token(self, *choices): + # Check if the next token is one of the given types. + while self.need_more_tokens(): + self.fetch_more_tokens() + if self.tokens: + if not choices: + return True + for choice in choices: + if isinstance(self.tokens[0], choice): + return True + return False + + def peek_token(self): + # Return the next token, but do not delete if from the queue. + # Return None if no more tokens. + while self.need_more_tokens(): + self.fetch_more_tokens() + if self.tokens: + return self.tokens[0] + else: + return None + + def get_token(self): + # Return the next token. + while self.need_more_tokens(): + self.fetch_more_tokens() + if self.tokens: + self.tokens_taken += 1 + return self.tokens.pop(0) + + # Private methods. + + def need_more_tokens(self): + if self.done: + return False + if not self.tokens: + return True + # The current token may be a potential simple key, so we + # need to look further. + self.stale_possible_simple_keys() + if self.next_possible_simple_key() == self.tokens_taken: + return True + + def fetch_more_tokens(self): + + # Eat whitespaces and comments until we reach the next token. + self.scan_to_next_token() + + # Remove obsolete possible simple keys. + self.stale_possible_simple_keys() + + # Compare the current indentation and column. It may add some tokens + # and decrease the current indentation level. + self.unwind_indent(self.column) + + # Peek the next character. + ch = self.peek() + + # Is it the end of stream? + if ch == '\0': + return self.fetch_stream_end() + + # Is it a directive? + if ch == '%' and self.check_directive(): + return self.fetch_directive() + + # Is it the document start? + if ch == '-' and self.check_document_start(): + return self.fetch_document_start() + + # Is it the document end? + if ch == '.' and self.check_document_end(): + return self.fetch_document_end() + + # TODO: support for BOM within a stream. + #if ch == '\uFEFF': + # return self.fetch_bom() <-- issue BOMToken + + # Note: the order of the following checks is NOT significant. + + # Is it the flow sequence start indicator? + if ch == '[': + return self.fetch_flow_sequence_start() + + # Is it the flow mapping start indicator? + if ch == '{': + return self.fetch_flow_mapping_start() + + # Is it the flow sequence end indicator? + if ch == ']': + return self.fetch_flow_sequence_end() + + # Is it the flow mapping end indicator? + if ch == '}': + return self.fetch_flow_mapping_end() + + # Is it the flow entry indicator? + if ch == ',': + return self.fetch_flow_entry() + + # Is it the block entry indicator? + if ch == '-' and self.check_block_entry(): + return self.fetch_block_entry() + + # Is it the key indicator? + if ch == '?' and self.check_key(): + return self.fetch_key() + + # Is it the value indicator? + if ch == ':' and self.check_value(): + return self.fetch_value() + + # Is it an alias? + if ch == '*': + return self.fetch_alias() + + # Is it an anchor? + if ch == '&': + return self.fetch_anchor() + + # Is it a tag? + if ch == '!': + return self.fetch_tag() + + # Is it a literal scalar? + if ch == '|' and not self.flow_level: + return self.fetch_literal() + + # Is it a folded scalar? + if ch == '>' and not self.flow_level: + return self.fetch_folded() + + # Is it a single quoted scalar? + if ch == '\'': + return self.fetch_single() + + # Is it a double quoted scalar? + if ch == '\"': + return self.fetch_double() + + # It must be a plain scalar then. + if self.check_plain(): + return self.fetch_plain() + + # No? It's an error. Let's produce a nice error message. + raise ScannerError("while scanning for the next token", None, + "found character %r that cannot start any token" % ch, + self.get_mark()) + + # Simple keys treatment. + + def next_possible_simple_key(self): + # Return the number of the nearest possible simple key. Actually we + # don't need to loop through the whole dictionary. We may replace it + # with the following code: + # if not self.possible_simple_keys: + # return None + # return self.possible_simple_keys[ + # min(self.possible_simple_keys.keys())].token_number + min_token_number = None + for level in self.possible_simple_keys: + key = self.possible_simple_keys[level] + if min_token_number is None or key.token_number < min_token_number: + min_token_number = key.token_number + return min_token_number + + def stale_possible_simple_keys(self): + # Remove entries that are no longer possible simple keys. According to + # the YAML specification, simple keys + # - should be limited to a single line, + # - should be no longer than 1024 characters. + # Disabling this procedure will allow simple keys of any length and + # height (may cause problems if indentation is broken though). + for level in list(self.possible_simple_keys): + key = self.possible_simple_keys[level] + if key.line != self.line \ + or self.index-key.index > 1024: + if key.required: + raise ScannerError("while scanning a simple key", key.mark, + "could not find expected ':'", self.get_mark()) + del self.possible_simple_keys[level] + + def save_possible_simple_key(self): + # The next token may start a simple key. We check if it's possible + # and save its position. This function is called for + # ALIAS, ANCHOR, TAG, SCALAR(flow), '[', and '{'. + + # Check if a simple key is required at the current position. + required = not self.flow_level and self.indent == self.column + + # The next token might be a simple key. Let's save it's number and + # position. + if self.allow_simple_key: + self.remove_possible_simple_key() + token_number = self.tokens_taken+len(self.tokens) + key = SimpleKey(token_number, required, + self.index, self.line, self.column, self.get_mark()) + self.possible_simple_keys[self.flow_level] = key + + def remove_possible_simple_key(self): + # Remove the saved possible key position at the current flow level. + if self.flow_level in self.possible_simple_keys: + key = self.possible_simple_keys[self.flow_level] + + if key.required: + raise ScannerError("while scanning a simple key", key.mark, + "could not find expected ':'", self.get_mark()) + + del self.possible_simple_keys[self.flow_level] + + # Indentation functions. + + def unwind_indent(self, column): + + ## In flow context, tokens should respect indentation. + ## Actually the condition should be `self.indent >= column` according to + ## the spec. But this condition will prohibit intuitively correct + ## constructions such as + ## key : { + ## } + #if self.flow_level and self.indent > column: + # raise ScannerError(None, None, + # "invalid intendation or unclosed '[' or '{'", + # self.get_mark()) + + # In the flow context, indentation is ignored. We make the scanner less + # restrictive then specification requires. + if self.flow_level: + return + + # In block context, we may need to issue the BLOCK-END tokens. + while self.indent > column: + mark = self.get_mark() + self.indent = self.indents.pop() + self.tokens.append(BlockEndToken(mark, mark)) + + def add_indent(self, column): + # Check if we need to increase indentation. + if self.indent < column: + self.indents.append(self.indent) + self.indent = column + return True + return False + + # Fetchers. + + def fetch_stream_start(self): + # We always add STREAM-START as the first token and STREAM-END as the + # last token. + + # Read the token. + mark = self.get_mark() + + # Add STREAM-START. + self.tokens.append(StreamStartToken(mark, mark, + encoding=self.encoding)) + + + def fetch_stream_end(self): + + # Set the current intendation to -1. + self.unwind_indent(-1) + + # Reset simple keys. + self.remove_possible_simple_key() + self.allow_simple_key = False + self.possible_simple_keys = {} + + # Read the token. + mark = self.get_mark() + + # Add STREAM-END. + self.tokens.append(StreamEndToken(mark, mark)) + + # The steam is finished. + self.done = True + + def fetch_directive(self): + + # Set the current intendation to -1. + self.unwind_indent(-1) + + # Reset simple keys. + self.remove_possible_simple_key() + self.allow_simple_key = False + + # Scan and add DIRECTIVE. + self.tokens.append(self.scan_directive()) + + def fetch_document_start(self): + self.fetch_document_indicator(DocumentStartToken) + + def fetch_document_end(self): + self.fetch_document_indicator(DocumentEndToken) + + def fetch_document_indicator(self, TokenClass): + + # Set the current intendation to -1. + self.unwind_indent(-1) + + # Reset simple keys. Note that there could not be a block collection + # after '---'. + self.remove_possible_simple_key() + self.allow_simple_key = False + + # Add DOCUMENT-START or DOCUMENT-END. + start_mark = self.get_mark() + self.forward(3) + end_mark = self.get_mark() + self.tokens.append(TokenClass(start_mark, end_mark)) + + def fetch_flow_sequence_start(self): + self.fetch_flow_collection_start(FlowSequenceStartToken) + + def fetch_flow_mapping_start(self): + self.fetch_flow_collection_start(FlowMappingStartToken) + + def fetch_flow_collection_start(self, TokenClass): + + # '[' and '{' may start a simple key. + self.save_possible_simple_key() + + # Increase the flow level. + self.flow_level += 1 + + # Simple keys are allowed after '[' and '{'. + self.allow_simple_key = True + + # Add FLOW-SEQUENCE-START or FLOW-MAPPING-START. + start_mark = self.get_mark() + self.forward() + end_mark = self.get_mark() + self.tokens.append(TokenClass(start_mark, end_mark)) + + def fetch_flow_sequence_end(self): + self.fetch_flow_collection_end(FlowSequenceEndToken) + + def fetch_flow_mapping_end(self): + self.fetch_flow_collection_end(FlowMappingEndToken) + + def fetch_flow_collection_end(self, TokenClass): + + # Reset possible simple key on the current level. + self.remove_possible_simple_key() + + # Decrease the flow level. + self.flow_level -= 1 + + # No simple keys after ']' or '}'. + self.allow_simple_key = False + + # Add FLOW-SEQUENCE-END or FLOW-MAPPING-END. + start_mark = self.get_mark() + self.forward() + end_mark = self.get_mark() + self.tokens.append(TokenClass(start_mark, end_mark)) + + def fetch_flow_entry(self): + + # Simple keys are allowed after ','. + self.allow_simple_key = True + + # Reset possible simple key on the current level. + self.remove_possible_simple_key() + + # Add FLOW-ENTRY. + start_mark = self.get_mark() + self.forward() + end_mark = self.get_mark() + self.tokens.append(FlowEntryToken(start_mark, end_mark)) + + def fetch_block_entry(self): + + # Block context needs additional checks. + if not self.flow_level: + + # Are we allowed to start a new entry? + if not self.allow_simple_key: + raise ScannerError(None, None, + "sequence entries are not allowed here", + self.get_mark()) + + # We may need to add BLOCK-SEQUENCE-START. + if self.add_indent(self.column): + mark = self.get_mark() + self.tokens.append(BlockSequenceStartToken(mark, mark)) + + # It's an error for the block entry to occur in the flow context, + # but we let the parser detect this. + else: + pass + + # Simple keys are allowed after '-'. + self.allow_simple_key = True + + # Reset possible simple key on the current level. + self.remove_possible_simple_key() + + # Add BLOCK-ENTRY. + start_mark = self.get_mark() + self.forward() + end_mark = self.get_mark() + self.tokens.append(BlockEntryToken(start_mark, end_mark)) + + def fetch_key(self): + + # Block context needs additional checks. + if not self.flow_level: + + # Are we allowed to start a key (not necessary a simple)? + if not self.allow_simple_key: + raise ScannerError(None, None, + "mapping keys are not allowed here", + self.get_mark()) + + # We may need to add BLOCK-MAPPING-START. + if self.add_indent(self.column): + mark = self.get_mark() + self.tokens.append(BlockMappingStartToken(mark, mark)) + + # Simple keys are allowed after '?' in the block context. + self.allow_simple_key = not self.flow_level + + # Reset possible simple key on the current level. + self.remove_possible_simple_key() + + # Add KEY. + start_mark = self.get_mark() + self.forward() + end_mark = self.get_mark() + self.tokens.append(KeyToken(start_mark, end_mark)) + + def fetch_value(self): + + # Do we determine a simple key? + if self.flow_level in self.possible_simple_keys: + + # Add KEY. + key = self.possible_simple_keys[self.flow_level] + del self.possible_simple_keys[self.flow_level] + self.tokens.insert(key.token_number-self.tokens_taken, + KeyToken(key.mark, key.mark)) + + # If this key starts a new block mapping, we need to add + # BLOCK-MAPPING-START. + if not self.flow_level: + if self.add_indent(key.column): + self.tokens.insert(key.token_number-self.tokens_taken, + BlockMappingStartToken(key.mark, key.mark)) + + # There cannot be two simple keys one after another. + self.allow_simple_key = False + + # It must be a part of a complex key. + else: + + # Block context needs additional checks. + # (Do we really need them? They will be caught by the parser + # anyway.) + if not self.flow_level: + + # We are allowed to start a complex value if and only if + # we can start a simple key. + if not self.allow_simple_key: + raise ScannerError(None, None, + "mapping values are not allowed here", + self.get_mark()) + + # If this value starts a new block mapping, we need to add + # BLOCK-MAPPING-START. It will be detected as an error later by + # the parser. + if not self.flow_level: + if self.add_indent(self.column): + mark = self.get_mark() + self.tokens.append(BlockMappingStartToken(mark, mark)) + + # Simple keys are allowed after ':' in the block context. + self.allow_simple_key = not self.flow_level + + # Reset possible simple key on the current level. + self.remove_possible_simple_key() + + # Add VALUE. + start_mark = self.get_mark() + self.forward() + end_mark = self.get_mark() + self.tokens.append(ValueToken(start_mark, end_mark)) + + def fetch_alias(self): + + # ALIAS could be a simple key. + self.save_possible_simple_key() + + # No simple keys after ALIAS. + self.allow_simple_key = False + + # Scan and add ALIAS. + self.tokens.append(self.scan_anchor(AliasToken)) + + def fetch_anchor(self): + + # ANCHOR could start a simple key. + self.save_possible_simple_key() + + # No simple keys after ANCHOR. + self.allow_simple_key = False + + # Scan and add ANCHOR. + self.tokens.append(self.scan_anchor(AnchorToken)) + + def fetch_tag(self): + + # TAG could start a simple key. + self.save_possible_simple_key() + + # No simple keys after TAG. + self.allow_simple_key = False + + # Scan and add TAG. + self.tokens.append(self.scan_tag()) + + def fetch_literal(self): + self.fetch_block_scalar(style='|') + + def fetch_folded(self): + self.fetch_block_scalar(style='>') + + def fetch_block_scalar(self, style): + + # A simple key may follow a block scalar. + self.allow_simple_key = True + + # Reset possible simple key on the current level. + self.remove_possible_simple_key() + + # Scan and add SCALAR. + self.tokens.append(self.scan_block_scalar(style)) + + def fetch_single(self): + self.fetch_flow_scalar(style='\'') + + def fetch_double(self): + self.fetch_flow_scalar(style='"') + + def fetch_flow_scalar(self, style): + + # A flow scalar could be a simple key. + self.save_possible_simple_key() + + # No simple keys after flow scalars. + self.allow_simple_key = False + + # Scan and add SCALAR. + self.tokens.append(self.scan_flow_scalar(style)) + + def fetch_plain(self): + + # A plain scalar could be a simple key. + self.save_possible_simple_key() + + # No simple keys after plain scalars. But note that `scan_plain` will + # change this flag if the scan is finished at the beginning of the + # line. + self.allow_simple_key = False + + # Scan and add SCALAR. May change `allow_simple_key`. + self.tokens.append(self.scan_plain()) + + # Checkers. + + def check_directive(self): + + # DIRECTIVE: ^ '%' ... + # The '%' indicator is already checked. + if self.column == 0: + return True + + def check_document_start(self): + + # DOCUMENT-START: ^ '---' (' '|'\n') + if self.column == 0: + if self.prefix(3) == '---' \ + and self.peek(3) in '\0 \t\r\n\x85\u2028\u2029': + return True + + def check_document_end(self): + + # DOCUMENT-END: ^ '...' (' '|'\n') + if self.column == 0: + if self.prefix(3) == '...' \ + and self.peek(3) in '\0 \t\r\n\x85\u2028\u2029': + return True + + def check_block_entry(self): + + # BLOCK-ENTRY: '-' (' '|'\n') + return self.peek(1) in '\0 \t\r\n\x85\u2028\u2029' + + def check_key(self): + + # KEY(flow context): '?' + if self.flow_level: + return True + + # KEY(block context): '?' (' '|'\n') + else: + return self.peek(1) in '\0 \t\r\n\x85\u2028\u2029' + + def check_value(self): + + # VALUE(flow context): ':' + if self.flow_level: + return True + + # VALUE(block context): ':' (' '|'\n') + else: + return self.peek(1) in '\0 \t\r\n\x85\u2028\u2029' + + def check_plain(self): + + # A plain scalar may start with any non-space character except: + # '-', '?', ':', ',', '[', ']', '{', '}', + # '#', '&', '*', '!', '|', '>', '\'', '\"', + # '%', '@', '`'. + # + # It may also start with + # '-', '?', ':' + # if it is followed by a non-space character. + # + # Note that we limit the last rule to the block context (except the + # '-' character) because we want the flow context to be space + # independent. + ch = self.peek() + return ch not in '\0 \t\r\n\x85\u2028\u2029-?:,[]{}#&*!|>\'\"%@`' \ + or (self.peek(1) not in '\0 \t\r\n\x85\u2028\u2029' + and (ch == '-' or (not self.flow_level and ch in '?:'))) + + # Scanners. + + def scan_to_next_token(self): + # We ignore spaces, line breaks and comments. + # If we find a line break in the block context, we set the flag + # `allow_simple_key` on. + # The byte order mark is stripped if it's the first character in the + # stream. We do not yet support BOM inside the stream as the + # specification requires. Any such mark will be considered as a part + # of the document. + # + # TODO: We need to make tab handling rules more sane. A good rule is + # Tabs cannot precede tokens + # BLOCK-SEQUENCE-START, BLOCK-MAPPING-START, BLOCK-END, + # KEY(block), VALUE(block), BLOCK-ENTRY + # So the checking code is + # if <TAB>: + # self.allow_simple_keys = False + # We also need to add the check for `allow_simple_keys == True` to + # `unwind_indent` before issuing BLOCK-END. + # Scanners for block, flow, and plain scalars need to be modified. + + if self.index == 0 and self.peek() == '\uFEFF': + self.forward() + found = False + while not found: + while self.peek() == ' ': + self.forward() + if self.peek() == '#': + while self.peek() not in '\0\r\n\x85\u2028\u2029': + self.forward() + if self.scan_line_break(): + if not self.flow_level: + self.allow_simple_key = True + else: + found = True + + def scan_directive(self): + # See the specification for details. + start_mark = self.get_mark() + self.forward() + name = self.scan_directive_name(start_mark) + value = None + if name == 'YAML': + value = self.scan_yaml_directive_value(start_mark) + end_mark = self.get_mark() + elif name == 'TAG': + value = self.scan_tag_directive_value(start_mark) + end_mark = self.get_mark() + else: + end_mark = self.get_mark() + while self.peek() not in '\0\r\n\x85\u2028\u2029': + self.forward() + self.scan_directive_ignored_line(start_mark) + return DirectiveToken(name, value, start_mark, end_mark) + + def scan_directive_name(self, start_mark): + # See the specification for details. + length = 0 + ch = self.peek(length) + while '0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \ + or ch in '-_': + length += 1 + ch = self.peek(length) + if not length: + raise ScannerError("while scanning a directive", start_mark, + "expected alphabetic or numeric character, but found %r" + % ch, self.get_mark()) + value = self.prefix(length) + self.forward(length) + ch = self.peek() + if ch not in '\0 \r\n\x85\u2028\u2029': + raise ScannerError("while scanning a directive", start_mark, + "expected alphabetic or numeric character, but found %r" + % ch, self.get_mark()) + return value + + def scan_yaml_directive_value(self, start_mark): + # See the specification for details. + while self.peek() == ' ': + self.forward() + major = self.scan_yaml_directive_number(start_mark) + if self.peek() != '.': + raise ScannerError("while scanning a directive", start_mark, + "expected a digit or '.', but found %r" % self.peek(), + self.get_mark()) + self.forward() + minor = self.scan_yaml_directive_number(start_mark) + if self.peek() not in '\0 \r\n\x85\u2028\u2029': + raise ScannerError("while scanning a directive", start_mark, + "expected a digit or ' ', but found %r" % self.peek(), + self.get_mark()) + return (major, minor) + + def scan_yaml_directive_number(self, start_mark): + # See the specification for details. + ch = self.peek() + if not ('0' <= ch <= '9'): + raise ScannerError("while scanning a directive", start_mark, + "expected a digit, but found %r" % ch, self.get_mark()) + length = 0 + while '0' <= self.peek(length) <= '9': + length += 1 + value = int(self.prefix(length)) + self.forward(length) + return value + + def scan_tag_directive_value(self, start_mark): + # See the specification for details. + while self.peek() == ' ': + self.forward() + handle = self.scan_tag_directive_handle(start_mark) + while self.peek() == ' ': + self.forward() + prefix = self.scan_tag_directive_prefix(start_mark) + return (handle, prefix) + + def scan_tag_directive_handle(self, start_mark): + # See the specification for details. + value = self.scan_tag_handle('directive', start_mark) + ch = self.peek() + if ch != ' ': + raise ScannerError("while scanning a directive", start_mark, + "expected ' ', but found %r" % ch, self.get_mark()) + return value + + def scan_tag_directive_prefix(self, start_mark): + # See the specification for details. + value = self.scan_tag_uri('directive', start_mark) + ch = self.peek() + if ch not in '\0 \r\n\x85\u2028\u2029': + raise ScannerError("while scanning a directive", start_mark, + "expected ' ', but found %r" % ch, self.get_mark()) + return value + + def scan_directive_ignored_line(self, start_mark): + # See the specification for details. + while self.peek() == ' ': + self.forward() + if self.peek() == '#': + while self.peek() not in '\0\r\n\x85\u2028\u2029': + self.forward() + ch = self.peek() + if ch not in '\0\r\n\x85\u2028\u2029': + raise ScannerError("while scanning a directive", start_mark, + "expected a comment or a line break, but found %r" + % ch, self.get_mark()) + self.scan_line_break() + + def scan_anchor(self, TokenClass): + # The specification does not restrict characters for anchors and + # aliases. This may lead to problems, for instance, the document: + # [ *alias, value ] + # can be interpreted in two ways, as + # [ "value" ] + # and + # [ *alias , "value" ] + # Therefore we restrict aliases to numbers and ASCII letters. + start_mark = self.get_mark() + indicator = self.peek() + if indicator == '*': + name = 'alias' + else: + name = 'anchor' + self.forward() + length = 0 + ch = self.peek(length) + while '0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \ + or ch in '-_': + length += 1 + ch = self.peek(length) + if not length: + raise ScannerError("while scanning an %s" % name, start_mark, + "expected alphabetic or numeric character, but found %r" + % ch, self.get_mark()) + value = self.prefix(length) + self.forward(length) + ch = self.peek() + if ch not in '\0 \t\r\n\x85\u2028\u2029?:,]}%@`': + raise ScannerError("while scanning an %s" % name, start_mark, + "expected alphabetic or numeric character, but found %r" + % ch, self.get_mark()) + end_mark = self.get_mark() + return TokenClass(value, start_mark, end_mark) + + def scan_tag(self): + # See the specification for details. + start_mark = self.get_mark() + ch = self.peek(1) + if ch == '<': + handle = None + self.forward(2) + suffix = self.scan_tag_uri('tag', start_mark) + if self.peek() != '>': + raise ScannerError("while parsing a tag", start_mark, + "expected '>', but found %r" % self.peek(), + self.get_mark()) + self.forward() + elif ch in '\0 \t\r\n\x85\u2028\u2029': + handle = None + suffix = '!' + self.forward() + else: + length = 1 + use_handle = False + while ch not in '\0 \r\n\x85\u2028\u2029': + if ch == '!': + use_handle = True + break + length += 1 + ch = self.peek(length) + handle = '!' + if use_handle: + handle = self.scan_tag_handle('tag', start_mark) + else: + handle = '!' + self.forward() + suffix = self.scan_tag_uri('tag', start_mark) + ch = self.peek() + if ch not in '\0 \r\n\x85\u2028\u2029': + raise ScannerError("while scanning a tag", start_mark, + "expected ' ', but found %r" % ch, self.get_mark()) + value = (handle, suffix) + end_mark = self.get_mark() + return TagToken(value, start_mark, end_mark) + + def scan_block_scalar(self, style): + # See the specification for details. + + if style == '>': + folded = True + else: + folded = False + + chunks = [] + start_mark = self.get_mark() + + # Scan the header. + self.forward() + chomping, increment = self.scan_block_scalar_indicators(start_mark) + self.scan_block_scalar_ignored_line(start_mark) + + # Determine the indentation level and go to the first non-empty line. + min_indent = self.indent+1 + if min_indent < 1: + min_indent = 1 + if increment is None: + breaks, max_indent, end_mark = self.scan_block_scalar_indentation() + indent = max(min_indent, max_indent) + else: + indent = min_indent+increment-1 + breaks, end_mark = self.scan_block_scalar_breaks(indent) + line_break = '' + + # Scan the inner part of the block scalar. + while self.column == indent and self.peek() != '\0': + chunks.extend(breaks) + leading_non_space = self.peek() not in ' \t' + length = 0 + while self.peek(length) not in '\0\r\n\x85\u2028\u2029': + length += 1 + chunks.append(self.prefix(length)) + self.forward(length) + line_break = self.scan_line_break() + breaks, end_mark = self.scan_block_scalar_breaks(indent) + if self.column == indent and self.peek() != '\0': + + # Unfortunately, folding rules are ambiguous. + # + # This is the folding according to the specification: + + if folded and line_break == '\n' \ + and leading_non_space and self.peek() not in ' \t': + if not breaks: + chunks.append(' ') + else: + chunks.append(line_break) + + # This is Clark Evans's interpretation (also in the spec + # examples): + # + #if folded and line_break == '\n': + # if not breaks: + # if self.peek() not in ' \t': + # chunks.append(' ') + # else: + # chunks.append(line_break) + #else: + # chunks.append(line_break) + else: + break + + # Chomp the tail. + if chomping is not False: + chunks.append(line_break) + if chomping is True: + chunks.extend(breaks) + + # We are done. + return ScalarToken(''.join(chunks), False, start_mark, end_mark, + style) + + def scan_block_scalar_indicators(self, start_mark): + # See the specification for details. + chomping = None + increment = None + ch = self.peek() + if ch in '+-': + if ch == '+': + chomping = True + else: + chomping = False + self.forward() + ch = self.peek() + if ch in '0123456789': + increment = int(ch) + if increment == 0: + raise ScannerError("while scanning a block scalar", start_mark, + "expected indentation indicator in the range 1-9, but found 0", + self.get_mark()) + self.forward() + elif ch in '0123456789': + increment = int(ch) + if increment == 0: + raise ScannerError("while scanning a block scalar", start_mark, + "expected indentation indicator in the range 1-9, but found 0", + self.get_mark()) + self.forward() + ch = self.peek() + if ch in '+-': + if ch == '+': + chomping = True + else: + chomping = False + self.forward() + ch = self.peek() + if ch not in '\0 \r\n\x85\u2028\u2029': + raise ScannerError("while scanning a block scalar", start_mark, + "expected chomping or indentation indicators, but found %r" + % ch, self.get_mark()) + return chomping, increment + + def scan_block_scalar_ignored_line(self, start_mark): + # See the specification for details. + while self.peek() == ' ': + self.forward() + if self.peek() == '#': + while self.peek() not in '\0\r\n\x85\u2028\u2029': + self.forward() + ch = self.peek() + if ch not in '\0\r\n\x85\u2028\u2029': + raise ScannerError("while scanning a block scalar", start_mark, + "expected a comment or a line break, but found %r" % ch, + self.get_mark()) + self.scan_line_break() + + def scan_block_scalar_indentation(self): + # See the specification for details. + chunks = [] + max_indent = 0 + end_mark = self.get_mark() + while self.peek() in ' \r\n\x85\u2028\u2029': + if self.peek() != ' ': + chunks.append(self.scan_line_break()) + end_mark = self.get_mark() + else: + self.forward() + if self.column > max_indent: + max_indent = self.column + return chunks, max_indent, end_mark + + def scan_block_scalar_breaks(self, indent): + # See the specification for details. + chunks = [] + end_mark = self.get_mark() + while self.column < indent and self.peek() == ' ': + self.forward() + while self.peek() in '\r\n\x85\u2028\u2029': + chunks.append(self.scan_line_break()) + end_mark = self.get_mark() + while self.column < indent and self.peek() == ' ': + self.forward() + return chunks, end_mark + + def scan_flow_scalar(self, style): + # See the specification for details. + # Note that we loose indentation rules for quoted scalars. Quoted + # scalars don't need to adhere indentation because " and ' clearly + # mark the beginning and the end of them. Therefore we are less + # restrictive then the specification requires. We only need to check + # that document separators are not included in scalars. + if style == '"': + double = True + else: + double = False + chunks = [] + start_mark = self.get_mark() + quote = self.peek() + self.forward() + chunks.extend(self.scan_flow_scalar_non_spaces(double, start_mark)) + while self.peek() != quote: + chunks.extend(self.scan_flow_scalar_spaces(double, start_mark)) + chunks.extend(self.scan_flow_scalar_non_spaces(double, start_mark)) + self.forward() + end_mark = self.get_mark() + return ScalarToken(''.join(chunks), False, start_mark, end_mark, + style) + + ESCAPE_REPLACEMENTS = { + '0': '\0', + 'a': '\x07', + 'b': '\x08', + 't': '\x09', + '\t': '\x09', + 'n': '\x0A', + 'v': '\x0B', + 'f': '\x0C', + 'r': '\x0D', + 'e': '\x1B', + ' ': '\x20', + '\"': '\"', + '\\': '\\', + '/': '/', + 'N': '\x85', + '_': '\xA0', + 'L': '\u2028', + 'P': '\u2029', + } + + ESCAPE_CODES = { + 'x': 2, + 'u': 4, + 'U': 8, + } + + def scan_flow_scalar_non_spaces(self, double, start_mark): + # See the specification for details. + chunks = [] + while True: + length = 0 + while self.peek(length) not in '\'\"\\\0 \t\r\n\x85\u2028\u2029': + length += 1 + if length: + chunks.append(self.prefix(length)) + self.forward(length) + ch = self.peek() + if not double and ch == '\'' and self.peek(1) == '\'': + chunks.append('\'') + self.forward(2) + elif (double and ch == '\'') or (not double and ch in '\"\\'): + chunks.append(ch) + self.forward() + elif double and ch == '\\': + self.forward() + ch = self.peek() + if ch in self.ESCAPE_REPLACEMENTS: + chunks.append(self.ESCAPE_REPLACEMENTS[ch]) + self.forward() + elif ch in self.ESCAPE_CODES: + length = self.ESCAPE_CODES[ch] + self.forward() + for k in range(length): + if self.peek(k) not in '0123456789ABCDEFabcdef': + raise ScannerError("while scanning a double-quoted scalar", start_mark, + "expected escape sequence of %d hexdecimal numbers, but found %r" % + (length, self.peek(k)), self.get_mark()) + code = int(self.prefix(length), 16) + chunks.append(chr(code)) + self.forward(length) + elif ch in '\r\n\x85\u2028\u2029': + self.scan_line_break() + chunks.extend(self.scan_flow_scalar_breaks(double, start_mark)) + else: + raise ScannerError("while scanning a double-quoted scalar", start_mark, + "found unknown escape character %r" % ch, self.get_mark()) + else: + return chunks + + def scan_flow_scalar_spaces(self, double, start_mark): + # See the specification for details. + chunks = [] + length = 0 + while self.peek(length) in ' \t': + length += 1 + whitespaces = self.prefix(length) + self.forward(length) + ch = self.peek() + if ch == '\0': + raise ScannerError("while scanning a quoted scalar", start_mark, + "found unexpected end of stream", self.get_mark()) + elif ch in '\r\n\x85\u2028\u2029': + line_break = self.scan_line_break() + breaks = self.scan_flow_scalar_breaks(double, start_mark) + if line_break != '\n': + chunks.append(line_break) + elif not breaks: + chunks.append(' ') + chunks.extend(breaks) + else: + chunks.append(whitespaces) + return chunks + + def scan_flow_scalar_breaks(self, double, start_mark): + # See the specification for details. + chunks = [] + while True: + # Instead of checking indentation, we check for document + # separators. + prefix = self.prefix(3) + if (prefix == '---' or prefix == '...') \ + and self.peek(3) in '\0 \t\r\n\x85\u2028\u2029': + raise ScannerError("while scanning a quoted scalar", start_mark, + "found unexpected document separator", self.get_mark()) + while self.peek() in ' \t': + self.forward() + if self.peek() in '\r\n\x85\u2028\u2029': + chunks.append(self.scan_line_break()) + else: + return chunks + + def scan_plain(self): + # See the specification for details. + # We add an additional restriction for the flow context: + # plain scalars in the flow context cannot contain ',' or '?'. + # We also keep track of the `allow_simple_key` flag here. + # Indentation rules are loosed for the flow context. + chunks = [] + start_mark = self.get_mark() + end_mark = start_mark + indent = self.indent+1 + # We allow zero indentation for scalars, but then we need to check for + # document separators at the beginning of the line. + #if indent == 0: + # indent = 1 + spaces = [] + while True: + length = 0 + if self.peek() == '#': + break + while True: + ch = self.peek(length) + if ch in '\0 \t\r\n\x85\u2028\u2029' \ + or (ch == ':' and + self.peek(length+1) in '\0 \t\r\n\x85\u2028\u2029' + + (u',[]{}' if self.flow_level else u''))\ + or (self.flow_level and ch in ',?[]{}'): + break + length += 1 + if length == 0: + break + self.allow_simple_key = False + chunks.extend(spaces) + chunks.append(self.prefix(length)) + self.forward(length) + end_mark = self.get_mark() + spaces = self.scan_plain_spaces(indent, start_mark) + if not spaces or self.peek() == '#' \ + or (not self.flow_level and self.column < indent): + break + return ScalarToken(''.join(chunks), True, start_mark, end_mark) + + def scan_plain_spaces(self, indent, start_mark): + # See the specification for details. + # The specification is really confusing about tabs in plain scalars. + # We just forbid them completely. Do not use tabs in YAML! + chunks = [] + length = 0 + while self.peek(length) in ' ': + length += 1 + whitespaces = self.prefix(length) + self.forward(length) + ch = self.peek() + if ch in '\r\n\x85\u2028\u2029': + line_break = self.scan_line_break() + self.allow_simple_key = True + prefix = self.prefix(3) + if (prefix == '---' or prefix == '...') \ + and self.peek(3) in '\0 \t\r\n\x85\u2028\u2029': + return + breaks = [] + while self.peek() in ' \r\n\x85\u2028\u2029': + if self.peek() == ' ': + self.forward() + else: + breaks.append(self.scan_line_break()) + prefix = self.prefix(3) + if (prefix == '---' or prefix == '...') \ + and self.peek(3) in '\0 \t\r\n\x85\u2028\u2029': + return + if line_break != '\n': + chunks.append(line_break) + elif not breaks: + chunks.append(' ') + chunks.extend(breaks) + elif whitespaces: + chunks.append(whitespaces) + return chunks + + def scan_tag_handle(self, name, start_mark): + # See the specification for details. + # For some strange reasons, the specification does not allow '_' in + # tag handles. I have allowed it anyway. + ch = self.peek() + if ch != '!': + raise ScannerError("while scanning a %s" % name, start_mark, + "expected '!', but found %r" % ch, self.get_mark()) + length = 1 + ch = self.peek(length) + if ch != ' ': + while '0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \ + or ch in '-_': + length += 1 + ch = self.peek(length) + if ch != '!': + self.forward(length) + raise ScannerError("while scanning a %s" % name, start_mark, + "expected '!', but found %r" % ch, self.get_mark()) + length += 1 + value = self.prefix(length) + self.forward(length) + return value + + def scan_tag_uri(self, name, start_mark): + # See the specification for details. + # Note: we do not check if URI is well-formed. + chunks = [] + length = 0 + ch = self.peek(length) + while '0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \ + or ch in '-;/?:@&=+$,_.!~*\'()[]%': + if ch == '%': + chunks.append(self.prefix(length)) + self.forward(length) + length = 0 + chunks.append(self.scan_uri_escapes(name, start_mark)) + else: + length += 1 + ch = self.peek(length) + if length: + chunks.append(self.prefix(length)) + self.forward(length) + length = 0 + if not chunks: + raise ScannerError("while parsing a %s" % name, start_mark, + "expected URI, but found %r" % ch, self.get_mark()) + return ''.join(chunks) + + def scan_uri_escapes(self, name, start_mark): + # See the specification for details. + codes = [] + mark = self.get_mark() + while self.peek() == '%': + self.forward() + for k in range(2): + if self.peek(k) not in '0123456789ABCDEFabcdef': + raise ScannerError("while scanning a %s" % name, start_mark, + "expected URI escape sequence of 2 hexdecimal numbers, but found %r" + % self.peek(k), self.get_mark()) + codes.append(int(self.prefix(2), 16)) + self.forward(2) + try: + value = bytes(codes).decode('utf-8') + except UnicodeDecodeError as exc: + raise ScannerError("while scanning a %s" % name, start_mark, str(exc), mark) + return value + + def scan_line_break(self): + # Transforms: + # '\r\n' : '\n' + # '\r' : '\n' + # '\n' : '\n' + # '\x85' : '\n' + # '\u2028' : '\u2028' + # '\u2029 : '\u2029' + # default : '' + ch = self.peek() + if ch in '\r\n\x85': + if self.prefix(2) == '\r\n': + self.forward(2) + else: + self.forward() + return '\n' + elif ch in '\u2028\u2029': + self.forward() + return ch + return '' diff --git a/libs/yaml/serializer.py b/libs/yaml/serializer.py new file mode 100644 index 000000000..fe911e67a --- /dev/null +++ b/libs/yaml/serializer.py @@ -0,0 +1,111 @@ + +__all__ = ['Serializer', 'SerializerError'] + +from .error import YAMLError +from .events import * +from .nodes import * + +class SerializerError(YAMLError): + pass + +class Serializer: + + ANCHOR_TEMPLATE = 'id%03d' + + def __init__(self, encoding=None, + explicit_start=None, explicit_end=None, version=None, tags=None): + self.use_encoding = encoding + self.use_explicit_start = explicit_start + self.use_explicit_end = explicit_end + self.use_version = version + self.use_tags = tags + self.serialized_nodes = {} + self.anchors = {} + self.last_anchor_id = 0 + self.closed = None + + def open(self): + if self.closed is None: + self.emit(StreamStartEvent(encoding=self.use_encoding)) + self.closed = False + elif self.closed: + raise SerializerError("serializer is closed") + else: + raise SerializerError("serializer is already opened") + + def close(self): + if self.closed is None: + raise SerializerError("serializer is not opened") + elif not self.closed: + self.emit(StreamEndEvent()) + self.closed = True + + #def __del__(self): + # self.close() + + def serialize(self, node): + if self.closed is None: + raise SerializerError("serializer is not opened") + elif self.closed: + raise SerializerError("serializer is closed") + self.emit(DocumentStartEvent(explicit=self.use_explicit_start, + version=self.use_version, tags=self.use_tags)) + self.anchor_node(node) + self.serialize_node(node, None, None) + self.emit(DocumentEndEvent(explicit=self.use_explicit_end)) + self.serialized_nodes = {} + self.anchors = {} + self.last_anchor_id = 0 + + def anchor_node(self, node): + if node in self.anchors: + if self.anchors[node] is None: + self.anchors[node] = self.generate_anchor(node) + else: + self.anchors[node] = None + if isinstance(node, SequenceNode): + for item in node.value: + self.anchor_node(item) + elif isinstance(node, MappingNode): + for key, value in node.value: + self.anchor_node(key) + self.anchor_node(value) + + def generate_anchor(self, node): + self.last_anchor_id += 1 + return self.ANCHOR_TEMPLATE % self.last_anchor_id + + def serialize_node(self, node, parent, index): + alias = self.anchors[node] + if node in self.serialized_nodes: + self.emit(AliasEvent(alias)) + else: + self.serialized_nodes[node] = True + self.descend_resolver(parent, index) + if isinstance(node, ScalarNode): + detected_tag = self.resolve(ScalarNode, node.value, (True, False)) + default_tag = self.resolve(ScalarNode, node.value, (False, True)) + implicit = (node.tag == detected_tag), (node.tag == default_tag) + self.emit(ScalarEvent(alias, node.tag, implicit, node.value, + style=node.style)) + elif isinstance(node, SequenceNode): + implicit = (node.tag + == self.resolve(SequenceNode, node.value, True)) + self.emit(SequenceStartEvent(alias, node.tag, implicit, + flow_style=node.flow_style)) + index = 0 + for item in node.value: + self.serialize_node(item, node, index) + index += 1 + self.emit(SequenceEndEvent()) + elif isinstance(node, MappingNode): + implicit = (node.tag + == self.resolve(MappingNode, node.value, True)) + self.emit(MappingStartEvent(alias, node.tag, implicit, + flow_style=node.flow_style)) + for key, value in node.value: + self.serialize_node(key, node, None) + self.serialize_node(value, node, key) + self.emit(MappingEndEvent()) + self.ascend_resolver() + diff --git a/libs/yaml/tokens.py b/libs/yaml/tokens.py new file mode 100644 index 000000000..4d0b48a39 --- /dev/null +++ b/libs/yaml/tokens.py @@ -0,0 +1,104 @@ + +class Token(object): + def __init__(self, start_mark, end_mark): + self.start_mark = start_mark + self.end_mark = end_mark + def __repr__(self): + attributes = [key for key in self.__dict__ + if not key.endswith('_mark')] + attributes.sort() + arguments = ', '.join(['%s=%r' % (key, getattr(self, key)) + for key in attributes]) + return '%s(%s)' % (self.__class__.__name__, arguments) + +#class BOMToken(Token): +# id = '<byte order mark>' + +class DirectiveToken(Token): + id = '<directive>' + def __init__(self, name, value, start_mark, end_mark): + self.name = name + self.value = value + self.start_mark = start_mark + self.end_mark = end_mark + +class DocumentStartToken(Token): + id = '<document start>' + +class DocumentEndToken(Token): + id = '<document end>' + +class StreamStartToken(Token): + id = '<stream start>' + def __init__(self, start_mark=None, end_mark=None, + encoding=None): + self.start_mark = start_mark + self.end_mark = end_mark + self.encoding = encoding + +class StreamEndToken(Token): + id = '<stream end>' + +class BlockSequenceStartToken(Token): + id = '<block sequence start>' + +class BlockMappingStartToken(Token): + id = '<block mapping start>' + +class BlockEndToken(Token): + id = '<block end>' + +class FlowSequenceStartToken(Token): + id = '[' + +class FlowMappingStartToken(Token): + id = '{' + +class FlowSequenceEndToken(Token): + id = ']' + +class FlowMappingEndToken(Token): + id = '}' + +class KeyToken(Token): + id = '?' + +class ValueToken(Token): + id = ':' + +class BlockEntryToken(Token): + id = '-' + +class FlowEntryToken(Token): + id = ',' + +class AliasToken(Token): + id = '<alias>' + def __init__(self, value, start_mark, end_mark): + self.value = value + self.start_mark = start_mark + self.end_mark = end_mark + +class AnchorToken(Token): + id = '<anchor>' + def __init__(self, value, start_mark, end_mark): + self.value = value + self.start_mark = start_mark + self.end_mark = end_mark + +class TagToken(Token): + id = '<tag>' + def __init__(self, value, start_mark, end_mark): + self.value = value + self.start_mark = start_mark + self.end_mark = end_mark + +class ScalarToken(Token): + id = '<scalar>' + def __init__(self, value, plain, start_mark, end_mark, style=None): + self.value = value + self.plain = plain + self.start_mark = start_mark + self.end_mark = end_mark + self.style = style + |