diff options
author | Andrei Lebedev <[email protected]> | 2022-11-03 11:09:37 +0100 |
---|---|---|
committer | GitHub <[email protected]> | 2022-11-03 10:09:37 +0000 |
commit | 27ed77aabba8c9eb08d66f34092b1bfcc22c482e (patch) | |
tree | 7cc41fc5e398009a5cf8e7e4156afb0246aa34d3 | |
parent | c4b19a88169fa76c5eb665d274e7270a0fe452c4 (diff) | |
download | youtube-dl-27ed77aabba8c9eb08d66f34092b1bfcc22c482e.tar.gz youtube-dl-27ed77aabba8c9eb08d66f34092b1bfcc22c482e.zip |
[utils] Backport traverse_obj (etc) from yt-dlp (#31156)
* Backport traverse_obj and closely related function from yt-dlp (code by pukkandan)
* Backport LazyList, variadic(), try_call (code by pukkandan)
* Recast using yt-dlp's newer traverse_obj() implementation and tests (code by grub4k)
* Add tests for Unicode case folding support matching Py3.5+ (requires f102e3d)
* Improve/add tests for variadic, try_call, join_nonempty
Co-authored-by: dirkf <[email protected]>
-rw-r--r-- | test/test_utils.py | 323 | ||||
-rw-r--r-- | youtube_dl/utils.py | 339 |
2 files changed, 662 insertions, 0 deletions
diff --git a/test/test_utils.py b/test/test_utils.py index f1a748dde..9d364c863 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -12,7 +12,9 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # Various small unit tests import io +import itertools import json +import re import xml.etree.ElementTree from youtube_dl.utils import ( @@ -40,11 +42,14 @@ from youtube_dl.utils import ( get_element_by_attribute, get_elements_by_class, get_elements_by_attribute, + get_first, InAdvancePagedList, int_or_none, intlist_to_bytes, is_html, + join_nonempty, js_to_json, + LazyList, limit_length, merge_dicts, mimetype2ext, @@ -79,6 +84,8 @@ from youtube_dl.utils import ( strip_or_none, subtitles_filename, timeconvert, + traverse_obj, + try_call, unescapeHTML, unified_strdate, unified_timestamp, @@ -92,6 +99,7 @@ from youtube_dl.utils import ( urlencode_postdata, urshift, update_url_query, + variadic, version_tuple, xpath_with_ns, xpath_element, @@ -112,12 +120,18 @@ from youtube_dl.compat import ( compat_getenv, compat_os_name, compat_setenv, + compat_str, compat_urlparse, compat_parse_qs, ) class TestUtil(unittest.TestCase): + + # yt-dlp shim + def assertCountEqual(self, expected, got, msg='count should be the same'): + return self.assertEqual(len(tuple(expected)), len(tuple(got)), msg=msg) + def test_timeconvert(self): self.assertTrue(timeconvert('') is None) self.assertTrue(timeconvert('bougrg') is None) @@ -1478,6 +1492,315 @@ Line 1 self.assertEqual(clean_podcast_url('https://www.podtrac.com/pts/redirect.mp3/chtbl.com/track/5899E/traffic.megaphone.fm/HSW7835899191.mp3'), 'https://traffic.megaphone.fm/HSW7835899191.mp3') self.assertEqual(clean_podcast_url('https://play.podtrac.com/npr-344098539/edge1.pod.npr.org/anon.npr-podcasts/podcast/npr/waitwait/2020/10/20201003_waitwait_wwdtmpodcast201003-015621a5-f035-4eca-a9a1-7c118d90bc3c.mp3'), 'https://edge1.pod.npr.org/anon.npr-podcasts/podcast/npr/waitwait/2020/10/20201003_waitwait_wwdtmpodcast201003-015621a5-f035-4eca-a9a1-7c118d90bc3c.mp3') + def test_LazyList(self): + it = list(range(10)) + + self.assertEqual(list(LazyList(it)), it) + self.assertEqual(LazyList(it).exhaust(), it) + self.assertEqual(LazyList(it)[5], it[5]) + + self.assertEqual(LazyList(it)[5:], it[5:]) + self.assertEqual(LazyList(it)[:5], it[:5]) + self.assertEqual(LazyList(it)[::2], it[::2]) + self.assertEqual(LazyList(it)[1::2], it[1::2]) + self.assertEqual(LazyList(it)[5::-1], it[5::-1]) + self.assertEqual(LazyList(it)[6:2:-2], it[6:2:-2]) + self.assertEqual(LazyList(it)[::-1], it[::-1]) + + self.assertTrue(LazyList(it)) + self.assertFalse(LazyList(range(0))) + self.assertEqual(len(LazyList(it)), len(it)) + self.assertEqual(repr(LazyList(it)), repr(it)) + self.assertEqual(compat_str(LazyList(it)), compat_str(it)) + + self.assertEqual(list(LazyList(it, reverse=True)), it[::-1]) + self.assertEqual(list(reversed(LazyList(it))[::-1]), it) + self.assertEqual(list(reversed(LazyList(it))[1:3:7]), it[::-1][1:3:7]) + + def test_LazyList_laziness(self): + + def test(ll, idx, val, cache): + self.assertEqual(ll[idx], val) + self.assertEqual(ll._cache, list(cache)) + + ll = LazyList(range(10)) + test(ll, 0, 0, range(1)) + test(ll, 5, 5, range(6)) + test(ll, -3, 7, range(10)) + + ll = LazyList(range(10), reverse=True) + test(ll, -1, 0, range(1)) + test(ll, 3, 6, range(10)) + + ll = LazyList(itertools.count()) + test(ll, 10, 10, range(11)) + ll = reversed(ll) + test(ll, -15, 14, range(15)) + + def test_try_call(self): + def total(*x, **kwargs): + return sum(x) + sum(kwargs.values()) + + self.assertEqual(try_call(None), None, + msg='not a fn should give None') + self.assertEqual(try_call(lambda: 1), 1, + msg='int fn with no expected_type should give int') + self.assertEqual(try_call(lambda: 1, expected_type=int), 1, + msg='int fn with expected_type int should give int') + self.assertEqual(try_call(lambda: 1, expected_type=dict), None, + msg='int fn with wrong expected_type should give None') + self.assertEqual(try_call(total, args=(0, 1, 0, ), expected_type=int), 1, + msg='fn should accept arglist') + self.assertEqual(try_call(total, kwargs={'a': 0, 'b': 1, 'c': 0}, expected_type=int), 1, + msg='fn should accept kwargs') + self.assertEqual(try_call(lambda: 1, expected_type=dict), None, + msg='int fn with no expected_type should give None') + self.assertEqual(try_call(lambda x: {}, total, args=(42, ), expected_type=int), 42, + msg='expect first int result with expected_type int') + + def test_variadic(self): + self.assertEqual(variadic(None), (None, )) + self.assertEqual(variadic('spam'), ('spam', )) + self.assertEqual(variadic('spam', allowed_types=dict), 'spam') + + def test_traverse_obj(self): + _TEST_DATA = { + 100: 100, + 1.2: 1.2, + 'str': 'str', + 'None': None, + '...': Ellipsis, + 'urls': [ + {'index': 0, 'url': 'https://www.example.com/0'}, + {'index': 1, 'url': 'https://www.example.com/1'}, + ], + 'data': ( + {'index': 2}, + {'index': 3}, + ), + 'dict': {}, + } + + # Test base functionality + self.assertEqual(traverse_obj(_TEST_DATA, ('str',)), 'str', + msg='allow tuple path') + self.assertEqual(traverse_obj(_TEST_DATA, ['str']), 'str', + msg='allow list path') + self.assertEqual(traverse_obj(_TEST_DATA, (value for value in ("str",))), 'str', + msg='allow iterable path') + self.assertEqual(traverse_obj(_TEST_DATA, 'str'), 'str', + msg='single items should be treated as a path') + self.assertEqual(traverse_obj(_TEST_DATA, None), _TEST_DATA) + self.assertEqual(traverse_obj(_TEST_DATA, 100), 100) + self.assertEqual(traverse_obj(_TEST_DATA, 1.2), 1.2) + + # Test Ellipsis behavior + self.assertCountEqual(traverse_obj(_TEST_DATA, Ellipsis), + (item for item in _TEST_DATA.values() if item is not None), + msg='`...` should give all values except `None`') + self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, Ellipsis)), _TEST_DATA['urls'][0].values(), + msg='`...` selection for dicts should select all values') + self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'url')), + ['https://www.example.com/0', 'https://www.example.com/1'], + msg='nested `...` queries should work') + self.assertCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'index')), range(4), + msg='`...` query result should be flattened') + + # Test function as key + self.assertEqual(traverse_obj(_TEST_DATA, lambda x, y: x == 'urls' and isinstance(y, list)), + [_TEST_DATA['urls']], + msg='function as query key should perform a filter based on (key, value)') + self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], compat_str)), {'str'}, + msg='exceptions in the query function should be caught') + + # Test alternative paths + self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str', + msg='multiple `paths` should be treated as alternative paths') + self.assertEqual(traverse_obj(_TEST_DATA, 'str', 100), 'str', + msg='alternatives should exit early') + self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'fail'), None, + msg='alternatives should return `default` if exhausted') + self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, 'fail'), 100), 100, + msg='alternatives should track their own branching return') + self.assertEqual(traverse_obj(_TEST_DATA, ('dict', Ellipsis), ('data', Ellipsis)), list(_TEST_DATA['data']), + msg='alternatives on empty objects should search further') + + # Test branch and path nesting + self.assertEqual(traverse_obj(_TEST_DATA, ('urls', (3, 0), 'url')), ['https://www.example.com/0'], + msg='tuple as key should be treated as branches') + self.assertEqual(traverse_obj(_TEST_DATA, ('urls', [3, 0], 'url')), ['https://www.example.com/0'], + msg='list as key should be treated as branches') + self.assertEqual(traverse_obj(_TEST_DATA, ('urls', ((1, 'fail'), (0, 'url')))), ['https://www.example.com/0'], + msg='double nesting in path should be treated as paths') + self.assertEqual(traverse_obj(['0', [1, 2]], [(0, 1), 0]), [1], + msg='do not fail early on branching') + self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', ((1, ('fail', 'url')), (0, 'url')))), + ['https://www.example.com/0', 'https://www.example.com/1'], + msg='triple nesting in path should be treated as branches') + self.assertEqual(traverse_obj(_TEST_DATA, ('urls', ('fail', (Ellipsis, 'url')))), + ['https://www.example.com/0', 'https://www.example.com/1'], + msg='ellipsis as branch path start gets flattened') + + # Test dictionary as key + self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2}), {0: 100, 1: 1.2}, + msg='dict key should result in a dict with the same keys') + self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', 0, 'url')}), + {0: 'https://www.example.com/0'}, + msg='dict key should allow paths') + self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', (3, 0), 'url')}), + {0: ['https://www.example.com/0']}, + msg='tuple in dict path should be treated as branches') + self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', ((1, 'fail'), (0, 'url')))}), + {0: ['https://www.example.com/0']}, + msg='double nesting in dict path should be treated as paths') + self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', ((1, ('fail', 'url')), (0, 'url')))}), + {0: ['https://www.example.com/1', 'https://www.example.com/0']}, + msg='triple nesting in dict path should be treated as branches') + self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}), {}, + msg='remove `None` values when dict key') + self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}, default=Ellipsis), {0: Ellipsis}, + msg='do not remove `None` values if `default`') + self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {0: {}}, + msg='do not remove empty values when dict key') + self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=Ellipsis), {0: {}}, + msg='do not remove empty values when dict key and a default') + self.assertEqual(traverse_obj(_TEST_DATA, {0: ('dict', Ellipsis)}), {0: []}, + msg='if branch in dict key not successful, return `[]`') + + # Testing default parameter behavior + _DEFAULT_DATA = {'None': None, 'int': 0, 'list': []} + self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail'), None, + msg='default value should be `None`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail', 'fail', default=Ellipsis), Ellipsis, + msg='chained fails should result in default') + self.assertEqual(traverse_obj(_DEFAULT_DATA, 'None', 'int'), 0, + msg='should not short cirquit on `None`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail', default=1), 1, + msg='invalid dict key should result in `default`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, 'None', default=1), 1, + msg='`None` is a deliberate sentinel and should become `default`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', 10)), None, + msg='`IndexError` should result in `default`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, (Ellipsis, 'fail'), default=1), 1, + msg='if branched but not successful return `default` if defined, not `[]`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, (Ellipsis, 'fail'), default=None), None, + msg='if branched but not successful return `default` even if `default` is `None`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, (Ellipsis, 'fail')), [], + msg='if branched but not successful return `[]`, not `default`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', Ellipsis)), [], + msg='if branched but object is empty return `[]`, not `default`') + + # Testing expected_type behavior + _EXPECTED_TYPE_DATA = {'str': 'str', 'int': 0} + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=compat_str), 'str', + msg='accept matching `expected_type` type') + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int), None, + msg='reject non matching `expected_type` type') + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: compat_str(x)), '0', + msg='transform type using type function') + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', + expected_type=lambda _: 1 / 0), None, + msg='wrap expected_type function in try_call') + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, Ellipsis, expected_type=compat_str), ['str'], + msg='eliminate items that expected_type fails on') + + # Test get_all behavior + _GET_ALL_DATA = {'key': [0, 1, 2]} + self.assertEqual(traverse_obj(_GET_ALL_DATA, ('key', Ellipsis), get_all=False), 0, + msg='if not `get_all`, return only first matching value') + self.assertEqual(traverse_obj(_GET_ALL_DATA, Ellipsis, get_all=False), [0, 1, 2], + msg='do not overflatten if not `get_all`') + + # Test casesense behavior + _CASESENSE_DATA = { + 'KeY': 'value0', + 0: { + 'KeY': 'value1', + 0: {'KeY': 'value2'}, + }, + # FULLWIDTH LATIN CAPITAL LETTER K + '\uff2bey': 'value3', + } + self.assertEqual(traverse_obj(_CASESENSE_DATA, 'key'), None, + msg='dict keys should be case sensitive unless `casesense`') + self.assertEqual(traverse_obj(_CASESENSE_DATA, 'keY', + casesense=False), 'value0', + msg='allow non matching key case if `casesense`') + self.assertEqual(traverse_obj(_CASESENSE_DATA, '\uff4bey', # FULLWIDTH LATIN SMALL LETTER K + casesense=False), 'value3', + msg='allow non matching Unicode key case if `casesense`') + self.assertEqual(traverse_obj(_CASESENSE_DATA, (0, ('keY',)), + casesense=False), ['value1'], + msg='allow non matching key case in branch if `casesense`') + self.assertEqual(traverse_obj(_CASESENSE_DATA, (0, ((0, 'keY'),)), + casesense=False), ['value2'], + msg='allow non matching key case in branch path if `casesense`') + + # Test traverse_string behavior + _TRAVERSE_STRING_DATA = {'str': 'str', 1.2: 1.2} + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', 0)), None, + msg='do not traverse into string if not `traverse_string`') + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', 0), + _traverse_string=True), 's', + msg='traverse into string if `traverse_string`') + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, (1.2, 1), + _traverse_string=True), '.', + msg='traverse into converted data if `traverse_string`') + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', Ellipsis), + _traverse_string=True), list('str'), + msg='`...` branching into string should result in list') + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)), + _traverse_string=True), ['s', 'r'], + msg='branching into string should result in list') + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda _, x: x), + _traverse_string=True), list('str'), + msg='function branching into string should result in list') + + # Test is_user_input behavior + _IS_USER_INPUT_DATA = {'range8': list(range(8))} + self.assertEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3'), + _is_user_input=True), 3, + msg='allow for string indexing if `is_user_input`') + self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3:'), + _is_user_input=True), tuple(range(8))[3:], + msg='allow for string slice if `is_user_input`') + self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':4:2'), + _is_user_input=True), tuple(range(8))[:4:2], + msg='allow step in string slice if `is_user_input`') + self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':'), + _is_user_input=True), range(8), + msg='`:` should be treated as `...` if `is_user_input`') + with self.assertRaises(TypeError, msg='too many params should result in error'): + traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':::'), _is_user_input=True) + + # Test re.Match as input obj + mobj = re.match(r'^0(12)(?P<group>3)(4)?$', '0123') + self.assertEqual(traverse_obj(mobj, Ellipsis), [x for x in mobj.groups() if x is not None], + msg='`...` on a `re.Match` should give its `groups()`') + self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 2)), ['0123', '3'], + msg='function on a `re.Match` should give groupno, value starting at 0') + self.assertEqual(traverse_obj(mobj, 'group'), '3', + msg='str key on a `re.Match` should give group with that name') + self.assertEqual(traverse_obj(mobj, 2), '3', + msg='int key on a `re.Match` should give group with that name') + self.assertEqual(traverse_obj(mobj, 'gRoUp', casesense=False), '3', + msg='str key on a `re.Match` should respect casesense') + self.assertEqual(traverse_obj(mobj, 'fail'), None, + msg='failing str key on a `re.Match` should return `default`') + self.assertEqual(traverse_obj(mobj, 'gRoUpS', casesense=False), None, + msg='failing str key on a `re.Match` should return `default`') + self.assertEqual(traverse_obj(mobj, 8), None, + msg='failing int key on a `re.Match` should return `default`') + + def test_get_first(self): + self.assertEqual(get_first([{'a': None}, {'a': 'spam'}], 'a'), 'spam') + + def test_join_nonempty(self): + self.assertEqual(join_nonempty('a', 'b'), 'a-b') + self.assertEqual(join_nonempty( + 'a', 'b', 'c', 'd', + from_dict={'a': 'c', 'c': [], 'b': 'd', 'd': None}), 'c-d') + if __name__ == '__main__': unittest.main() diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py index 23a65a81c..e3c3ccff9 100644 --- a/youtube_dl/utils.py +++ b/youtube_dl/utils.py @@ -43,6 +43,7 @@ from .compat import ( compat_HTTPError, compat_basestring, compat_chr, + compat_collections_abc, compat_cookiejar, compat_ctypes_WINFUNCTYPE, compat_etree_fromstring, @@ -1685,6 +1686,7 @@ USER_AGENTS = { NO_DEFAULT = object() +IDENTITY = lambda x: x ENGLISH_MONTH_NAMES = [ 'January', 'February', 'March', 'April', 'May', 'June', @@ -3867,6 +3869,105 @@ def detect_exe_version(output, version_re=None, unrecognized='present'): return unrecognized +class LazyList(compat_collections_abc.Sequence): + """Lazy immutable list from an iterable + Note that slices of a LazyList are lists and not LazyList""" + + class IndexError(IndexError): + def __init__(self, cause=None): + if cause: + # reproduce `raise from` + self.__cause__ = cause + super(IndexError, self).__init__() + + def __init__(self, iterable, **kwargs): + # kwarg-only + reverse = kwargs.get('reverse', False) + _cache = kwargs.get('_cache') + + self._iterable = iter(iterable) + self._cache = [] if _cache is None else _cache + self._reversed = reverse + + def __iter__(self): + if self._reversed: + # We need to consume the entire iterable to iterate in reverse + for item in self.exhaust(): + yield item + return + for item in self._cache: + yield item + for item in self._iterable: + self._cache.append(item) + yield item + + def _exhaust(self): + self._cache.extend(self._iterable) + self._iterable = [] # Discard the emptied iterable to make it pickle-able + return self._cache + + def exhaust(self): + """Evaluate the entire iterable""" + return self._exhaust()[::-1 if self._reversed else 1] + + @staticmethod + def _reverse_index(x): + return None if x is None else ~x + + def __getitem__(self, idx): + if isinstance(idx, slice): + if self._reversed: + idx = slice(self._reverse_index(idx.start), self._reverse_index(idx.stop), -(idx.step or 1)) + start, stop, step = idx.start, idx.stop, idx.step or 1 + elif isinstance(idx, int): + if self._reversed: + idx = self._reverse_index(idx) + start, stop, step = idx, idx, 0 + else: + raise TypeError('indices must be integers or slices') + if ((start or 0) < 0 or (stop or 0) < 0 + or (start is None and step < 0) + or (stop is None and step > 0)): + # We need to consume the entire iterable to be able to slice from the end + # Obviously, never use this with infinite iterables + self._exhaust() + try: + return self._cache[idx] + except IndexError as e: + raise self.IndexError(e) + n = max(start or 0, stop or 0) - len(self._cache) + 1 + if n > 0: + self._cache.extend(itertools.islice(self._iterable, n)) + try: + return self._cache[idx] + except IndexError as e: + raise self.IndexError(e) + + def __bool__(self): + try: + self[-1] if self._reversed else self[0] + except self.IndexError: + return False + return True + + def __len__(self): + self._exhaust() + return len(self._cache) + + def __reversed__(self): + return type(self)(self._iterable, reverse=not self._reversed, _cache=self._cache) + + def __copy__(self): + return type(self)(self._iterable, reverse=self._reversed, _cache=self._cache) + + def __repr__(self): + # repr and str should mimic a list. So we exhaust the iterable + return repr(self.exhaust()) + + def __str__(self): + return repr(self.exhaust()) + + class PagedList(object): def __len__(self): # This is only useful for tests @@ -4092,6 +4193,10 @@ def multipart_encode(data, boundary=None): return out, content_type +def variadic(x, allowed_types=(compat_str, bytes, dict)): + return x if isinstance(x, compat_collections_abc.Iterable) and not isinstance(x, allowed_types) else (x,) + + def dict_get(d, key_or_keys, default=None, skip_false_values=True): if isinstance(key_or_keys, (list, tuple)): for key in key_or_keys: @@ -4102,6 +4207,23 @@ def dict_get(d, key_or_keys, default=None, skip_false_values=True): return d.get(key_or_keys, default) +def try_call(*funcs, **kwargs): + + # parameter defaults + expected_type = kwargs.get('expected_type') + fargs = kwargs.get('args', []) + fkwargs = kwargs.get('kwargs', {}) + + for f in funcs: + try: + val = f(*fargs, **fkwargs) + except (AttributeError, KeyError, TypeError, IndexError, ZeroDivisionError): + pass + else: + if expected_type is None or isinstance(val, expected_type): + return val + + def try_get(src, getter, expected_type=None): if not isinstance(getter, (list, tuple)): getter = [getter] @@ -5835,3 +5957,220 @@ def clean_podcast_url(url): st\.fm # https://podsights.com/docs/ )/e )/''', '', url) + + +def traverse_obj(obj, *paths, **kwargs): + """ + Safely traverse nested `dict`s and `Sequence`s + + >>> obj = [{}, {"key": "value"}] + >>> traverse_obj(obj, (1, "key")) + "value" + + Each of the provided `paths` is tested and the first producing a valid result will be returned. + The next path will also be tested if the path branched but no results could be found. + Supported values for traversal are `Mapping`, `Sequence` and `re.Match`. + A value of None is treated as the absence of a value. + + The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`. + + The keys in the path can be one of: + - `None`: Return the current object. + - `str`/`int`: Return `obj[key]`. For `re.Match, return `obj.group(key)`. + - `slice`: Branch out and return all values in `obj[key]`. + - `Ellipsis`: Branch out and return a list of all values. + - `tuple`/`list`: Branch out and return a list of all matching values. + Read as: `[traverse_obj(obj, branch) for branch in branches]`. + - `function`: Branch out and return values filtered by the function. + Read as: `[value for key, value in obj if function(key, value)]`. + For `Sequence`s, `key` is the index of the value. + - `dict` Transform the current object and return a matching dict. + Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`. + + `tuple`, `list`, and `dict` all support nested paths and branches. + + @params paths Paths which to traverse by. + Keyword arguments: + @param default Value to return if the paths do not match. + @param expected_type If a `type`, only accept final values of this type. + If any other callable, try to call the function on each result. + @param get_all If `False`, return the first matching result, otherwise all matching ones. + @param casesense If `False`, consider string dictionary keys as case insensitive. + + The following are only meant to be used by YoutubeDL.prepare_outtmpl and are not part of the API + + @param _is_user_input Whether the keys are generated from user input. + If `True` strings get converted to `int`/`slice` if needed. + @param _traverse_string Whether to traverse into objects as strings. + If `True`, any non-compatible object will first be + converted into a string and then traversed into. + + + @returns The result of the object traversal. + If successful, `get_all=True`, and the path branches at least once, + then a list of results is returned instead. + A list is always returned if the last path branches and no `default` is given. + """ + + # parameter defaults + default = kwargs.get('default', NO_DEFAULT) + expected_type = kwargs.get('expected_type') + get_all = kwargs.get('get_all', True) + casesense = kwargs.get('casesense', True) + _is_user_input = kwargs.get('_is_user_input', False) + _traverse_string = kwargs.get('_traverse_string', False) + + # instant compat + str = compat_str + + is_sequence = lambda x: isinstance(x, compat_collections_abc.Sequence) and not isinstance(x, (str, bytes)) + # stand-in until compat_re_Match is added + compat_re_Match = type(re.match('a', 'a')) + # stand-in until casefold.py is added + try: + ''.casefold() + compat_casefold = lambda s: s.casefold() + except AttributeError: + compat_casefold = lambda s: s.lower() + casefold = lambda k: compat_casefold(k) if isinstance(k, str) else k + + if isinstance(expected_type, type): + type_test = lambda val: val if isinstance(val, expected_type) else None + else: + type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,)) + + def from_iterable(iterables): + # chain.from_iterable(['ABC', 'DEF']) --> A B C D E F + for it in iterables: + for item in it: + yield item + + def apply_key(key, obj): + if obj is None: + return + + elif key is None: + yield obj + + elif isinstance(key, (list, tuple)): + for branch in key: + _, result = apply_path(obj, branch) + for item in result: + yield item + + elif key is Ellipsis: + result = [] + if isinstance(obj, compat_collections_abc.Mapping): + result = obj.values() + elif is_sequence(obj): + result = obj + elif isinstance(obj, compat_re_Match): + result = obj.groups() + elif _traverse_string: + result = str(obj) + for item in result: + yield item + + elif callable(key): + if is_sequence(obj): + iter_obj = enumerate(obj) + elif isinstance(obj, compat_collections_abc.Mapping): + iter_obj = obj.items() + elif isinstance(obj, compat_re_Match): + iter_obj = enumerate(itertools.chain([obj.group()], obj.groups())) + elif _traverse_string: + iter_obj = enumerate(str(obj)) + else: + return + for item in (v for k, v in iter_obj if try_call(key, args=(k, v))): + yield item + + elif isinstance(key, dict): + iter_obj = ((k, _traverse_obj(obj, v)) for k, v in key.items()) + yield dict((k, v if v is not None else default) for k, v in iter_obj + if v is not None or default is not NO_DEFAULT) + + elif isinstance(obj, compat_collections_abc.Mapping): + yield (obj.get(key) if casesense or (key in obj) + else next((v for k, v in obj.items() if casefold(k) == key), None)) + + elif isinstance(obj, compat_re_Match): + if isinstance(key, int) or casesense: + try: + yield obj.group(key) + return + except IndexError: + pass + if not isinstance(key, str): + return + + yield next((v for k, v in obj.groupdict().items() if casefold(k) == key), None) + + else: + if _is_user_input: + key = (int_or_none(key) if ':' not in key + else slice(*map(int_or_none, key.split(':')))) + + if not isinstance(key, (int, slice)): + return + + if not is_sequence(obj): + if not _traverse_string: + return + obj = str(obj) + + try: + yield obj[key] + except IndexError: + pass + + def apply_path(start_obj, path): + objs = (start_obj,) + has_branched = False + + for key in variadic(path): + if _is_user_input and key == ':': + key = Ellipsis + + if not casesense and isinstance(key, str): + key = compat_casefold(key) + + if key is Ellipsis or isinstance(key, (list, tuple)) or callable(key): + has_branched = True + + key_func = functools.partial(apply_key, key) + objs = from_iterable(map(key_func, objs)) + + return has_branched, objs + + def _traverse_obj(obj, path, use_list=True): + has_branched, results = apply_path(obj, path) + results = LazyList(x for x in map(type_test, results) if x is not None) + + if get_all and has_branched: + return results.exhaust() if results or use_list else None + + return results[0] if results else None + + for index, path in enumerate(paths, 1): + use_list = default is NO_DEFAULT and index == len(paths) + result = _traverse_obj(obj, path, use_list) + if result is not None: + return result + + return None if default is NO_DEFAULT else default + + +def get_first(obj, keys, **kwargs): + return traverse_obj(obj, (Ellipsis,) + tuple(variadic(keys)), get_all=False, **kwargs) + + +def join_nonempty(*values, **kwargs): + + # parameter defaults + delim = kwargs.get('delim', '-') + from_dict = kwargs.get('from_dict') + + if from_dict is not None: + values = (traverse_obj(from_dict, variadic(v)) for v in values) + return delim.join(map(compat_str, filter(None, values))) |