summaryrefslogtreecommitdiffhomepage
path: root/libs
diff options
context:
space:
mode:
Diffstat (limited to 'libs')
-rw-r--r--libs/subliminal_patch/core.py56
1 files changed, 52 insertions, 4 deletions
diff --git a/libs/subliminal_patch/core.py b/libs/subliminal_patch/core.py
index c31d5ecd0..b649e4288 100644
--- a/libs/subliminal_patch/core.py
+++ b/libs/subliminal_patch/core.py
@@ -153,9 +153,52 @@ class _Blacklist(list):
return not blacklisted
+class _LanguageEquals(list):
+ """ An optional config field for the pool. It will treat a couple of languages as equal for
+ list-subtitles operations. It's optional; its methods won't do anything if an empy list
+ is set.
+
+ Example usage: [(language_instance, language_instance), ...]"""
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ for item in self:
+ if len(item) != 2 or not any(isinstance(i, Language) for i in item):
+ raise ValueError(f"Not a valid equal tuple: {item}")
+
+ def check_set(self, items: set):
+ """ Check a set of languages. For example, if the set is {Language('es')} and one of the
+ equals of the instance is (Language('es'), Language('es', 'MX')), the set will now have
+ to {Language('es'), Language('es', 'MX')}.
+
+ It will return a copy of the original set to avoid messing up outside its scope.
+
+ Note that hearing_impaired and forced language attributes are not yet tested.
+ """
+ to_add = []
+ for equals in self:
+ from_, to_ = equals
+ if from_ in items:
+ logger.debug("Adding %s to %s", to_, items)
+ to_add.append(to_)
+
+ new_items = items.copy()
+ new_items.update(to_add)
+ logger.debug("New set: %s", new_items)
+ return new_items
+
+ def update_subtitle(self, subtitle):
+ for equals in self:
+ from_, to_ = equals
+ if from_ == subtitle.language:
+ logger.debug("Updating language for %s (to %s)", subtitle, to_)
+ subtitle.language = to_
+ break
+
+
class SZProviderPool(ProviderPool):
def __init__(self, providers=None, provider_configs=None, blacklist=None, ban_list=None, throttle_callback=None,
- pre_download_hook=None, post_download_hook=None, language_hook=None):
+ pre_download_hook=None, post_download_hook=None, language_hook=None, language_equals=None):
#: Name of providers to use
self.providers = set(providers or [])
@@ -170,6 +213,8 @@ class SZProviderPool(ProviderPool):
#: Should be a dict of 2 lists of strings
self.ban_list = _Banlist(**(ban_list or {'must_contain': [], 'must_not_contain': []}))
+ self.lang_equals = _LanguageEquals(language_equals or [])
+
self.throttle_callback = throttle_callback
self.pre_download_hook = pre_download_hook
@@ -185,7 +230,7 @@ class SZProviderPool(ProviderPool):
self.provider_configs = _ProviderConfigs(self)
self.provider_configs.update(provider_configs or {})
- def update(self, providers, provider_configs, blacklist, ban_list):
+ def update(self, providers, provider_configs, blacklist, ban_list, language_equals=None):
# Check if the pool was initialized enough hours ago
self._check_lifetime()
@@ -222,6 +267,7 @@ class SZProviderPool(ProviderPool):
self.blacklist = _Blacklist(blacklist or [])
self.ban_list = _Banlist(**ban_list or {'must_contain': [], 'must_not_contain': []})
+ self.lang_equals = _LanguageEquals(language_equals or [])
return updated
@@ -299,7 +345,7 @@ class SZProviderPool(ProviderPool):
return []
# check supported languages
- provider_languages = provider_registry[provider].languages & use_languages
+ provider_languages = self.lang_equals.check_set(set(provider_registry[provider].languages)) & use_languages
if not provider_languages:
logger.info('Skipping provider %r: no language to search for', provider)
return []
@@ -312,6 +358,8 @@ class SZProviderPool(ProviderPool):
seen = []
out = []
for s in results:
+ self.lang_equals.update_subtitle(s)
+
if not self.blacklist.is_valid(provider, s):
continue
@@ -569,7 +617,7 @@ class SZProviderPool(ProviderPool):
continue
# add the languages for this provider
- languages.append({'provider': name, 'languages': provider_languages})
+ languages.append({'provider': name, 'languages': self.lang_equals.check_set(set(provider_languages))})
return languages