diff options
author | dirkf <[email protected]> | 2023-07-06 15:46:22 +0100 |
---|---|---|
committer | dirkf <[email protected]> | 2023-07-18 10:50:46 +0100 |
commit | f47fdb9564d3ca1c0fa70ed6031148ec908fdc7b (patch) | |
tree | d4c275d568e2a6d189a66ecdc967c0a16bbafcce /youtube_dl/utils.py | |
parent | b6dff4073d469cceadb099c00ccbf3bd6fc515a6 (diff) | |
download | youtube-dl-f47fdb9564d3ca1c0fa70ed6031148ec908fdc7b.tar.gz youtube-dl-f47fdb9564d3ca1c0fa70ed6031148ec908fdc7b.zip |
[utils] Add {expected_type} and Iterable support to traverse_obj()
Diffstat (limited to 'youtube_dl/utils.py')
-rw-r--r-- | youtube_dl/utils.py | 213 |
1 files changed, 142 insertions, 71 deletions
diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py index 83f67bd95..dbdbe5f59 100644 --- a/youtube_dl/utils.py +++ b/youtube_dl/utils.py @@ -16,6 +16,7 @@ import email.header import errno import functools import gzip +import inspect import io import itertools import json @@ -3881,7 +3882,7 @@ def detect_exe_version(output, version_re=None, unrecognized='present'): return unrecognized -class LazyList(compat_collections_abc.Sequence): +class LazyList(compat_collections_abc.Iterable): """Lazy immutable list from an iterable Note that slices of a LazyList are lists and not LazyList""" @@ -4223,10 +4224,16 @@ def multipart_encode(data, boundary=None): return out, content_type -def variadic(x, allowed_types=(compat_str, bytes, dict)): - if not isinstance(allowed_types, tuple) and isinstance(allowed_types, compat_collections_abc.Iterable): +def is_iterable_like(x, allowed_types=compat_collections_abc.Iterable, blocked_types=NO_DEFAULT): + if blocked_types is NO_DEFAULT: + blocked_types = (compat_str, bytes, compat_collections_abc.Mapping) + return isinstance(x, allowed_types) and not isinstance(x, blocked_types) + + +def variadic(x, allowed_types=NO_DEFAULT): + if isinstance(allowed_types, compat_collections_abc.Iterable): allowed_types = tuple(allowed_types) - return x if isinstance(x, compat_collections_abc.Iterable) and not isinstance(x, allowed_types) else (x,) + return x if is_iterable_like(x, blocked_types=allowed_types) else (x,) def dict_get(d, key_or_keys, default=None, skip_false_values=True): @@ -5993,7 +6000,7 @@ def clean_podcast_url(url): def traverse_obj(obj, *paths, **kwargs): """ - Safely traverse nested `dict`s and `Sequence`s + Safely traverse nested `dict`s and `Iterable`s >>> obj = [{}, {"key": "value"}] >>> traverse_obj(obj, (1, "key")) @@ -6001,14 +6008,17 @@ def traverse_obj(obj, *paths, **kwargs): Each of the provided `paths` is tested and the first producing a valid result will be returned. The next path will also be tested if the path branched but no results could be found. - Supported values for traversal are `Mapping`, `Sequence` and `re.Match`. - A value of None is treated as the absence of a value. + Supported values for traversal are `Mapping`, `Iterable` and `re.Match`. + Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded. The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`. The keys in the path can be one of: - `None`: Return the current object. - - `str`/`int`: Return `obj[key]`. For `re.Match, return `obj.group(key)`. + - `set`: Requires the only item in the set to be a type or function, + like `{type}`/`{func}`. If a `type`, returns only values + of this type. If a function, returns `func(obj)`. + - `str`/`int`: Return `obj[key]`. For `re.Match`, return `obj.group(key)`. - `slice`: Branch out and return all values in `obj[key]`. - `Ellipsis`: Branch out and return a list of all values. - `tuple`/`list`: Branch out and return a list of all matching values. @@ -6016,6 +6026,9 @@ def traverse_obj(obj, *paths, **kwargs): - `function`: Branch out and return values filtered by the function. Read as: `[value for key, value in obj if function(key, value)]`. For `Sequence`s, `key` is the index of the value. + For `Iterable`s, `key` is the enumeration count of the value. + For `re.Match`es, `key` is the group number (0 = full match) + as well as additionally any group names, if given. - `dict` Transform the current object and return a matching dict. Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`. @@ -6024,8 +6037,12 @@ def traverse_obj(obj, *paths, **kwargs): @params paths Paths which to traverse by. Keyword arguments: @param default Value to return if the paths do not match. + If the last key in the path is a `dict`, it will apply to each value inside + the dict instead, depth first. Try to avoid if using nested `dict` keys. @param expected_type If a `type`, only accept final values of this type. If any other callable, try to call the function on each result. + If the last key in the path is a `dict`, it will apply to each value inside + the dict instead, recursively. This does respect branching paths. @param get_all If `False`, return the first matching result, otherwise all matching ones. @param casesense If `False`, consider string dictionary keys as case insensitive. @@ -6036,12 +6053,15 @@ def traverse_obj(obj, *paths, **kwargs): @param _traverse_string Whether to traverse into objects as strings. If `True`, any non-compatible object will first be converted into a string and then traversed into. + The return value of that path will be a string instead, + not respecting any further branching. @returns The result of the object traversal. If successful, `get_all=True`, and the path branches at least once, then a list of results is returned instead. A list is always returned if the last path branches and no `default` is given. + If a path ends on a `dict` that result will always be a `dict`. """ # parameter defaults @@ -6055,7 +6075,6 @@ def traverse_obj(obj, *paths, **kwargs): # instant compat str = compat_str - is_sequence = lambda x: isinstance(x, compat_collections_abc.Sequence) and not isinstance(x, (str, bytes)) casefold = lambda k: compat_casefold(k) if isinstance(k, str) else k if isinstance(expected_type, type): @@ -6063,128 +6082,180 @@ def traverse_obj(obj, *paths, **kwargs): else: type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,)) + def lookup_or_none(v, k, getter=None): + try: + return getter(v, k) if getter else v[k] + except IndexError: + return None + def from_iterable(iterables): # chain.from_iterable(['ABC', 'DEF']) --> A B C D E F for it in iterables: for item in it: yield item - def apply_key(key, obj): - if obj is None: - return + def apply_key(key, obj, is_last): + branching = False + + if obj is None and _traverse_string: + if key is Ellipsis or callable(key) or isinstance(key, slice): + branching = True + result = () + else: + result = None elif key is None: - yield obj + result = obj + + elif isinstance(key, set): + assert len(key) == 1, 'Set should only be used to wrap a single item' + item = next(iter(key)) + if isinstance(item, type): + result = obj if isinstance(obj, item) else None + else: + result = try_call(item, args=(obj,)) elif isinstance(key, (list, tuple)): - for branch in key: - _, result = apply_path(obj, branch) - for item in result: - yield item + branching = True + result = from_iterable( + apply_path(obj, branch, is_last)[0] for branch in key) elif key is Ellipsis: - result = [] + branching = True if isinstance(obj, compat_collections_abc.Mapping): result = obj.values() - elif is_sequence(obj): + elif is_iterable_like(obj): result = obj elif isinstance(obj, compat_re_Match): result = obj.groups() elif _traverse_string: + branching = False result = str(obj) - for item in result: - yield item + else: + result = () elif callable(key): - if is_sequence(obj): - iter_obj = enumerate(obj) - elif isinstance(obj, compat_collections_abc.Mapping): + branching = True + if isinstance(obj, compat_collections_abc.Mapping): iter_obj = obj.items() + elif is_iterable_like(obj): + iter_obj = enumerate(obj) elif isinstance(obj, compat_re_Match): - iter_obj = enumerate(itertools.chain([obj.group()], obj.groups())) + iter_obj = itertools.chain( + enumerate(itertools.chain((obj.group(),), obj.groups())), + obj.groupdict().items()) elif _traverse_string: + branching = False iter_obj = enumerate(str(obj)) else: - return - for item in (v for k, v in iter_obj if try_call(key, args=(k, v))): - yield item + iter_obj = () + + result = (v for k, v in iter_obj if try_call(key, args=(k, v))) + if not branching: # string traversal + result = ''.join(result) elif isinstance(key, dict): - iter_obj = ((k, _traverse_obj(obj, v)) for k, v in key.items()) - yield dict((k, v if v is not None else default) for k, v in iter_obj - if v is not None or default is not NO_DEFAULT) + iter_obj = ((k, _traverse_obj(obj, v, False, is_last)) for k, v in key.items()) + result = dict((k, v if v is not None else default) for k, v in iter_obj + if v is not None or default is not NO_DEFAULT) or None elif isinstance(obj, compat_collections_abc.Mapping): - yield (obj.get(key) if casesense or (key in obj) - else next((v for k, v in obj.items() if casefold(k) == key), None)) + result = (try_call(obj.get, args=(key,)) + if casesense or try_call(obj.__contains__, args=(key,)) + else next((v for k, v in obj.items() if casefold(k) == key), None)) elif isinstance(obj, compat_re_Match): + result = None if isinstance(key, int) or casesense: - try: - yield obj.group(key) - return - except IndexError: - pass - if not isinstance(key, str): - return + result = lookup_or_none(obj, key, getter=compat_re_Match.group) - yield next((v for k, v in obj.groupdict().items() if casefold(k) == key), None) + elif isinstance(key, str): + result = next((v for k, v in obj.groupdict().items() + if casefold(k) == key), None) else: - if _is_user_input: - key = (int_or_none(key) if ':' not in key - else slice(*map(int_or_none, key.split(':')))) + result = None + if isinstance(key, (int, slice)): + if is_iterable_like(obj, compat_collections_abc.Sequence): + branching = isinstance(key, slice) + result = lookup_or_none(obj, key) + elif _traverse_string: + result = lookup_or_none(str(obj), key) + + return branching, result if branching else (result,) + + def lazy_last(iterable): + iterator = iter(iterable) + prev = next(iterator, NO_DEFAULT) + if prev is NO_DEFAULT: + return - if not isinstance(key, (int, slice)): - return + for item in iterator: + yield False, prev + prev = item - if not is_sequence(obj): - if not _traverse_string: - return - obj = str(obj) + yield True, prev - try: - yield obj[key] - except IndexError: - pass - - def apply_path(start_obj, path): + def apply_path(start_obj, path, test_type): objs = (start_obj,) has_branched = False - for key in variadic(path): - if _is_user_input and key == ':': - key = Ellipsis + key = None + for last, key in lazy_last(variadic(path, (str, bytes, dict, set))): + if _is_user_input and isinstance(key, str): + if key == ':': + key = Ellipsis + elif ':' in key: + key = slice(*map(int_or_none, key.split(':'))) + elif int_or_none(key) is not None: + key = int(key) if not casesense and isinstance(key, str): key = compat_casefold(key) - if key is Ellipsis or isinstance(key, (list, tuple)) or callable(key): - has_branched = True + if __debug__ and callable(key): + # Verify function signature + inspect.getcallargs(key, None, None) + + new_objs = [] + for obj in objs: + branching, results = apply_key(key, obj, last) + has_branched |= branching + new_objs.append(results) + + objs = from_iterable(new_objs) - key_func = functools.partial(apply_key, key) - objs = from_iterable(map(key_func, objs)) + if test_type and not isinstance(key, (dict, list, tuple)): + objs = map(type_test, objs) - return has_branched, objs + return objs, has_branched, isinstance(key, dict) - def _traverse_obj(obj, path, use_list=True): - has_branched, results = apply_path(obj, path) - results = LazyList(x for x in map(type_test, results) if x is not None) + def _traverse_obj(obj, path, allow_empty, test_type): + results, has_branched, is_dict = apply_path(obj, path, test_type) + results = LazyList(x for x in results if x not in (None, {})) if get_all and has_branched: - return results.exhaust() if results or use_list else None + if results: + return results.exhaust() + if allow_empty: + return [] if default is NO_DEFAULT else default + return None - return results[0] if results else None + return results[0] if results else {} if allow_empty and is_dict else None for index, path in enumerate(paths, 1): - use_list = default is NO_DEFAULT and index == len(paths) - result = _traverse_obj(obj, path, use_list) + result = _traverse_obj(obj, path, index == len(paths), True) if result is not None: return result return None if default is NO_DEFAULT else default +def T(x): + """ For use in yt-dl instead of {type} or set((type,)) """ + return set((x,)) + + def get_first(obj, keys, **kwargs): return traverse_obj(obj, (Ellipsis,) + tuple(variadic(keys)), get_all=False, **kwargs) |