Improve provider configs updates

This commit is contained in:
Vitiko 2022-05-13 01:10:10 -04:00
parent 1dff555fc8
commit 23a5ab9b0e
2 changed files with 54 additions and 28 deletions

View File

@ -269,5 +269,4 @@ def _get_language_obj(profile_id):
def _set_forced_providers(also_forced, pool):
if also_forced:
pool.provider_configs['podnapisi']['also_foreign'] = True
pool.provider_configs['opensubtitles']['also_foreign'] = True
pool.provider_configs.update({'podnapisi': {'also_foreign': True}, 'opensubtitles': {'also_foreign': True}})

View File

@ -70,15 +70,53 @@ def remove_crap_from_fn(fn):
return REMOVE_CRAP_FROM_FILENAME.sub(repl, fn)
class _ProviderConfigs(dict):
def __init__(self, pool, *args, **kwargs):
super().__init__(*args, **kwargs)
self._pool = pool
def update(self, items):
updated = set()
# Restart providers with new configs
for key, val in items.items():
# Don't restart providers that are not enabled
if key not in self._pool.providers:
continue
# key: provider's name; val: config dict
registered_val = self.get(key)
if registered_val is None or registered_val == val:
continue
updated.add(key)
# The new dict might be a partial dict
registered_val.update(val)
logger.debug("Config changed. Restarting provider: %s", key)
try:
provider = provider_registry[key](**registered_val) # type: ignore
provider.initialize()
except Exception as error:
self._pool.throttle_callback(key, error)
else:
self._pool.initialized_providers[key] = provider
if updated:
logger.debug("Providers with config updates: %s", updated)
else:
logger.debug("No provider config updates")
return super().update(items)
class SZProviderPool(ProviderPool):
def __init__(self, providers=None, provider_configs=None, blacklist=None, ban_list=None, throttle_callback=None,
pre_download_hook=None, post_download_hook=None, language_hook=None):
#: Name of providers to use
self.providers = set(providers or [])
#: Provider configuration
self.provider_configs = provider_configs or {}
#: Initialized providers
self.initialized_providers = {}
@ -101,6 +139,10 @@ class SZProviderPool(ProviderPool):
if not self.throttle_callback:
self.throttle_callback = lambda x, y: x
#: Provider configuration
self.provider_configs = _ProviderConfigs(self)
self.provider_configs.update(provider_configs or {})
def update(self, providers, provider_configs, blacklist, ban_list):
# Check if the pool was initialized enough hours ago
self._check_lifetime()
@ -130,29 +172,8 @@ class SZProviderPool(ProviderPool):
self.providers.difference_update(removed_providers)
self.providers.update(list(providers))
# Restart providers with new configs
for key, val in provider_configs.items():
# Don't restart providers that are not enabled
if key not in self.providers:
continue
# key: provider's name; val: config dict
old_val = self.provider_configs.get(key)
if old_val == val:
continue
logger.debug("Restarting provider: %s", key)
try:
provider = provider_registry[key](**val)
provider.initialize()
except Exception as error:
self.throttle_callback(key, error)
else:
self.initialized_providers[key] = provider
updated = True
self.provider_configs = provider_configs
# self.provider_configs = provider_configs
self.provider_configs.update(provider_configs)
self.blacklist = blacklist or []
self.ban_list = ban_list or {'must_contain': [], 'must_not_contain': []}
@ -540,6 +561,12 @@ class SZProviderPool(ProviderPool):
return video_types
def __repr__(self):
return (
f"{self.__class__.__name__} [{len(self.providers)} providers ({len(self.initialized_providers)} "
f"initialized; {len(self.discarded_providers)} discarded)]"
)
class SZAsyncProviderPool(SZProviderPool):
"""Subclass of :class:`ProviderPool` with asynchronous support for :meth:`~ProviderPool.list_subtitles`.