From 48341cf8ff40d012b3e597228c83c5b978ae1298 Mon Sep 17 00:00:00 2001 From: Kamyar Salahi Date: Mon, 21 Oct 2024 14:04:11 -0700 Subject: [PATCH] Adding set values for input / output --- config/gpt2_small_fast_supervised.yaml | 2 ++ src/levanter/data/text.py | 14 ++++++++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/config/gpt2_small_fast_supervised.yaml b/config/gpt2_small_fast_supervised.yaml index 56ce7ea36..d71e1267e 100644 --- a/config/gpt2_small_fast_supervised.yaml +++ b/config/gpt2_small_fast_supervised.yaml @@ -16,6 +16,8 @@ supervised_data: validation_urls: - "gs://marin-us-central2/benchmarks/mmlu/mmlu-*-dev-evaluation.jsonl.gz" cache_dir: "gs://marin-us-central2/benchmarks/tokenized-gpt2/mmlu/" + input_field: "input" + output_field: "output" model: type: gpt2 hidden_dim: 768 diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index a1e20384f..c16676410 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -570,13 +570,16 @@ class LMSupervisedDatasetConfig: """tags for the dataset. Typically the name of the dataset in the config will be added as a tag as well""" name: Optional[str] = None # name for hf dataset + input_field: str = "prompt" # name of the input field in the jsonl file + output_field: str = "response" # name of the output field in the jsonl file + validation_urls: List[str] = () # type:ignore -def preprocess_supervised_example(batch, tokenizer: PreTrainedTokenizerBase): - sources = [example["input"] for example in batch] +def preprocess_supervised_example(batch, tokenizer: PreTrainedTokenizerBase, input_field: str, output_field: str) -> dict: + sources = [example[input_field] for example in batch] - targets = [f"{example['output']}" for example in batch] + targets = [f"{example[output_field]}" for example in batch] # TODO: this seems pretty wasteful since you end up tokenizing twice, but it's how alpaca does it. examples = [s + t for s, t in zip(sources, targets)] sources_tokenized = tokenizer(sources, padding=False, truncation=True) @@ -623,9 +626,12 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain validation_urls = [url for url_pat in config.validation_urls for url in fsspec_expand_glob(url_pat)] dataset = levanter.data.datasource_from_jsonl(validation_urls) + input_field = config.input_field + output_field = config.output_field + output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((), dtype=np.int32)} - dataset = dataset.map_batches(lambda ex: preprocess_supervised_example(ex, tokenizer), batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer), output_exemplar=output_exemplar) # type: ignore + dataset = dataset.map_batches(lambda ex: preprocess_supervised_example(ex, tokenizer, input_field, output_field), batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer), output_exemplar=output_exemplar) # type: ignore dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True) # type: ignore if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token