Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try new dumping/loading of jsons that can allow artifacts in task data #1442

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions src/unitxt/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,8 @@ def from_dict(cls, d, overwrite_args=None):

@classmethod
def load(cls, path, artifact_identifier=None, overwrite_args=None):
with open(path) as f:
d = json_loads_with_artifacts(f.read())
d = artifacts_json_cache(path)
if "artifact_linked_to" in d and d["artifact_linked_to"] is not None:
# d stands for an ArtifactLink
Expand Down Expand Up @@ -379,7 +381,9 @@ def save(self, path):
raise UnitxtError(
f"Cannot save catalog artifacts that have changed since initialization. Detected differences in the following fields:\n{diffs}"
)
save_to_file(path, self.to_json())
save_to_file(
path, json_dumps_with_artifacts(source=self, dump_source_as_dict=True)
)

def verify_instance(
self, instance: Dict[str, Any], name: Optional[str] = None
Expand Down Expand Up @@ -460,6 +464,29 @@ def verify_instance(
return instance


def json_dumps_with_artifacts(source, dump_source_as_dict=False):
def maybe_artifact_object_to_dict(obj):
if isinstance(obj, Artifact):
if (
dump_source_as_dict and obj.__id__ == source.__id__
) or obj.__id__ is None:
return obj.to_dict()
return obj.__id__
return obj

return json.dumps(source, default=maybe_artifact_object_to_dict, indent=4)


def maybe_artifact_dict_to_object(d):
if Artifact.is_artifact_dict(d):
return Artifact.from_dict(d)
return d


def json_loads_with_artifacts(s):
return json.loads(s, object_hook=maybe_artifact_dict_to_object)


class ArtifactLink(Artifact):
# the artifact linked to, expressed by its catalog id
artifact_linked_to: str = Field(default=None, required=True)
Expand Down Expand Up @@ -671,7 +698,7 @@ def get_artifacts_data_classification(artifact: str) -> Optional[List[str]]:
)

try:
data_classification = json.loads(data_classification)
data_classification = json_loads_with_artifacts(data_classification)
except json.decoder.JSONDecodeError as e:
raise RuntimeError(error_msg) from e

Expand Down
7 changes: 3 additions & 4 deletions src/unitxt/inference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import abc
import asyncio
import dataclasses
import json
import logging
import os
import re
Expand All @@ -27,7 +26,7 @@
from tqdm import tqdm, trange
from tqdm.asyncio import tqdm_asyncio

from .artifact import Artifact
from .artifact import Artifact, json_loads_with_artifacts
from .dataclass import InternalField, NonPositionalField
from .deprecation_utils import deprecation
from .error_utils import UnitxtError
Expand Down Expand Up @@ -2216,7 +2215,7 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
def _extract_queries(instance: Dict[str, Any]) -> Tuple[Optional[str], List]:
task_data = instance["task_data"]
if isinstance(task_data, str):
task_data = json.loads(task_data)
task_data = json_loads_with_artifacts(task_data)
question = task_data.get("question")

images = [None]
Expand Down Expand Up @@ -2545,7 +2544,7 @@ def _infer(
task_data = instance["task_data"]

if isinstance(task_data, str):
task_data = json.loads(task_data)
task_data = json_loads_with_artifacts(task_data)

for option in task_data["options"]:
requests.append(
Expand Down
9 changes: 6 additions & 3 deletions src/unitxt/llm_as_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Dict, List, Literal, Optional

from .api import infer
from .artifact import json_loads_with_artifacts
from .dataclass import Field
from .formats import ChatAPIFormat, Format, SystemFormat
from .inference import InferenceEngine, LogProbInferenceEngine, OpenAiInferenceEngine
Expand All @@ -17,11 +18,13 @@


def get_task_data_dict(task_data):
import json

# seems like the task data sometimes comes as a string, not a dict
# this fixes it
return json.loads(task_data) if isinstance(task_data, str) else task_data
return (
json_loads_with_artifacts(task_data)
if isinstance(task_data, str)
else task_data
)


class LLMAsJudgeBase(BulkInstanceMetric, ArtifactFetcherMixin):
Expand Down
28 changes: 14 additions & 14 deletions src/unitxt/schema.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import json
from typing import Any, Dict, List, Optional

from datasets import Audio, Features, Sequence, Value
from datasets import Image as DatasetImage

from .artifact import Artifact
from .artifact import json_dumps_with_artifacts, json_loads_with_artifacts
from .dict_utils import dict_get
from .operator import InstanceOperatorValidator
from .settings_utils import get_constants, get_settings
Expand All @@ -27,7 +26,7 @@
"audios": Sequence(Audio()),
},
"postprocessors": Sequence(Value("string")),
"task_data": Value(dtype="string"),
"task_data": Value("string"),
"data_classification_policy": Sequence(Value("string")),
}
)
Expand All @@ -39,7 +38,7 @@
"groups": Sequence(Value("string")),
"subset": Sequence(Value("string")),
"postprocessors": Sequence(Value("string")),
"task_data": Value(dtype="string"),
"task_data": Value("string"),
"data_classification_policy": Sequence(Value("string")),
"media": {
"images": Sequence(Image()),
Expand All @@ -64,13 +63,13 @@ def loads_instance(batch):
or batch["source"][0].startswith('[{"content":')
)
):
batch["source"] = [json.loads(d) for d in batch["source"]]
batch["source"] = [json_loads_with_artifacts(d) for d in batch["source"]]
if (
not settings.task_data_as_text
and "task_data" in batch
and isinstance(batch["task_data"][0], str)
):
batch["task_data"] = [json.loads(d) for d in batch["task_data"]]
batch["task_data"] = [json_loads_with_artifacts(d) for d in batch["task_data"]]
return batch


