From c03dd932f68c0dd595e3e69dcfc3774c4ebebb1f Mon Sep 17 00:00:00 2001 From: elronbandel Date: Mon, 16 Dec 2024 21:15:52 +0200 Subject: [PATCH 1/2] Try new dumping/loading of jsons that can allows artifacts in task data Signed-off-by: elronbandel --- src/unitxt/artifact.py | 31 +++++++++++++++-- src/unitxt/inference.py | 7 ++-- src/unitxt/llm_as_judge.py | 9 +++-- src/unitxt/schema.py | 28 +++++++-------- tests/library/test_artifact_recovery.py | 46 ++++++++++++++++++++++++- utils/.secrets.baseline | 6 ++-- 6 files changed, 100 insertions(+), 27 deletions(-) diff --git a/src/unitxt/artifact.py b/src/unitxt/artifact.py index b684dbe621..39b6013e76 100644 --- a/src/unitxt/artifact.py +++ b/src/unitxt/artifact.py @@ -281,6 +281,8 @@ def from_dict(cls, d, overwrite_args=None): @classmethod def load(cls, path, artifact_identifier=None, overwrite_args=None): + with open(path) as f: + d = json_loads_with_artifacts(f.read()) d = artifacts_json_cache(path) if "artifact_linked_to" in d and d["artifact_linked_to"] is not None: # d stands for an ArtifactLink @@ -379,7 +381,9 @@ def save(self, path): raise UnitxtError( f"Cannot save catalog artifacts that have changed since initialization. Detected differences in the following fields:\n{diffs}" ) - save_to_file(path, self.to_json()) + save_to_file( + path, json_dumps_with_artifacts(source=self, dump_source_as_dict=True) + ) def verify_instance( self, instance: Dict[str, Any], name: Optional[str] = None @@ -460,6 +464,29 @@ def verify_instance( return instance +def json_dumps_with_artifacts(source, dump_source_as_dict=False): + def maybe_artifact_object_to_dict(obj): + if isinstance(obj, Artifact): + if ( + dump_source_as_dict and obj.__id__ == source.__id__ + ) or obj.__id__ is None: + return obj.to_dict() + return obj.__id__ + return obj + + return json.dumps(source, default=maybe_artifact_object_to_dict) + + +def maybe_artifact_dict_to_object(d): + if Artifact.is_artifact_dict(d): + return Artifact.from_dict(d) + return d + + +def json_loads_with_artifacts(s): + return json.loads(s, object_hook=maybe_artifact_dict_to_object) + + class ArtifactLink(Artifact): # the artifact linked to, expressed by its catalog id artifact_linked_to: str = Field(default=None, required=True) @@ -671,7 +698,7 @@ def get_artifacts_data_classification(artifact: str) -> Optional[List[str]]: ) try: - data_classification = json.loads(data_classification) + data_classification = json_loads_with_artifacts(data_classification) except json.decoder.JSONDecodeError as e: raise RuntimeError(error_msg) from e diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 444596573b..41eb128991 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -1,7 +1,6 @@ import abc import asyncio import dataclasses -import json import logging import os import re @@ -27,7 +26,7 @@ from tqdm import tqdm, trange from tqdm.asyncio import tqdm_asyncio -from .artifact import Artifact +from .artifact import Artifact, json_loads_with_artifacts from .dataclass import InternalField, NonPositionalField from .deprecation_utils import deprecation from .error_utils import UnitxtError @@ -2216,7 +2215,7 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin): def _extract_queries(instance: Dict[str, Any]) -> Tuple[Optional[str], List]: task_data = instance["task_data"] if isinstance(task_data, str): - task_data = json.loads(task_data) + task_data = json_loads_with_artifacts(task_data) question = task_data.get("question") images = [None] @@ -2545,7 +2544,7 @@ def _infer( task_data = instance["task_data"] if isinstance(task_data, str): - task_data = json.loads(task_data) + task_data = json_loads_with_artifacts(task_data) for option in task_data["options"]: requests.append( diff --git a/src/unitxt/llm_as_judge.py b/src/unitxt/llm_as_judge.py index c76713d4f0..d7797f5e1a 100644 --- a/src/unitxt/llm_as_judge.py +++ b/src/unitxt/llm_as_judge.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Literal, Optional from .api import infer +from .artifact import json_loads_with_artifacts from .dataclass import Field from .formats import ChatAPIFormat, Format, SystemFormat from .inference import InferenceEngine, LogProbInferenceEngine, OpenAiInferenceEngine @@ -17,11 +18,13 @@ def get_task_data_dict(task_data): - import json - # seems like the task data sometimes comes as a string, not a dict # this fixes it - return json.loads(task_data) if isinstance(task_data, str) else task_data + return ( + json_loads_with_artifacts(task_data) + if isinstance(task_data, str) + else task_data + ) class LLMAsJudgeBase(BulkInstanceMetric, ArtifactFetcherMixin): diff --git a/src/unitxt/schema.py b/src/unitxt/schema.py index cde289124d..81138f7b71 100644 --- a/src/unitxt/schema.py +++ b/src/unitxt/schema.py @@ -1,10 +1,9 @@ -import json from typing import Any, Dict, List, Optional from datasets import Audio, Features, Sequence, Value from datasets import Image as DatasetImage -from .artifact import Artifact +from .artifact import json_dumps_with_artifacts, json_loads_with_artifacts from .dict_utils import dict_get from .operator import InstanceOperatorValidator from .settings_utils import get_constants, get_settings @@ -27,7 +26,7 @@ "audios": Sequence(Audio()), }, "postprocessors": Sequence(Value("string")), - "task_data": Value(dtype="string"), + "task_data": Value("string"), "data_classification_policy": Sequence(Value("string")), } ) @@ -39,7 +38,7 @@ "groups": Sequence(Value("string")), "subset": Sequence(Value("string")), "postprocessors": Sequence(Value("string")), - "task_data": Value(dtype="string"), + "task_data": Value("string"), "data_classification_policy": Sequence(Value("string")), "media": { "images": Sequence(Image()), @@ -64,13 +63,13 @@ def loads_instance(batch): or batch["source"][0].startswith('[{"content":') ) ): - batch["source"] = [json.loads(d) for d in batch["source"]] + batch["source"] = [json_loads_with_artifacts(d) for d in batch["source"]] if ( not settings.task_data_as_text and "task_data" in batch and isinstance(batch["task_data"][0], str) ): - batch["task_data"] = [json.loads(d) for d in batch["task_data"]] + batch["task_data"] = [json_loads_with_artifacts(d) for d in batch["task_data"]] return batch @@ -115,10 +114,10 @@ def _get_instance_task_data( def serialize_instance_fields(self, instance, task_data): if settings.task_data_as_text: - instance["task_data"] = json.dumps(task_data) + instance["task_data"] = json_dumps_with_artifacts(task_data) if not isinstance(instance["source"], str): - instance["source"] = json.dumps(instance["source"]) + instance["source"] = json_dumps_with_artifacts(instance["source"]) return instance def process( @@ -130,9 +129,8 @@ def process( ) task_data["metadata"]["num_demos"] = instance["recipe_metadata"]["num_demos"] - task_data["metadata"]["template"] = self.artifact_to_jsonable( - instance["recipe_metadata"]["template"] - ) + task_data["metadata"]["template"] = instance["recipe_metadata"]["template"] + if "demos" in instance: task_data["demos"] = [ self._get_instance_task_data(instance) @@ -159,7 +157,7 @@ def process( group_attributes = [group_attributes] for attribute in group_attributes: group[attribute] = dict_get(data, attribute) - groups.append(json.dumps(group)) + groups.append(json_dumps_with_artifacts(group)) instance["groups"] = groups instance["subset"] = [] @@ -167,11 +165,13 @@ def process( instance = self._prepare_media(instance) instance["metrics"] = [ - metric.to_json() if isinstance(metric, Artifact) else metric + json_dumps_with_artifacts(metric) if not isinstance(metric, str) else metric for metric in instance["metrics"] ] instance["postprocessors"] = [ - processor.to_json() if isinstance(processor, Artifact) else processor + json_dumps_with_artifacts(processor) + if not isinstance(processor, str) + else processor for processor in instance["postprocessors"] ] diff --git a/tests/library/test_artifact_recovery.py b/tests/library/test_artifact_recovery.py index 376baeb9f6..cc9dc8c6b3 100644 --- a/tests/library/test_artifact_recovery.py +++ b/tests/library/test_artifact_recovery.py @@ -1,9 +1,15 @@ +import json + from unitxt.artifact import ( Artifact, MissingArtifactTypeError, UnrecognizedArtifactTypeError, + json_dumps_with_artifacts, + json_loads_with_artifacts, ) +from unitxt.card import TaskCard from unitxt.logging_utils import get_logger +from unitxt.templates import InputOutputTemplate from tests.utils import UnitxtTestCase @@ -15,12 +21,50 @@ def test_correct_artifact_recovery(self): args = { "__type__": "standard_recipe", "card": "cards.sst2", - "template_card_index": 0, + "template": { + "__type__": "input_output_template", + "input_format": "Given the following {type_of_input}, generate the corresponding {type_of_output}. {type_of_input}: {input}", + "output_format": "{output}", + "postprocessors": [ + "processors.take_first_non_empty_line", + "processors.lower_case_till_punc", + ], + }, "demos_pool_size": 100, "num_demos": 0, } a = Artifact.from_dict(args) self.assertEqual(a.num_demos, 0) + self.assertIsInstance(a.template, InputOutputTemplate) + + def test_correct_artifact_loading_with_json_loads(self): + args = { + "__type__": "standard_recipe", + "card": "cards.sst2", + "template": { + "__type__": "input_output_template", + "input_format": "Given the following {type_of_input}, generate the corresponding {type_of_output}. {type_of_input}: {input}", + "output_format": "{output}", + "postprocessors": [ + "processors.take_first_non_empty_line", + "processors.lower_case_till_punc", + ], + }, + "demos_pool_size": 100, + "num_demos": 0, + } + + a = json_loads_with_artifacts(json.dumps(args)) + self.assertEqual(a.num_demos, 0) + + a = json_loads_with_artifacts(json.dumps({"x": args})) + self.assertEqual(a["x"].num_demos, 0) + + self.assertIsInstance(a["x"].card, TaskCard) + self.assertIsInstance(a["x"].template, InputOutputTemplate) + + d = json.loads(json_dumps_with_artifacts(a)) + self.assertDictEqual(d, {"x": args}) def test_correct_artifact_recovery_with_overwrite(self): args = { diff --git a/utils/.secrets.baseline b/utils/.secrets.baseline index ec9d704a7a..b41bb7be37 100644 --- a/utils/.secrets.baseline +++ b/utils/.secrets.baseline @@ -133,7 +133,7 @@ "filename": "src/unitxt/inference.py", "hashed_secret": "aa6cd2a77de22303be80e1f632195d62d211a729", "is_verified": false, - "line_number": 1235, + "line_number": 1234, "is_secret": false }, { @@ -141,7 +141,7 @@ "filename": "src/unitxt/inference.py", "hashed_secret": "c8f16a194efc59559549c7bd69f7bea038742e79", "is_verified": false, - "line_number": 1635, + "line_number": 1634, "is_secret": false } ], @@ -184,5 +184,5 @@ } ] }, - "generated_at": "2024-12-09T15:45:50Z" + "generated_at": "2024-12-16T19:15:29Z" } From 7334f82f4723b2b545572f6e61a2244529061468 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Mon, 16 Dec 2024 21:34:46 +0200 Subject: [PATCH 2/2] Fix Signed-off-by: elronbandel --- src/unitxt/artifact.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/unitxt/artifact.py b/src/unitxt/artifact.py index 39b6013e76..70486d6100 100644 --- a/src/unitxt/artifact.py +++ b/src/unitxt/artifact.py @@ -474,7 +474,7 @@ def maybe_artifact_object_to_dict(obj): return obj.__id__ return obj - return json.dumps(source, default=maybe_artifact_object_to_dict) + return json.dumps(source, default=maybe_artifact_object_to_dict, indent=4) def maybe_artifact_dict_to_object(d):