[utils] Improve `traverse_obj`

This commit is contained in:
pukkandan 2021-07-21 11:17:27 +05:30
parent 11f9be0912
commit 352d63fdb5
No known key found for this signature in database
GPG Key ID: 0F00D95A001F4698
2 changed files with 20 additions and 11 deletions

View File

@ -1929,10 +1929,11 @@ class YoutubeIE(YoutubeBaseInfoExtractor):
return sts return sts
def _mark_watched(self, video_id, player_responses): def _mark_watched(self, video_id, player_responses):
playback_url = url_or_none((traverse_obj( playback_url = traverse_obj(
player_responses, ('playbackTracking', 'videostatsPlaybackUrl', 'baseUrl'), player_responses, (..., 'playbackTracking', 'videostatsPlaybackUrl', 'baseUrl'),
expected_type=str) or [None])[0]) expected_type=url_or_none, get_all=False)
if not playback_url: if not playback_url:
self.report_warning('Unable to mark watched')
return return
parsed_playback_url = compat_urlparse.urlparse(playback_url) parsed_playback_url = compat_urlparse.urlparse(playback_url)
qs = compat_urlparse.parse_qs(parsed_playback_url.query) qs = compat_urlparse.parse_qs(parsed_playback_url.query)
@ -2606,8 +2607,7 @@ class YoutubeIE(YoutubeBaseInfoExtractor):
self._get_requested_clients(url, smuggled_data), self._get_requested_clients(url, smuggled_data),
video_id, webpage, master_ytcfg, player_url, identity_token)) video_id, webpage, master_ytcfg, player_url, identity_token))
get_first = lambda obj, keys, **kwargs: ( get_first = lambda obj, keys, **kwargs: traverse_obj(obj, (..., *variadic(keys)), **kwargs, get_all=False)
traverse_obj(obj, (..., *variadic(keys)), **kwargs) or [None])[0]
playability_statuses = traverse_obj( playability_statuses = traverse_obj(
player_responses, (..., 'playabilityStatus'), expected_type=dict, default=[]) player_responses, (..., 'playabilityStatus'), expected_type=dict, default=[])

View File

@ -6225,7 +6225,7 @@ def load_plugins(name, suffix, namespace):
def traverse_obj( def traverse_obj(
obj, *path_list, default=None, expected_type=None, obj, *path_list, default=None, expected_type=None, get_all=True,
casesense=True, is_user_input=False, traverse_string=False): casesense=True, is_user_input=False, traverse_string=False):
''' Traverse nested list/dict/tuple ''' Traverse nested list/dict/tuple
@param path_list A list of paths which are checked one by one. @param path_list A list of paths which are checked one by one.
@ -6234,7 +6234,8 @@ def traverse_obj(
all the keys given in the tuple are traversed, and all the keys given in the tuple are traversed, and
"..." traverses all the keys in the object "..." traverses all the keys in the object
@param default Default value to return @param default Default value to return
@param expected_type Only accept final value of this type @param expected_type Only accept final value of this type (Can also be any callable)
@param get_all Return all the values obtained from a path or only the first one
@param casesense Whether to consider dictionary keys as case sensitive @param casesense Whether to consider dictionary keys as case sensitive
@param is_user_input Whether the keys are generated from user input. If True, @param is_user_input Whether the keys are generated from user input. If True,
strings are converted to int/slice if necessary strings are converted to int/slice if necessary
@ -6281,6 +6282,13 @@ def traverse_obj(
return None return None
return obj return obj
if isinstance(expected_type, type):
type_test = lambda val: val if isinstance(val, expected_type) else None
elif expected_type is not None:
type_test = expected_type
else:
type_test = lambda val: val
for path in path_list: for path in path_list:
depth = 0 depth = 0
val = _traverse_obj(obj, path) val = _traverse_obj(obj, path)
@ -6288,12 +6296,13 @@ def traverse_obj(
if depth: if depth:
for _ in range(depth - 1): for _ in range(depth - 1):
val = itertools.chain.from_iterable(v for v in val if v is not None) val = itertools.chain.from_iterable(v for v in val if v is not None)
val = ([v for v in val if v is not None] if expected_type is None val = [v for v in map(type_test, val) if v is not None]
else [v for v in val if isinstance(v, expected_type)])
if val: if val:
return val if get_all else val[0]
else:
val = type_test(val)
if val is not None:
return val return val
elif expected_type is None or isinstance(val, expected_type):
return val
return default return default