diff --git a/Makefile b/Makefile index 4aa9117cd..bf2e38750 100644 --- a/Makefile +++ b/Makefile @@ -27,3 +27,8 @@ pypi: hf: python $(DIR)/make/hf.py + +build: + format + pypi + diff --git a/examples/example9.py b/examples/example9.py new file mode 100644 index 000000000..9879ebc7f --- /dev/null +++ b/examples/example9.py @@ -0,0 +1,14 @@ +from src.unitxt.text_utils import print_dict +from datasets import load_dataset + +dataset = load_dataset('unitxt/data', 'card=wnli_card,template_item=0,num_demos=5,demos_pool_size=100') + +print_dict(dataset['train'][0]) + +import evaluate + +metric = evaluate.load('unitxt/metric') + +results = metric.compute(predictions=['entailment' for t in dataset['test']], references=dataset['test']) + +print_dict(results[0]) \ No newline at end of file diff --git a/prepare/cards/wnli_card.py b/prepare/cards/wnli_card.py index a42584fce..4dd6bdf02 100644 --- a/prepare/cards/wnli_card.py +++ b/prepare/cards/wnli_card.py @@ -15,7 +15,7 @@ card = TaskCard( loader=LoadHF(path='glue', name='wnli'), preprocess_steps=[ - SplitRandomMix({'train': 'train[:95%]', 'validation': 'train[95%:]', 'test': 'test'}), + SplitRandomMix({'train': 'train[95%]', 'validation': 'train[5%]', 'test': 'validation'}), MapInstanceValues(mappers={'label': {"0": 'entailment', "1": 'not_entailment'}}), AddFields( fields={ diff --git a/setup.py b/setup.py index c9fa8117f..3f483353e 100644 --- a/setup.py +++ b/setup.py @@ -14,10 +14,10 @@ description="Load any mixture of text to text data in one line of code", long_description=long_description, long_description_content_type="text/markdown", - url="https://github.ibm.com/IBM-Research-AI/unitext", + url="https://github.com/ibm/unitxt", packages=setuptools.find_packages('src'), package_dir={'': 'src'}, - package_data={'unitext': ['catalog/*.json']}, + package_data={'unitxt': ['catalog/*.json']}, classifiers=[ "Programming Language :: Python :: 3", "Operating System :: OS Independent", diff --git a/src/unitxt/__init__.py b/src/unitxt/__init__.py index 4ab43c762..d0c10845d 100644 --- a/src/unitxt/__init__.py +++ b/src/unitxt/__init__.py @@ -2,10 +2,5 @@ register_blocks() -from . import dataset - -dataset_url = dataset.__file__ - -from . import metric - -metric_url = metric.__file__ +dataset_url = 'unitxt/data' +metric_url = 'unitxt/metric' diff --git a/src/unitxt/catalog.py b/src/unitxt/catalog.py index 8bdd45fc6..273517594 100644 --- a/src/unitxt/catalog.py +++ b/src/unitxt/catalog.py @@ -7,7 +7,11 @@ class Catalog(Artifactory): name: str = None location: str = None -catalog_path = os.path.dirname(__file__) + "/catalog" +try: + import unitxt + catalog_path = os.path.dirname(unitxt.__file__) + "/catalog" +except ImportError: + catalog_path = os.path.dirname(__file__) + "/catalog" class LocalCatalog(Catalog): name: str = "local" diff --git a/src/unitxt/catalog/cards/wnli_card.json b/src/unitxt/catalog/cards/wnli_card.json index 296a0fb43..f738d3bb7 100644 --- a/src/unitxt/catalog/cards/wnli_card.json +++ b/src/unitxt/catalog/cards/wnli_card.json @@ -25,9 +25,9 @@ { "type": "split_random_mix", "mix": { - "train": "train[:95%]", - "validation": "train[95%:]", - "test": "test" + "train": "train[95%]", + "validation": "train[5%]", + "test": "validation" } }, { diff --git a/src/unitxt/common.py b/src/unitxt/common.py index a885e9531..5d1d8104d 100644 --- a/src/unitxt/common.py +++ b/src/unitxt/common.py @@ -11,6 +11,7 @@ class CommonRecipe(Recipe, SourceOperator): + card: TaskCard demos_pool_name: str = "demos_pool" demos_pool_size: int = None @@ -37,8 +38,8 @@ def prepare(self): steps.append( SliceSplit( slices={ - self.demos_pool_name: f"train[:{self.demos_pool_size}]", - "train": f"train[{self.demos_pool_size}:]", + self.demos_pool_name: f"train[:{int(self.demos_pool_size)}]", + "train": f"train[{int(self.demos_pool_size)}:]", "validation": "validation", "test": "test", } @@ -47,7 +48,7 @@ def prepare(self): if self.num_demos is not None: if self.sampler_type == "random": - sampler = RandomSampler(sample_size=self.num_demos) + sampler = RandomSampler(sample_size=int(self.num_demos)) steps.append( SpreadSplit( diff --git a/src/unitxt/dataset.py b/src/unitxt/dataset.py index 7b2b8c439..85c26f1b3 100644 --- a/src/unitxt/dataset.py +++ b/src/unitxt/dataset.py @@ -65,10 +65,16 @@ ############# from .register import register_blocks -from .artifact import Artifact +from .artifact import Artifact, fetch_artifact, UnitxtArtifactNotFoundError import datasets +def fetch(artifact_name): + try: + artifact, _ = fetch_artifact(artifact_name) + return artifact + except UnitxtArtifactNotFoundError: + return None def parse(query: str): """ @@ -87,7 +93,7 @@ def parse(query: str): return result -class Unitext(datasets.GeneratorBasedBuilder): +class Dataset(datasets.GeneratorBasedBuilder): """TODO: Short description of my dataset.""" VERSION = datasets.Version("1.1.1") @@ -97,10 +103,12 @@ class Unitext(datasets.GeneratorBasedBuilder): def generators(self): register_blocks() if not hasattr(self, "_generators") or self._generators is None: - args = parse(self.config.name) - if "type" not in args: - args["type"] = "common_recipe" - recipe = Artifact.from_dict(args) + recipe = fetch(self.config.name) + if recipe is None: + args = parse(self.config.name) + if "type" not in args: + args["type"] = "common_recipe" + recipe = Artifact.from_dict(args) self._generators = recipe() return self._generators diff --git a/src/unitxt/metric.py b/src/unitxt/metric.py index 67320baa6..7f1ae3041 100644 --- a/src/unitxt/metric.py +++ b/src/unitxt/metric.py @@ -119,7 +119,7 @@ def prepare(self): # @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class UnitextMetric(evaluate.Metric): +class Metric(evaluate.Metric): def _info(self): return evaluate.MetricInfo( description="_DESCRIPTION",