diff --git a/yt_dlp/utils.py b/yt_dlp/utils.py index 795c5632f..d1be485f8 100644 --- a/yt_dlp/utils.py +++ b/yt_dlp/utils.py @@ -6225,9 +6225,14 @@ def load_plugins(name, suffix, namespace): def traverse_obj( - obj, *key_list, default=None, expected_type=None, + obj, *path_list, default=None, expected_type=None, casesense=True, is_user_input=False, traverse_string=False): ''' Traverse nested list/dict/tuple + @param path_list A list of paths which are checked one by one. + Each path is a list of keys where each key is a string, + a tuple of strings or "...". When a tuple is given, + all the keys given in the tuple are traversed, and + "..." traverses all the keys in the object @param default Default value to return @param expected_type Only accept final value of this type @param casesense Whether to consider dictionary keys as case sensitive @@ -6235,23 +6240,38 @@ def traverse_obj( strings are converted to int/slice if necessary @param traverse_string Whether to traverse inside strings. If True, any non-compatible object will also be converted into a string + # TODO: Write tests ''' if not casesense: _lower = lambda k: k.lower() if isinstance(k, str) else k - key_list = ((_lower(k) for k in keys) for keys in key_list) + path_list = (map(_lower, variadic(path)) for path in path_list) - def _traverse_obj(obj, keys): - for key in list(keys): - if isinstance(obj, dict): + def _traverse_obj(obj, path, _current_depth=0): + nonlocal depth + path = tuple(variadic(path)) + for i, key in enumerate(path): + if isinstance(key, (list, tuple)): + obj = [_traverse_obj(obj, sub_key, _current_depth) for sub_key in key] + key = ... + if key is ...: + obj = (obj.values() if isinstance(obj, dict) + else obj if isinstance(obj, (list, tuple, LazyList)) + else str(obj) if traverse_string else []) + _current_depth += 1 + depth = max(depth, _current_depth) + return [_traverse_obj(inner_obj, path[i + 1:], _current_depth) for inner_obj in obj] + elif isinstance(obj, dict): obj = (obj.get(key) if casesense or (key in obj) else next((v for k, v in obj.items() if _lower(k) == key), None)) else: if is_user_input: key = (int_or_none(key) if ':' not in key else slice(*map(int_or_none, key.split(':')))) + if key == slice(None): + return _traverse_obj(obj, (..., *path[i + 1:])) if not isinstance(key, (int, slice)): return None - if not isinstance(obj, (list, tuple)): + if not isinstance(obj, (list, tuple, LazyList)): if not traverse_string: return None obj = str(obj) @@ -6261,10 +6281,18 @@ def traverse_obj( return None return obj - for keys in key_list: - val = _traverse_obj(obj, keys) + for path in path_list: + depth = 0 + val = _traverse_obj(obj, path) if val is not None: - if expected_type is None or isinstance(val, expected_type): + if depth: + for _ in range(depth - 1): + val = itertools.chain.from_iterable(filter(None, val)) + val = (list(filter(None, val)) if expected_type is None + else [v for v in val if isinstance(v, expected_type)]) + if val: + return val + elif expected_type is None or isinstance(val, expected_type): return val return default