Skip to content

Commit

Permalink
sanitize kwargs from plugins
Browse files Browse the repository at this point in the history
  • Loading branch information
JarbasAl committed Jul 23, 2024
1 parent bb70748 commit 74cc917
Showing 1 changed file with 48 additions and 23 deletions.
71 changes: 48 additions & 23 deletions ovos_plugin_manager/templates/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# QuestionSolver Improvements and other solver classes are OVOS originals licensed under Apache 2.0

import abc
import inspect
from functools import wraps, lru_cache
from typing import Optional, List, Iterable, Tuple, Dict, Union

Expand All @@ -23,13 +24,24 @@ def func_decorator(func):

@wraps(func)
def func_wrapper(*args, **kwargs):

# Inspect the function signature to ensure it has both 'lang' and 'context' parameters
signature = inspect.signature(func)
params = signature.parameters

if "context" in kwargs:
# NOTE: deprecate this at same time we
# standardize plugin namespaces to opm.XXX
log_deprecation("'context' kwarg has been deprecated, "
"please pass 'lang' as it's own kwarg instead", "0.1.0")
if "lang" in kwargs["context"] and "lang" not in kwargs:
kwargs["lang"] = kwargs["context"]["lang"]

# ensure valid kwargs
if "lang" not in params and "lang" in kwargs:
kwargs.pop("lang")
if "context" not in params and "context" in kwargs:
kwargs.pop("context")
return func(*args, **kwargs)

return func_wrapper
Expand Down Expand Up @@ -275,13 +287,14 @@ def stream_utterances(self, query: str,
:param lang: Optional language code.
:return: An iterable of utterances.
"""
ans = self.get_spoken_answer(query, context=context, lang=lang)
ans = _call_with_sanitized_kwargs(self.get_spoken_answer, query,
lang=lang, context=context)
for utt in self.sentence_split(ans):
yield utt

@_deprecate_context2lang()
def get_data(self, query: str, context: Optional[Dict] = None,
lang: Optional[str] = None) -> dict:
lang: Optional[str] = None) -> Optional[dict]:
"""
Retrieve data for the given query.
Expand All @@ -290,11 +303,12 @@ def get_data(self, query: str, context: Optional[Dict] = None,
:param lang: Optional language code.
:return: A dictionary containing the answer.
"""
return {"answer": self.get_spoken_answer(query, context)}
return {"answer": _call_with_sanitized_kwargs(self.get_spoken_answer, query,
lang=lang, context=context)(query, **kwargs)}

@_deprecate_context2lang()
def get_image(self, query: str, context: Optional[Dict] = None,
lang: Optional[str] = None) -> str:
lang: Optional[str] = None) -> Optional[str]:
"""
Get the path or URL to an image associated with the query.
Expand All @@ -317,8 +331,10 @@ def get_expanded_answer(self, query: str, context: Optional[Dict] = None,
:return: A list of dictionaries with each step containing a title, summary, and optional image.
"""
return [{"title": query,
"summary": self.get_spoken_answer(query, context),
"img": self.get_image(query, context)}]
"summary": _call_with_sanitized_kwargs(self.get_spoken_answer, query,
lang=lang, context=context),
"img": _call_with_sanitized_kwargs(self.get_image, query,
lang=lang, context=context)}]

# user facing methods
@_deprecate_context2lang()
Expand All @@ -343,8 +359,8 @@ def search(self, query: str, context: Optional[Dict] = None, lang: Optional[str]
else:
# search data
try:
data = self.get_data(query,
context=context, lang=lang)
data = _call_with_sanitized_kwargs(self.get_data, query,
lang=lang, context=context)
except:
return {}

Expand All @@ -370,8 +386,8 @@ def visual_answer(self, query: str, context: Optional[Dict] = None, lang: Option
:param lang: Optional language code.
:return: The path or URL to the image.
"""
return self.get_image(query,
context=context, lang=lang)
return _call_with_sanitized_kwargs(self.get_image, query,
lang=lang, context=context)

@_deprecate_context2lang()
@auto_translate_inputs(translate_keys=["query"])
Expand All @@ -394,13 +410,13 @@ def spoken_answer(self, query: str, context: Optional[Dict] = None, lang: Option
# read from cache
summary = self.spoken_cache[query]
else:
summary = self.get_spoken_answer(query,
context=context, lang=lang)

summary = _call_with_sanitized_kwargs(self.get_spoken_answer, query,
lang=lang, context=context)
# save to cache
if self.enable_cache:
self.spoken_cache[query] = summary
self.spoken_cache.store()

return summary

@_deprecate_context2lang()
Expand All @@ -420,19 +436,17 @@ def long_answer(self, query: str, context: Optional[Dict] = None,
:param lang: Optional language code.
:return: A list of steps to elaborate on the answer, with each step containing a title, summary, and optional image.
"""
steps = self.get_expanded_answer(query,
context=context, lang=lang)

steps = _call_with_sanitized_kwargs(self.get_expanded_answer, query,
lang=lang, context=context)
# use spoken_answer as last resort
if not steps:
summary = self.get_spoken_answer(query,
context=context, lang=lang)
summary = _call_with_sanitized_kwargs(self.get_spoken_answer, query,
lang=lang, context=context)
if summary:
img = self.get_image(query,
context=context, lang=lang)
img = _call_with_sanitized_kwargs(self.get_image, query,
lang=lang, context=context)
steps = [{"title": query, "summary": step0, "img": img}
for step0 in self.sentence_split(summary, -1)]

return steps


Expand Down Expand Up @@ -477,8 +491,19 @@ def tldr(self, document: str,
:return: A summary of the provided document.
"""
# summarize
return self.get_tldr(document,
context=context, lang=lang)
return _call_with_sanitized_kwargs(self.get_tldr, document,
lang=lang, context=context)


def _call_with_sanitized_kwargs(func, *args, context: Optional[Dict] = None, lang: Optional[str] = None):
# Inspect the function signature to ensure it has both 'lang' and 'context' parameters
params = inspect.signature(func).parameters
kwargs = {}
if "context" in params:
kwargs["context"] = context
if "lang" in params:
kwargs["lang"] = lang
return func(*args, **kwargs)


class EvidenceSolver(AbstractSolver):
Expand Down

0 comments on commit 74cc917

Please sign in to comment.