diff --git a/ovos_plugin_manager/templates/solvers.py b/ovos_plugin_manager/templates/solvers.py index 7f8f4498..b2bd725f 100644 --- a/ovos_plugin_manager/templates/solvers.py +++ b/ovos_plugin_manager/templates/solvers.py @@ -2,6 +2,7 @@ # QuestionSolver Improvements and other solver classes are OVOS originals licensed under Apache 2.0 import abc +import inspect from functools import wraps, lru_cache from typing import Optional, List, Iterable, Tuple, Dict, Union @@ -23,6 +24,11 @@ def func_decorator(func): @wraps(func) def func_wrapper(*args, **kwargs): + + # Inspect the function signature to ensure it has both 'lang' and 'context' parameters + signature = inspect.signature(func) + params = signature.parameters + if "context" in kwargs: # NOTE: deprecate this at same time we # standardize plugin namespaces to opm.XXX @@ -30,6 +36,12 @@ def func_wrapper(*args, **kwargs): "please pass 'lang' as it's own kwarg instead", "0.1.0") if "lang" in kwargs["context"] and "lang" not in kwargs: kwargs["lang"] = kwargs["context"]["lang"] + + # ensure valid kwargs + if "lang" not in params and "lang" in kwargs: + kwargs.pop("lang") + if "context" not in params and "context" in kwargs: + kwargs.pop("context") return func(*args, **kwargs) return func_wrapper @@ -275,13 +287,14 @@ def stream_utterances(self, query: str, :param lang: Optional language code. :return: An iterable of utterances. """ - ans = self.get_spoken_answer(query, context=context, lang=lang) + ans = _call_with_sanitized_kwargs(self.get_spoken_answer, query, + lang=lang, context=context) for utt in self.sentence_split(ans): yield utt @_deprecate_context2lang() def get_data(self, query: str, context: Optional[Dict] = None, - lang: Optional[str] = None) -> dict: + lang: Optional[str] = None) -> Optional[dict]: """ Retrieve data for the given query. @@ -290,11 +303,12 @@ def get_data(self, query: str, context: Optional[Dict] = None, :param lang: Optional language code. :return: A dictionary containing the answer. """ - return {"answer": self.get_spoken_answer(query, context)} + return {"answer": _call_with_sanitized_kwargs(self.get_spoken_answer, query, + lang=lang, context=context)(query, **kwargs)} @_deprecate_context2lang() def get_image(self, query: str, context: Optional[Dict] = None, - lang: Optional[str] = None) -> str: + lang: Optional[str] = None) -> Optional[str]: """ Get the path or URL to an image associated with the query. @@ -317,8 +331,10 @@ def get_expanded_answer(self, query: str, context: Optional[Dict] = None, :return: A list of dictionaries with each step containing a title, summary, and optional image. """ return [{"title": query, - "summary": self.get_spoken_answer(query, context), - "img": self.get_image(query, context)}] + "summary": _call_with_sanitized_kwargs(self.get_spoken_answer, query, + lang=lang, context=context), + "img": _call_with_sanitized_kwargs(self.get_image, query, + lang=lang, context=context)}] # user facing methods @_deprecate_context2lang() @@ -343,8 +359,8 @@ def search(self, query: str, context: Optional[Dict] = None, lang: Optional[str] else: # search data try: - data = self.get_data(query, - context=context, lang=lang) + data = _call_with_sanitized_kwargs(self.get_data, query, + lang=lang, context=context) except: return {} @@ -370,8 +386,8 @@ def visual_answer(self, query: str, context: Optional[Dict] = None, lang: Option :param lang: Optional language code. :return: The path or URL to the image. """ - return self.get_image(query, - context=context, lang=lang) + return _call_with_sanitized_kwargs(self.get_image, query, + lang=lang, context=context) @_deprecate_context2lang() @auto_translate_inputs(translate_keys=["query"]) @@ -394,13 +410,13 @@ def spoken_answer(self, query: str, context: Optional[Dict] = None, lang: Option # read from cache summary = self.spoken_cache[query] else: - summary = self.get_spoken_answer(query, - context=context, lang=lang) + + summary = _call_with_sanitized_kwargs(self.get_spoken_answer, query, + lang=lang, context=context) # save to cache if self.enable_cache: self.spoken_cache[query] = summary self.spoken_cache.store() - return summary @_deprecate_context2lang() @@ -420,19 +436,17 @@ def long_answer(self, query: str, context: Optional[Dict] = None, :param lang: Optional language code. :return: A list of steps to elaborate on the answer, with each step containing a title, summary, and optional image. """ - steps = self.get_expanded_answer(query, - context=context, lang=lang) - + steps = _call_with_sanitized_kwargs(self.get_expanded_answer, query, + lang=lang, context=context) # use spoken_answer as last resort if not steps: - summary = self.get_spoken_answer(query, - context=context, lang=lang) + summary = _call_with_sanitized_kwargs(self.get_spoken_answer, query, + lang=lang, context=context) if summary: - img = self.get_image(query, - context=context, lang=lang) + img = _call_with_sanitized_kwargs(self.get_image, query, + lang=lang, context=context) steps = [{"title": query, "summary": step0, "img": img} for step0 in self.sentence_split(summary, -1)] - return steps @@ -477,8 +491,19 @@ def tldr(self, document: str, :return: A summary of the provided document. """ # summarize - return self.get_tldr(document, - context=context, lang=lang) + return _call_with_sanitized_kwargs(self.get_tldr, document, + lang=lang, context=context) + + +def _call_with_sanitized_kwargs(func, *args, context: Optional[Dict] = None, lang: Optional[str] = None): + # Inspect the function signature to ensure it has both 'lang' and 'context' parameters + params = inspect.signature(func).parameters + kwargs = {} + if "context" in params: + kwargs["context"] = context + if "lang" in params: + kwargs["lang"] = lang + return func(*args, **kwargs) class EvidenceSolver(AbstractSolver):