Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
elronbandel committed Jul 11, 2023
1 parent d4b7bcb commit aafb7c4
Show file tree
Hide file tree
Showing 10 changed files with 51 additions and 24 deletions.
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,8 @@ pypi:

hf:
python $(DIR)/make/hf.py

build:
format
pypi

14 changes: 14 additions & 0 deletions examples/example9.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from src.unitxt.text_utils import print_dict
from datasets import load_dataset

dataset = load_dataset('unitxt/data', 'card=wnli_card,template_item=0,num_demos=5,demos_pool_size=100')

print_dict(dataset['train'][0])

import evaluate

metric = evaluate.load('unitxt/metric')

results = metric.compute(predictions=['entailment' for t in dataset['test']], references=dataset['test'])

print_dict(results[0])
2 changes: 1 addition & 1 deletion prepare/cards/wnli_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
card = TaskCard(
loader=LoadHF(path='glue', name='wnli'),
preprocess_steps=[
SplitRandomMix({'train': 'train[:95%]', 'validation': 'train[95%:]', 'test': 'test'}),
SplitRandomMix({'train': 'train[95%]', 'validation': 'train[5%]', 'test': 'validation'}),
MapInstanceValues(mappers={'label': {"0": 'entailment', "1": 'not_entailment'}}),
AddFields(
fields={
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
description="Load any mixture of text to text data in one line of code",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.ibm.com/IBM-Research-AI/unitext",
url="https://github.com/ibm/unitxt",
packages=setuptools.find_packages('src'),
package_dir={'': 'src'},
package_data={'unitext': ['catalog/*.json']},
package_data={'unitxt': ['catalog/*.json']},
classifiers=[
"Programming Language :: Python :: 3",
"Operating System :: OS Independent",
Expand Down
9 changes: 2 additions & 7 deletions src/unitxt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,5 @@

register_blocks()

from . import dataset

dataset_url = dataset.__file__

from . import metric

metric_url = metric.__file__
dataset_url = 'unitxt/data'
metric_url = 'unitxt/metric'
6 changes: 5 additions & 1 deletion src/unitxt/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ class Catalog(Artifactory):
name: str = None
location: str = None

catalog_path = os.path.dirname(__file__) + "/catalog"
try:
import unitxt
catalog_path = os.path.dirname(unitxt.__file__) + "/catalog"
except ImportError:
catalog_path = os.path.dirname(__file__) + "/catalog"

class LocalCatalog(Catalog):
name: str = "local"
Expand Down
6 changes: 3 additions & 3 deletions src/unitxt/catalog/cards/wnli_card.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
{
"type": "split_random_mix",
"mix": {
"train": "train[:95%]",
"validation": "train[95%:]",
"test": "test"
"train": "train[95%]",
"validation": "train[5%]",
"test": "validation"
}
},
{
Expand Down
7 changes: 4 additions & 3 deletions src/unitxt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@


class CommonRecipe(Recipe, SourceOperator):

card: TaskCard
demos_pool_name: str = "demos_pool"
demos_pool_size: int = None
Expand All @@ -37,8 +38,8 @@ def prepare(self):
steps.append(
SliceSplit(
slices={
self.demos_pool_name: f"train[:{self.demos_pool_size}]",
"train": f"train[{self.demos_pool_size}:]",
self.demos_pool_name: f"train[:{int(self.demos_pool_size)}]",
"train": f"train[{int(self.demos_pool_size)}:]",
"validation": "validation",
"test": "test",
}
Expand All @@ -47,7 +48,7 @@ def prepare(self):

if self.num_demos is not None:
if self.sampler_type == "random":
sampler = RandomSampler(sample_size=self.num_demos)
sampler = RandomSampler(sample_size=int(self.num_demos))

steps.append(
SpreadSplit(
Expand Down
20 changes: 14 additions & 6 deletions src/unitxt/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,16 @@
#############

from .register import register_blocks
from .artifact import Artifact
from .artifact import Artifact, fetch_artifact, UnitxtArtifactNotFoundError

import datasets

def fetch(artifact_name):
try:
artifact, _ = fetch_artifact(artifact_name)
return artifact
except UnitxtArtifactNotFoundError:
return None

def parse(query: str):
"""
Expand All @@ -87,7 +93,7 @@ def parse(query: str):
return result


class Unitext(datasets.GeneratorBasedBuilder):
class Dataset(datasets.GeneratorBasedBuilder):
"""TODO: Short description of my dataset."""

VERSION = datasets.Version("1.1.1")
Expand All @@ -97,10 +103,12 @@ class Unitext(datasets.GeneratorBasedBuilder):
def generators(self):
register_blocks()
if not hasattr(self, "_generators") or self._generators is None:
args = parse(self.config.name)
if "type" not in args:
args["type"] = "common_recipe"
recipe = Artifact.from_dict(args)
recipe = fetch(self.config.name)
if recipe is None:
args = parse(self.config.name)
if "type" not in args:
args["type"] = "common_recipe"
recipe = Artifact.from_dict(args)
self._generators = recipe()
return self._generators

Expand Down
2 changes: 1 addition & 1 deletion src/unitxt/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def prepare(self):


# @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class UnitextMetric(evaluate.Metric):
class Metric(evaluate.Metric):
def _info(self):
return evaluate.MetricInfo(
description="_DESCRIPTION",
Expand Down

0 comments on commit aafb7c4

Please sign in to comment.