Expand Down Expand Up @@ -115,10 +114,10 @@ def _get_instance_task_data(

def serialize_instance_fields(self, instance, task_data):
if settings.task_data_as_text:
instance["task_data"] = json.dumps(task_data)
instance["task_data"] = json_dumps_with_artifacts(task_data)

if not isinstance(instance["source"], str):
instance["source"] = json.dumps(instance["source"])
instance["source"] = json_dumps_with_artifacts(instance["source"])
return instance

def process(
Expand All @@ -130,9 +129,8 @@ def process(
)

task_data["metadata"]["num_demos"] = instance["recipe_metadata"]["num_demos"]
task_data["metadata"]["template"] = self.artifact_to_jsonable(
instance["recipe_metadata"]["template"]
)
task_data["metadata"]["template"] = instance["recipe_metadata"]["template"]

if "demos" in instance:
task_data["demos"] = [
self._get_instance_task_data(instance)
Expand All @@ -159,19 +157,21 @@ def process(
group_attributes = [group_attributes]
for attribute in group_attributes:
group[attribute] = dict_get(data, attribute)
groups.append(json.dumps(group))
groups.append(json_dumps_with_artifacts(group))

instance["groups"] = groups
instance["subset"] = []

instance = self._prepare_media(instance)

instance["metrics"] = [
metric.to_json() if isinstance(metric, Artifact) else metric
json_dumps_with_artifacts(metric) if not isinstance(metric, str) else metric
for metric in instance["metrics"]
]
instance["postprocessors"] = [
processor.to_json() if isinstance(processor, Artifact) else processor
json_dumps_with_artifacts(processor)
if not isinstance(processor, str)
else processor
for processor in instance["postprocessors"]
]

Expand Down
46 changes: 45 additions & 1 deletion tests/library/test_artifact_recovery.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import json

from unitxt.artifact import (
Artifact,
MissingArtifactTypeError,
UnrecognizedArtifactTypeError,
json_dumps_with_artifacts,
json_loads_with_artifacts,
)
from unitxt.card import TaskCard
from unitxt.logging_utils import get_logger
from unitxt.templates import InputOutputTemplate

from tests.utils import UnitxtTestCase

Expand All @@ -15,12 +21,50 @@ def test_correct_artifact_recovery(self):
args = {
"__type__": "standard_recipe",
"card": "cards.sst2",
"template_card_index": 0,
"template": {
"__type__": "input_output_template",
"input_format": "Given the following {type_of_input}, generate the corresponding {type_of_output}. {type_of_input}: {input}",
"output_format": "{output}",
"postprocessors": [
"processors.take_first_non_empty_line",
"processors.lower_case_till_punc",
],
},
"demos_pool_size": 100,
"num_demos": 0,
}
a = Artifact.from_dict(args)
self.assertEqual(a.num_demos, 0)
self.assertIsInstance(a.template, InputOutputTemplate)

def test_correct_artifact_loading_with_json_loads(self):
args = {
"__type__": "standard_recipe",
"card": "cards.sst2",
"template": {
"__type__": "input_output_template",
"input_format": "Given the following {type_of_input}, generate the corresponding {type_of_output}. {type_of_input}: {input}",
"output_format": "{output}",
"postprocessors": [
"processors.take_first_non_empty_line",
"processors.lower_case_till_punc",
],
},
"demos_pool_size": 100,
"num_demos": 0,
}

a = json_loads_with_artifacts(json.dumps(args))
self.assertEqual(a.num_demos, 0)

a = json_loads_with_artifacts(json.dumps({"x": args}))
self.assertEqual(a["x"].num_demos, 0)

self.assertIsInstance(a["x"].card, TaskCard)
self.assertIsInstance(a["x"].template, InputOutputTemplate)

d = json.loads(json_dumps_with_artifacts(a))
self.assertDictEqual(d, {"x": args})

def test_correct_artifact_recovery_with_overwrite(self):
args = {
Expand Down
6 changes: 3 additions & 3 deletions utils/.secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,15 @@
"filename": "src/unitxt/inference.py",
"hashed_secret": "aa6cd2a77de22303be80e1f632195d62d211a729",
"is_verified": false,
"line_number": 1235,
"line_number": 1234,
"is_secret": false
},
{
"type": "Secret Keyword",
"filename": "src/unitxt/inference.py",
"hashed_secret": "c8f16a194efc59559549c7bd69f7bea038742e79",
"is_verified": false,
"line_number": 1635,
"line_number": 1634,
"is_secret": false
}
],
Expand Down Expand Up @@ -184,5 +184,5 @@
}
]
},
"generated_at": "2024-12-09T15:45:50Z"
"generated_at": "2024-12-16T19:15:29Z"
}
Loading