diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index 0b36588b..d4b5568e 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -12,8 +12,7 @@ def filter_user_message_formatter( if isinstance(multimodal_data, str): text = multimodal_data image_inputs: list[dict[str, str]] = [] - - if isinstance(multimodal_data, dict): + elif isinstance(multimodal_data, dict): image_data: dict[str, str] = multimodal_data.get("image", {}) _image_inputs: list[tuple[dict, dict]] = [ ( @@ -30,7 +29,17 @@ def filter_user_message_formatter( ] image_inputs = [m for image_input in _image_inputs for m in image_input] text = multimodal_data["text"] or "" + else: + raise ValueError("multimodal_data must be a dictionary or a string") + if not image_inputs or len(image_inputs) == 0: + return { + "role": "user", + "content": { + "type": "text", + "text": f"Claim: {user_instruction}\n\nContext:\n{text}", + }, + } return { "role": "user", "content": [