From b87422d7b810d44c33379bd84b64229979acbaa3 Mon Sep 17 00:00:00 2001 From: miro Date: Tue, 23 Jul 2024 18:52:53 +0100 Subject: [PATCH] refactor/decorators make code more portable and easier to understand and maintain by using decorators for the autotranslation steps move class properties to init method add configurable language detector plugin and new util methods for translation and detection --- ovos_plugin_manager/templates/solvers.py | 369 ++++++++++++++--------- 1 file changed, 233 insertions(+), 136 deletions(-) diff --git a/ovos_plugin_manager/templates/solvers.py b/ovos_plugin_manager/templates/solvers.py index 92522b9d..7269b8d6 100644 --- a/ovos_plugin_manager/templates/solvers.py +++ b/ovos_plugin_manager/templates/solvers.py @@ -2,69 +2,155 @@ # QuestionSolver Improvements and other solver classes are OVOS originals licensed under Apache 2.0 import abc -from typing import Optional, List, Iterable, Tuple +from functools import wraps +from typing import Optional, List, Iterable, Tuple, Dict, Union from json_database import JsonStorageXDG -from ovos_plugin_manager.language import OVOSLangTranslationFactory -from ovos_utils.log import LOG +from ovos_utils.log import LOG, log_deprecation from ovos_utils.xdg_utils import xdg_cache_home from quebra_frases import sentence_tokenize +from ovos_plugin_manager.language import OVOSLangTranslationFactory, OVOSLangDetectionFactory +from ovos_plugin_manager.templates.language import LanguageTranslator, LanguageDetector + + +def _deprecate_context2lang(): + def func_decorator(func): + + @wraps(func) + def func_wrapper(*args, **kwargs): + if "context" in kwargs: + # NOTE: deprecate this at same time we + # standardize plugin namespaces to opm.XXX + log_deprecation("'context' kwarg has been deprecated, " + "please pass 'lang' as it's own kwarg instead", "0.1.0") + if "lang" in kwargs["context"] and "lang" not in kwargs: + kwargs["lang"] = kwargs["context"]["lang"] + return func(*args, **kwargs) + + return func_wrapper + + return func_decorator + + +def _auto_translate(translate_keys: List[str]): + """ Decorator to ensure all kwargs in 'translate_keys' are translated to self.default_lang + NOTE: not meant to be used outside solver plugins""" + + def func_decorator(func): + + @wraps(func) + def func_wrapper(*args, **kwargs): + solver: AbstractSolver = args[0] + # check if translation is enabled + if not solver.enable_tx: + return func(*args, **kwargs) + + # detect language if needed + lang = kwargs.get("lang") + if lang is None: + for k in translate_keys: + v = kwargs.get(k) + if isinstance(v, str): + lang = solver.detect_language(v) + break + + # check if translation can be skipped + if any([lang is None, + lang == solver.default_lang, + lang in solver.supported_langs]): + return func(*args, **kwargs) + + # translate input keys + for k in translate_keys: + v = kwargs.get(k) + if isinstance(v, str): + kwargs[k] = solver.translate(v, + source_lang=lang, + target_lang=solver.default_lang) + elif isinstance(v, list): + kwargs[k] = solver.translate_list(v, + source_lang=lang, + target_lang=solver.default_lang) + elif isinstance(v, dict): + kwargs[k] = solver.translate_dict(v, + source_lang=lang, + target_lang=solver.default_lang) + + output = func(*args, **kwargs) + + # reverse translate + if isinstance(output, str): + return solver.translate(output, + source_lang=solver.default_lang, + target_lang=lang) + elif isinstance(output, list): + return solver.translate_list(output, + source_lang=solver.default_lang, + target_lang=lang) + elif isinstance(output, dict): + return solver.translate_dict(output, + source_lang=solver.default_lang, + target_lang=lang) + return output + + return func_wrapper + + return func_decorator + class AbstractSolver: - # these are defined by the plugin developer - priority = 50 - enable_tx = False - enable_cache = False - - def __init__(self, config=None, translator=None, *args, **kwargs): - if args or kwargs: - LOG.warning("solver plugins init signature changed, please update to accept config=None, translator=None. " - "an exception will be raised in next stable release") - for arg in args: - if isinstance(arg, str): - kwargs["name"] = arg - if isinstance(arg, int): - kwargs["priority"] = arg - if "priority" in kwargs: - self.priority = kwargs["priority"] - if "enable_tx" in kwargs: - self.enable_tx = kwargs["enable_tx"] - if "enable_cache" in kwargs: - self.enable_cache = kwargs["enable_cache"] + + def __init__(self, config=None, + translator: Optional[LanguageTranslator] = None, + detector: Optional[LanguageDetector] = None, + priority=50, + enable_tx=False, + enable_cache=False, + *args, **kwargs): + self.priority = priority + self.enable_tx = enable_tx + self.enable_cache = enable_cache self.config = config or {} self.supported_langs = self.config.get("supported_langs") or [] self.default_lang = self.config.get("lang", "en") if self.default_lang not in self.supported_langs: self.supported_langs.insert(0, self.default_lang) self.translator = translator or OVOSLangTranslationFactory.create() + self.detector = detector or OVOSLangDetectionFactory.create() + LOG.debug(f"{self.__class__.__name__} default language: {self.default_lang}") @staticmethod - def sentence_split(text: str, max_sentences: int=25) -> List[str]: + def sentence_split(text: str, max_sentences: int = 25) -> List[str]: return sentence_tokenize(text)[:max_sentences] - def _get_user_lang(self, context: Optional[dict] = None, - lang: Optional[str] = None) -> str: - context = context or {} - lang = lang or context.get("lang") or self.default_lang - lang = lang.split("-")[0] - return lang - - def _tx_query(self, query: str, - context: Optional[dict] = None, lang: Optional[str] = None): - if not self.enable_tx: - return query, context, lang - context = context or {} - lang = user_lang = self._get_user_lang(context, lang) - - # translate input to default lang - if user_lang not in self.supported_langs: - lang = self.default_lang - query = self.translator.translate(query, lang, user_lang) - - context["lang"] = lang - - return query, context, lang + def detect_language(self, text: str) -> str: + return self.detector.detect(text) + + def translate(self, text: str, + target_lang: Optional[str] = None, + source_lang: Optional[str] = None) -> str: + source_lang = source_lang or self.detect_language(text) + target_lang = target_lang or self.default_lang + if source_lang.split("-")[0] == target_lang.split("-")[0]: + return text # skip translation + return self.translator.translate(text, + target=target_lang, + source=source_lang) + + def translate_list(self, data: List[str], + target_lang: Optional[str] = None, + source_lang: Optional[str] = None) -> str: + return self.translator.translate_list(data, + lang_tgt=target_lang, + lang_src=source_lang) + + def translate_dict(self, data: Dict[str, str], + target_lang: Optional[str] = None, + source_lang: Optional[str] = None) -> str: + return self.translator.translate_dict(data, + lang_tgt=target_lang, + lang_src=source_lang) def shutdown(self): """ module specific shutdown method """ @@ -75,58 +161,72 @@ class QuestionSolver(AbstractSolver): """free form unscontrained spoken question solver handling automatic translation back and forth as needed""" - def __init__(self, config=None, translator=None, *args, **kwargs): - super().__init__(config, translator, *args, **kwargs) + def __init__(self, config=None, + translator: Optional[LanguageTranslator] = None, + detector: Optional[LanguageDetector] = None, + priority=50, + enable_tx=False, + enable_cache=False, + *args, **kwargs): + super().__init__(config, translator, detector, + priority, enable_tx, enable_cache, + *args, **kwargs) name = kwargs.get("name") or self.__class__.__name__ if self.enable_cache: # cache contains raw data self.cache = JsonStorageXDG(name + "_data", xdg_folder=xdg_cache_home(), - subfolder="neon_solvers") + subfolder="ovos_solvers") # spoken cache contains dialogs self.spoken_cache = JsonStorageXDG(name, xdg_folder=xdg_cache_home(), - subfolder="neon_solvers") + subfolder="ovos_solvers") else: self.cache = self.spoken_cache = {} # plugin methods to override @abc.abstractmethod def get_spoken_answer(self, query: str, - context: Optional[dict] = None) -> str: + context: Optional[dict] = None, + lang: Optional[str] = None) -> str: """ query assured to be in self.default_lang return a single sentence text response """ raise NotImplementedError + @_deprecate_context2lang() def stream_utterances(self, query: str, - context: Optional[dict] = None) -> Iterable[str]: + context: Optional[dict] = None, + lang: Optional[str] = None) -> Iterable[str]: """streaming api, yields utterances as they become available each utterance can be sent to TTS before we have a full answer this is particularly helpful with LLMs""" - ans = self.get_spoken_answer(query, context) + ans = self.get_spoken_answer(query, context=context, lang=lang) for utt in self.sentence_split(ans): yield utt - def get_data(self, query: str, - context: Optional[dict] = None) -> dict: + @_deprecate_context2lang() + def get_data(self, query: str, context: Optional[dict] = None, + lang: Optional[str] = None) -> dict: """ query assured to be in self.default_lang return a dict response """ return {"answer": self.get_spoken_answer(query, context)} - def get_image(self, query: str, - context: Optional[dict] = None) -> str: + @_deprecate_context2lang() + def get_image(self, query: str, context: Optional[dict] = None, + lang: Optional[str] = None) -> str: """ query assured to be in self.default_lang return path/url to a single image to acompany spoken_answer """ return None - def get_expanded_answer(self, query: str, - context: Optional[dict] = None) -> List[dict]: + @_deprecate_context2lang() + def get_expanded_answer(self, query: str, context: Optional[dict] = None, + lang: Optional[str] = None) -> List[dict]: """ query assured to be in self.default_lang return a list of ordered steps to expand the answer, eg, "tell me more" @@ -142,21 +242,21 @@ def get_expanded_answer(self, query: str, "img": self.get_image(query, context)}] # user facing methods - def search(self, query: str, - context: Optional[dict] = None, lang: Optional[str] = None) -> dict: + @_deprecate_context2lang() + @_auto_translate(translate_keys=["query"]) + def search(self, query: str, context: Optional[dict] = None, lang: Optional[str] = None) -> dict: """ cache and auto translate query if needed returns translated response from self.get_data """ - user_lang = self._get_user_lang(context, lang) - query, context, lang = self._tx_query(query, context, lang) # read from cache if self.enable_cache and query in self.cache: data = self.cache[query] else: # search data try: - data = self.get_data(query, context) + data = self.get_data(query, + context=context, lang=lang) except: return {} @@ -164,51 +264,43 @@ def search(self, query: str, if self.enable_cache: self.cache[query] = data self.cache.store() - - # translate english output to user lang - if self.enable_tx and user_lang not in self.supported_langs: - return self.translator.translate_dict(data, user_lang, lang) return data - def visual_answer(self, query: str, - context: Optional[dict] = None, lang: Optional[str] = None) -> str: + @_deprecate_context2lang() + @_auto_translate(translate_keys=["query"]) + def visual_answer(self, query: str, context: Optional[dict] = None, lang: Optional[str] = None) -> str: """ cache and auto translate query if needed returns image that answers query """ - query, context, lang = self._tx_query(query, context, lang) - return self.get_image(query, context) + return self.get_image(query, + context=context, lang=lang) - def spoken_answer(self, query: str, - context: Optional[dict] = None, lang: Optional[str] = None) -> str: + @_deprecate_context2lang() + @_auto_translate(translate_keys=["query"]) + def spoken_answer(self, query: str, context: Optional[dict] = None, lang: Optional[str] = None) -> str: """ cache and auto translate query if needed returns chunked and translated response from self.get_spoken_answer """ - user_lang = self._get_user_lang(context, lang) - query, context, lang = self._tx_query(query, context, lang) - # get answer if self.enable_cache and query in self.spoken_cache: # read from cache summary = self.spoken_cache[query] else: - summary = self.get_spoken_answer(query, context) + summary = self.get_spoken_answer(query, + context=context, lang=lang) # save to cache if self.enable_cache: self.spoken_cache[query] = summary self.spoken_cache.store() - # summarize - if summary: - # translate english output to user lang - if self.enable_tx and user_lang not in self.supported_langs: - return self.translator.translate(summary, user_lang, lang) - else: - return summary + return summary - def long_answer(self, query: str, - context: Optional[dict] = None, lang: Optional[str] = None) -> List[dict]: + @_deprecate_context2lang() + @_auto_translate(translate_keys=["query"]) + def long_answer(self, query: str, context: Optional[dict] = None, + lang: Optional[str] = None) -> List[dict]: """ return a list of ordered steps to expand the answer, eg, "tell me more" step0 is always self.spoken_answer and self.get_image @@ -219,21 +311,19 @@ def long_answer(self, query: str, } :return: """ - user_lang = self._get_user_lang(context, lang) - query, context, lang = self._tx_query(query, context, lang) - steps = self.get_expanded_answer(query, context) + steps = self.get_expanded_answer(query, + context=context, lang=lang) # use spoken_answer as last resort if not steps: - summary = self.get_spoken_answer(query, context) + summary = self.get_spoken_answer(query, + context=context, lang=lang) if summary: - img = self.get_image(query, context) + img = self.get_image(query, + context=context, lang=lang) steps = [{"title": query, "summary": step0, "img": img} for step0 in self.sentence_split(summary, -1)] - # translate english output to user lang - if self.enable_tx and user_lang not in self.supported_langs: - return self.translator.translate_list(steps, user_lang, lang) return steps @@ -245,7 +335,8 @@ class TldrSolver(AbstractSolver): @abc.abstractmethod def get_tldr(self, document: str, - context: Optional[dict] = None) -> str: + context: Optional[dict] = None, + lang: Optional[str] = None) -> str: """ document assured to be in self.default_lang returns summary of provided document @@ -253,22 +344,18 @@ def get_tldr(self, document: str, raise NotImplementedError # user facing methods + + @_deprecate_context2lang() + @_auto_translate(translate_keys=["document"]) def tldr(self, document: str, context: Optional[dict] = None, lang: Optional[str] = None) -> str: """ cache and auto translate query if needed returns summary of provided document """ - user_lang = self._get_user_lang(context, lang) - document, context, lang = self._tx_query(document, context, lang) - # summarize - tldr = self.get_tldr(document, context) - - # translate output to user lang - if self.enable_tx and user_lang not in self.supported_langs: - return self.translator.translate(tldr, user_lang, lang) - return tldr + return self.get_tldr(document, + context=context, lang=lang) class EvidenceSolver(AbstractSolver): @@ -279,7 +366,8 @@ class EvidenceSolver(AbstractSolver): @abc.abstractmethod def get_best_passage(self, evidence: str, question: str, - context: Optional[dict] = None) -> str: + context: Optional[dict] = None, + lang: Optional[str] = None) -> str: """ evidence and question assured to be in self.default_lang returns summary of provided document @@ -287,23 +375,18 @@ def get_best_passage(self, evidence: str, question: str, raise NotImplementedError # user facing methods + @_deprecate_context2lang() + @_auto_translate(translate_keys=["evidence", "question"]) def extract_answer(self, evidence: str, question: str, - context: Optional[dict] = None, lang: Optional[str] = None) -> str: + context: Optional[dict] = None, + lang: Optional[str] = None) -> str: """ cache and auto translate evidence and question if needed returns passage from evidence that answers question """ - user_lang = self._get_user_lang(context, lang) - evidence, context, lang = self._tx_query(evidence, context, lang) - question, context, lang = self._tx_query(question, context, lang) - # extract answer from doc - ans = self.get_best_passage(evidence, question, context) - - # translate output to user lang - if self.enable_tx and user_lang not in self.supported_langs: - return self.translator.translate(ans, user_lang, lang) - return ans + return self.get_best_passage(evidence, question, + context=context, lang=lang) class MultipleChoiceSolver(AbstractSolver): @@ -315,38 +398,49 @@ class MultipleChoiceSolver(AbstractSolver): # TODO - make abstract in the future, # just giving some time buffer to update existing # plugins in the wild missing this method - #@abc.abstractmethod + # @abc.abstractmethod def rerank(self, query: str, options: List[str], - context: Optional[dict] = None) -> List[Tuple[float, str]]: + context: Optional[dict] = None, + lang: Optional[str] = None) -> List[Tuple[float, str]]: """ + query and options assured to be in self.default_lang + rank options list, returning a list of tuples (score, text) """ raise NotImplementedError + @_deprecate_context2lang() + @_auto_translate(translate_keys=["query", "options"]) def select_answer(self, query: str, options: List[str], - context: Optional[dict] = None) -> str: + context: Optional[dict] = None, + lang: Optional[str] = None, + return_index=False) -> Union[str, int]: """ query and options assured to be in self.default_lang return best answer from options list """ - return self.rerank(query, options, context)[0][1] + best = self.rerank(query, options, + context=context, lang=lang)[0][1] + if return_index: + return options.index(best) + return best # user facing methods + + @_deprecate_context2lang() def solve(self, query: str, options: List[str], - context: Optional[dict] = None, lang: Optional[str] = None) -> str: + context: Optional[dict] = None, + lang: Optional[str] = None) -> str: """ cache and auto translate query and options if needed returns best answer from provided options """ - user_lang = self._get_user_lang(context, lang) - query, context, lang = self._tx_query(query, context, lang) - opts = [self.translator.translate(opt, lang, user_lang) - for opt in options] - # select best answer - ans = self.select_answer(query, opts, context) - - idx = opts.index(ans) + # NOTE: use index so we return exactly the source text + # there may have been an auto-translation step in self.select_answer + idx = self.select_answer(query, options, + context=context, lang=lang, + return_index=True) return options[idx] @@ -358,21 +452,24 @@ class EntailmentSolver(AbstractSolver): @abc.abstractmethod def check_entailment(self, premise: str, hypothesis: str, - context: Optional[dict] = None) -> bool: + context: Optional[dict] = None, + lang: Optional[str] = None) -> bool: """ - premise and hyopithesis assured to be in self.default_lang + premise and hypothesis assured to be in self.default_lang return Bool, True if premise entails the hypothesis False otherwise """ raise NotImplementedError # user facing methods + @_deprecate_context2lang() + @_auto_translate(translate_keys=["premise", "hypothesis"]) def entails(self, premise: str, hypothesis: str, - context: Optional[dict] = None, lang: Optional[str] = None) -> bool: + context: Optional[dict] = None, + lang: Optional[str] = None) -> bool: """ cache and auto translate premise and hypothesis if needed return Bool, True if premise entails the hypothesis False otherwise """ - premise, context, lang = self._tx_query(premise, context, lang) - hypothesis, context, lang = self._tx_query(hypothesis, context, lang) # check for entailment - return self.check_entailment(premise, hypothesis) + return self.check_entailment(premise, hypothesis, + context=context, lang=lang)