Skip to content

Commit

Permalink
make demos_pool a variable, rather than a separate stream. Needs to r…
Browse files Browse the repository at this point in the history
…estart train repeatedly. Fix needed in Loaders. Allow a given demos_pool, or input stream-instances already loaded with demos

Signed-off-by: dafnapension <[email protected]>
  • Loading branch information
dafnapension committed Dec 20, 2024
1 parent b0286d7 commit ccebc83
Show file tree
Hide file tree
Showing 10 changed files with 454 additions and 186 deletions.
148 changes: 72 additions & 76 deletions src/unitxt/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Union

import pandas as pd
from datasets import IterableDatasetDict
from datasets import load_dataset as hf_load_dataset
from huggingface_hub import HfApi
from tqdm import tqdm
Expand All @@ -51,7 +52,7 @@
from .operator import SourceOperator
from .operators import Set
from .settings_utils import get_settings
from .stream import DynamicStream, MultiStream
from .stream import MultiStream
from .type_utils import isoftype
from .utils import LRUCache

Expand Down Expand Up @@ -122,7 +123,7 @@ def add_data_classification(self, multi_stream: MultiStream) -> MultiStream:
)
return operator(multi_stream)

def sef_default_data_classification(
def set_default_data_classification(
self, default_data_classification_policy, additional_info
):
if self.data_classification_policy is None:
Expand Down Expand Up @@ -161,23 +162,24 @@ class LoadHF(Loader):
and it can filter datasets upon loading.
Args:
path: The path or identifier of the dataset on the HuggingFace Hub.
name: An optional dataset name.
data_dir: Optional directory to store downloaded data.
split: Optional specification of which split to load.
data_files: Optional specification of particular data files to load.
revision: Optional. The revision of the dataset. Often the commit id. Use in case you want to set the dataset version.
streaming (bool): indicating if streaming should be used.
filtering_lambda: A lambda function for filtering the data after loading.
num_proc (int): Optional integer to specify the number of processes to use for parallel dataset loading.
path:
The path or identifier of the dataset on the HuggingFace Hub.
name:
An optional dataset name.
data_dir:
Optional directory to store downloaded data.
split:
Optional specification of which split to load.
data_files:
Optional specification of particular data files to load.
revision:
Optional. The revision of the dataset. Often the commit id. Use in case you want to set the dataset version.
streaming (bool):
indicating if streaming should be used.
filtering_lambda (str, optional):
A lambda function for filtering the data after loading.
num_proc (int, optional):
Specifies the number of processes to use for parallel dataset loading.
Example:
Loading glue's mrpc dataset
Expand Down Expand Up @@ -277,48 +279,37 @@ def load_dataset(self):
for split in dataset.keys():
dataset[split] = dataset[split].to_iterable_dataset()
else:
dataset = {self.split: dataset}

if self.filtering_lambda is not None:
dataset = self.filter_load(dataset)
dataset = {self.split: dataset.to_iterable_dataset()}

return dataset

def split_limited_load(self, dataset, split_name):
yield from itertools.islice(dataset[split_name], self.get_limit())

def limited_load(self, dataset):
self.log_limited_loading()
return MultiStream(
{
name: DynamicStream(
generator=self.split_limited_load,
gen_kwargs={"dataset": dataset, "split_name": name},
)
for name in dataset.keys()
}
)

def _maybe_set_classification_policy(self):
if os.path.exists(self.path):
self.sef_default_data_classification(
self.set_default_data_classification(
["proprietary"], "when loading from local files"
)
else:
self.sef_default_data_classification(
self.set_default_data_classification(
["public"], "when loading from Huggingface hub"
)

def load_iterables(self):
def load_iterables(self) -> IterableDatasetDict:
try:
dataset = self.stream_dataset()
except (
NotImplementedError
): # streaming is not supported for zipped files so we load without streaming
dataset = self.load_dataset()

if self.filtering_lambda is not None:
dataset = self.filter_load(dataset)

if self.get_limit() is not None:
return self.limited_load(dataset=dataset)
self.log_limited_loading()
return {
split_name: dataset[split_name].take(self.get_limit())
for split_name in dataset
}

return dataset

Expand Down Expand Up @@ -350,7 +341,7 @@ class LoadCSV(Loader):
sep: str = ","

def _maybe_set_classification_policy(self):
self.sef_default_data_classification(
self.set_default_data_classification(
["proprietary"], "when loading from local files"
)

Expand All @@ -363,9 +354,7 @@ def load_iterables(self):
file_path, nrows=self.get_limit(), sep=self.sep
).to_dict("records")
else:
iterables[split_name] = pd.read_csv(file_path, sep=self.sep).to_dict(
"records"
)
iterables[split_name] = pd.read_csv(file_path).to_dict("records")
return iterables


Expand Down Expand Up @@ -473,14 +462,22 @@ class LoadFromIBMCloud(Loader):
3. Mapping: split -> file_names, e.g. {"test" : ["test1.json", "test2.json"], "train": ["train.json"]}
Args:
endpoint_url_env: Environment variable name for the IBM Cloud endpoint URL.
aws_access_key_id_env: Environment variable name for the AWS access key ID.
aws_secret_access_key_env: Environment variable name for the AWS secret access key.
bucket_name: Name of the S3 bucket from which to load data.
data_dir: Optional directory path within the bucket.
data_files: Union type allowing either a list of file names or a mapping of splits to file names.
data_field: The dataset key for nested JSON file, i.e. when multiple datasets are nested in the same file
caching: Bool indicating if caching is enabled to avoid re-downloading data.
endpoint_url_env:
Environment variable name for the IBM Cloud endpoint URL.
aws_access_key_id_env:
Environment variable name for the AWS access key ID.
aws_secret_access_key_env:
Environment variable name for the AWS secret access key.
bucket_name:
Name of the S3 bucket from which to load data.
data_dir:
Optional directory path within the bucket.
data_files:
Union type allowing either a list of file names or a mapping of splits to file names.
data_field:
The dataset key for nested JSON file, i.e. when multiple datasets are nested in the same file
caching (bool):
indicating if caching is enabled to avoid re-downloading data.
Example:
Loading from IBM Cloud
Expand Down Expand Up @@ -576,7 +573,7 @@ def lazy_verify(self):
raise NotImplementedError("LoadFromKaggle cannot load with streaming.")

def _maybe_set_classification_policy(self):
self.sef_default_data_classification(
self.set_default_data_classification(
["proprietary"], "when loading from IBM COS"
)

Expand Down Expand Up @@ -727,7 +724,7 @@ def verify(self):
)

def _maybe_set_classification_policy(self):
self.sef_default_data_classification(
self.set_default_data_classification(
["proprietary"], "when loading from python dictionary"
)

Expand All @@ -742,25 +739,24 @@ class LoadFromHFSpace(LoadHF):
from the given space and then reads them as a HuggingFace Dataset.
Args:
space_name (str): Name of the HuggingFace Space to be accessed.
data_files (str | Sequence[str] | Mapping[str, str | Sequence[str]]): Relative
paths to files within a given repository. If given as a mapping, paths should
be values, while keys should represent the type of respective files
(training, testing etc.).
path (str, optional): Absolute path to a directory where data should be downloaded.
revision (str, optional): ID of a Git branch or commit to be used. By default, it is
set to None, thus data is downloaded from the main branch of the accessed
repository.
use_token (bool, optional): Whether a token is used for authentication when accessing
the HuggingFace Space. If necessary, the token is read from the HuggingFace
config folder.
token_env (str, optional): Key of an env variable which value will be used for
authentication when accessing the HuggingFace Space - if necessary.
space_name (str):
Name of the HuggingFace Space to be accessed.
data_files (str | Sequence[str] | Mapping[str, str | Sequence[str]]):
Relative paths to files within a given repository. If given as a mapping,
paths should be values, while keys should represent the type of respective files
(training, testing etc.).
path (str, optional):
Absolute path to a directory where data should be downloaded.
revision (str, optional):
ID of a Git branch or commit to be used. By default, it is set to None,
thus data is downloaded from the main branch of the accessed repository.
use_token (bool, optional):
Whether a token is used for authentication when accessing
the HuggingFace Space. If necessary, the token is read from the HuggingFace
config folder.
token_env (str, optional):
Key of an env variable which value will be used for
authentication when accessing the HuggingFace Space - if necessary.
Example:
Loading from a HuggingFace Space
Expand Down Expand Up @@ -908,7 +904,7 @@ def _map_wildcard_path_to_full_paths(self):
)

def _maybe_set_classification_policy(self):
self.sef_default_data_classification(
self.set_default_data_classification(
["public"], "when loading from Huggingface spaces"
)

Expand Down
3 changes: 3 additions & 0 deletions src/unitxt/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ def process(
)

task_data["metadata"]["num_demos"] = instance["recipe_metadata"]["num_demos"]
task_data["metadata"]["demos_pool_size"] = instance["recipe_metadata"][
"demos_pool_size"
]
task_data["metadata"]["template"] = self.artifact_to_jsonable(
instance["recipe_metadata"]["template"]
)
Expand Down
61 changes: 29 additions & 32 deletions src/unitxt/splitters.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import itertools
from abc import abstractmethod
from difflib import get_close_matches
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

from .artifact import Artifact
from .dict_utils import dict_get
from .operator import InstanceOperatorWithMultiStreamAccess, MultiStreamOperator
from .operator import InstanceOperator, MultiStreamOperator
from .random_utils import new_random_generator
from .split_utils import (
parse_random_mix_string,
Expand All @@ -14,7 +14,7 @@
rename_split,
slice_streams,
)
from .stream import EmptyStreamError, FaultyStreamError, MultiStream
from .stream import MultiStream
from .type_utils import isoftype
from .utils import recursive_copy

Expand Down Expand Up @@ -118,14 +118,14 @@ class Sampler(Artifact):
def sample(
self,
sample_size: int,
instances_pool: List[Dict[str, object]],
instance: Dict[str, object],
) -> List[Dict[str, object]]:
instances_pool: List[Dict[str, Any]],
instance: Dict[str, Any],
) -> List[Dict[str, Any]]:
pass

def filter_source_by_instance(
self, instances_pool: List[Dict[str, object]], instance: Dict[str, object]
) -> List[Dict[str, object]]:
self, instances_pool: List[Dict[str, Any]], instance: Dict[str, Any]
) -> List[Dict[str, Any]]:
if "input_fields" not in instance:
raise ValueError(f"'input_fields' field is missing from '{instance}'.")
try:
Expand Down Expand Up @@ -336,10 +336,11 @@ def sample(
return result


class Sample(InstanceOperatorWithMultiStreamAccess):
from_stream: str
class Sample(InstanceOperator):
from_field: str
to_field: str
sampler: Sampler
skip_demoed_instances: bool = False

def prepare(self):
self.local_cache = None
Expand All @@ -350,30 +351,26 @@ def get_sample_size(self, instance) -> int:
pass

def process(
self, instance: Dict[str, object], multi_stream: MultiStream
) -> Dict[str, object]:
sample_size = self.get_sample_size(instance)
try:
if self.local_cache is None:
self.local_cache = recursive_copy(list(multi_stream[self.from_stream]))
self, instance: Dict[str, Any], multi_stream: MultiStream
) -> Dict[str, Any]:
if self.skip_demoed_instances and self.to_field in instance:
if self.from_field in instance:
instance.pop(self.from_field)
return instance

source_stream = self.local_cache
source_stream = self.sampler.filter_source_by_instance(
source_stream, instance
)
if len(source_stream) < sample_size:
raise ValueError(
f"Size of population to sample from: {len(source_stream)} is smaller than the needed sample_size: {self.sampler.sample_size}."
)
sampled_instances = self.sampler.sample(
sample_size=sample_size, instances_pool=source_stream, instance=instance
demos_pool = instance[self.from_field]
sample_size = self.get_sample_size(instance)
source_stream = self.sampler.filter_source_by_instance(demos_pool, instance)
if len(source_stream) < sample_size:
raise ValueError(
f"Size of population to sample from: {len(source_stream)} is smaller than the needed sample_size: {sample_size}. Please consider increasing increasing the demos pool, for which you may need to increase loader_limit or employ a less strict stream filtering."
)
instance[self.to_field] = sampled_instances
return instance
except FaultyStreamError as e:
raise EmptyStreamError(
f"Unable to fetch instances from '{self.from_stream}' to '{self.to_field}', due to {e.__class__.__name__}: {e}"
) from e
sampled_instances = self.sampler.sample(
sample_size=sample_size, instances_pool=source_stream, instance=instance
)
instance[self.to_field] = recursive_copy(sampled_instances)
instance.pop(self.from_field) # pop the field pointing to the demos_pool
return instance


class ConstantSizeSample(Sample):
Expand Down
Loading

0 comments on commit ccebc83

Please sign in to comment.