diff options
Diffstat (limited to 'libs')
-rw-r--r-- | libs/subliminal_patch/core.py | 56 |
1 files changed, 52 insertions, 4 deletions
diff --git a/libs/subliminal_patch/core.py b/libs/subliminal_patch/core.py index c31d5ecd0..b649e4288 100644 --- a/libs/subliminal_patch/core.py +++ b/libs/subliminal_patch/core.py @@ -153,9 +153,52 @@ class _Blacklist(list): return not blacklisted +class _LanguageEquals(list): + """ An optional config field for the pool. It will treat a couple of languages as equal for + list-subtitles operations. It's optional; its methods won't do anything if an empy list + is set. + + Example usage: [(language_instance, language_instance), ...]""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + for item in self: + if len(item) != 2 or not any(isinstance(i, Language) for i in item): + raise ValueError(f"Not a valid equal tuple: {item}") + + def check_set(self, items: set): + """ Check a set of languages. For example, if the set is {Language('es')} and one of the + equals of the instance is (Language('es'), Language('es', 'MX')), the set will now have + to {Language('es'), Language('es', 'MX')}. + + It will return a copy of the original set to avoid messing up outside its scope. + + Note that hearing_impaired and forced language attributes are not yet tested. + """ + to_add = [] + for equals in self: + from_, to_ = equals + if from_ in items: + logger.debug("Adding %s to %s", to_, items) + to_add.append(to_) + + new_items = items.copy() + new_items.update(to_add) + logger.debug("New set: %s", new_items) + return new_items + + def update_subtitle(self, subtitle): + for equals in self: + from_, to_ = equals + if from_ == subtitle.language: + logger.debug("Updating language for %s (to %s)", subtitle, to_) + subtitle.language = to_ + break + + class SZProviderPool(ProviderPool): def __init__(self, providers=None, provider_configs=None, blacklist=None, ban_list=None, throttle_callback=None, - pre_download_hook=None, post_download_hook=None, language_hook=None): + pre_download_hook=None, post_download_hook=None, language_hook=None, language_equals=None): #: Name of providers to use self.providers = set(providers or []) @@ -170,6 +213,8 @@ class SZProviderPool(ProviderPool): #: Should be a dict of 2 lists of strings self.ban_list = _Banlist(**(ban_list or {'must_contain': [], 'must_not_contain': []})) + self.lang_equals = _LanguageEquals(language_equals or []) + self.throttle_callback = throttle_callback self.pre_download_hook = pre_download_hook @@ -185,7 +230,7 @@ class SZProviderPool(ProviderPool): self.provider_configs = _ProviderConfigs(self) self.provider_configs.update(provider_configs or {}) - def update(self, providers, provider_configs, blacklist, ban_list): + def update(self, providers, provider_configs, blacklist, ban_list, language_equals=None): # Check if the pool was initialized enough hours ago self._check_lifetime() @@ -222,6 +267,7 @@ class SZProviderPool(ProviderPool): self.blacklist = _Blacklist(blacklist or []) self.ban_list = _Banlist(**ban_list or {'must_contain': [], 'must_not_contain': []}) + self.lang_equals = _LanguageEquals(language_equals or []) return updated @@ -299,7 +345,7 @@ class SZProviderPool(ProviderPool): return [] # check supported languages - provider_languages = provider_registry[provider].languages & use_languages + provider_languages = self.lang_equals.check_set(set(provider_registry[provider].languages)) & use_languages if not provider_languages: logger.info('Skipping provider %r: no language to search for', provider) return [] @@ -312,6 +358,8 @@ class SZProviderPool(ProviderPool): seen = [] out = [] for s in results: + self.lang_equals.update_subtitle(s) + if not self.blacklist.is_valid(provider, s): continue @@ -569,7 +617,7 @@ class SZProviderPool(ProviderPool): continue # add the languages for this provider - languages.append({'provider': name, 'languages': provider_languages}) + languages.append({'provider': name, 'languages': self.lang_equals.check_set(set(provider_languages))}) return languages |