diff --git a/src/unitxt/llm_as_judge.py b/src/unitxt/llm_as_judge.py index fc8fff6d1..5e3f95d36 100644 --- a/src/unitxt/llm_as_judge.py +++ b/src/unitxt/llm_as_judge.py @@ -4,6 +4,7 @@ from .api import infer from .artifact import fetch_artifact +from .dict_utils import dict_get from .error_utils import UnitxtError from .inference import ( InferenceEngine, @@ -59,7 +60,7 @@ class LLMJudge(BulkInstanceMetric): # ) evaluator_name: EvaluatorNameEnum = None check_positional_bias: bool = True - context_fields: str = ["context"] + context_fields: Union[str, List[str], Dict[str, str]] = ["context"] generate_summaries: bool = True format = "formats.chat_api" include_prompts_in_result: bool = False @@ -71,6 +72,10 @@ def prepare(self): super().prepare() if isinstance(self.context_fields, str): self.context_fields = [self.context_fields] + if isinstance(self.context_fields, List): + self.context_fields = { + context_field: context_field for context_field in self.context_fields + } # if not isinstance(self.option_selection_strategy, OptionSelectionStrategyEnum): # self.option_selection_strategy = OptionSelectionStrategyEnum[ @@ -149,8 +154,8 @@ def get_contexts(self, task_data: List[Dict[str, Any]]) -> List[Dict[str, str]]: return [ get_parsed_context( { - context_field: td[context_field] - for context_field in self.context_fields + context_field_name: dict_get(td, context_field) + for context_field_name, context_field in self.context_fields.items() } ) for td in task_data