Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding supervised data config #746

Merged
merged 9 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions config/gpt2_small_fast_supervised.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
data:
configs:
owt:
train_urls:
- "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz"
validation_urls:
- "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz"
wikitext:
id: dlwh/wikitext_103_detokenized
train_weights:
owt: 0.6
wikitext: 0.4
tokenizer: gpt2
cache_dir: "gs://levanter-data/tokenized/data_mix"
supervised_data:
validation_urls:
- "gs://marin-us-central2/benchmarks/mmlu/mmlu-*-dev-evaluation.jsonl.gz"
cache_dir: "gs://marin-us-central2/benchmarks/tokenized-gpt2/mmlu/"
model:
type: gpt2
hidden_dim: 768
num_heads: 12
num_layers: 12
seq_len: 1024
gradient_checkpointing: true
scale_attn_by_inverse_layer_idx: true
trainer:
tracker:
project: "levanter"
tags: [ "openwebtext+wiki", "gpt2", "itest"]

mp: p=f32,c=bfloat16
model_axis_size: 1

train_batch_size: 256
num_train_steps: 20000
optimizer:
learning_rate: 1E-3
weight_decay: 0.1
warmup: 0.01
2 changes: 1 addition & 1 deletion scripts/launch_gpt2_small_fast_tpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@

python infra/launch.py --foreground --tpu_name levanter-itest-32 --zone us-central2-b --tpu_type v4-32 --preemptible -- \
python -m levanter.main.train_lm \
--config_path config/gpt2_small_fast.yaml \
--config_path config/gpt2_small_fast_supervised.yaml \
--trainer.checkpointer.base_path gs://levanter-checkpoints/gpt-itest/ --trainer.checkpointer.save_interval 30m $*
4 changes: 3 additions & 1 deletion src/levanter/data/_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,15 @@ class _BatchMapTransform(_DatasetTransform):
num_cpus: int
num_gpus: int
resources: dict
output_exemplar: Any

def __init__(self, fn, batch_size, num_cpus, num_gpus, resources):
def __init__(self, fn, batch_size, num_cpus, num_gpus, resources, output_exemplar=None):
self.fn = fn
self.batch_size = batch_size
self.num_cpus = num_cpus
self.num_gpus = num_gpus
self.resources = resources
self.output_exemplar = output_exemplar


def as_record_batch(doc: BatchResult) -> pa.RecordBatch:
Expand Down
18 changes: 15 additions & 3 deletions src/levanter/data/sharded_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,14 @@ def map(self, fn: Callable[[T_co], U]) -> "ShardedDataSource[U]":
return _MappedShardedDataSource(self, fn)

def map_batches(
self, fn: Callable[[list[T_co]], BatchResult], batch_size, *, num_cpus=1, num_gpus=0, **resources
self,
fn: Callable[[list[T_co]], BatchResult],
batch_size,
*,
num_cpus=1,
num_gpus=0,
output_exemplar=None,
**resources,
) -> "ShardedDataSource[dict]":
"""
**Lazily** map a function over batches of data. This is useful for doing things like batching data for a model,
Expand All @@ -131,7 +138,9 @@ def map_batches(
Returns:
A new ShardedDataset.
"""
return _BatchMappedShardedDataSource(self, fn, batch_size, num_cpus=num_cpus, num_gpus=num_gpus, **resources)
return _BatchMappedShardedDataSource(
self, fn, batch_size, num_cpus=num_cpus, num_gpus=num_gpus, output_exemplar=output_exemplar, **resources
)


def datasource_from_hf(id: str, *, split, **kwargs) -> ShardedDataSource[dict]:
Expand Down Expand Up @@ -478,10 +487,13 @@ def __init__(
batch_size,
num_cpus=1,
num_gpus=0,
output_exemplar=None,
**resources,
):
self.source = source
self._transform = _BatchMapTransform(fn, batch_size, num_cpus, num_gpus, resources)
self._transform = _BatchMapTransform(
fn, batch_size, num_cpus, num_gpus, resources, output_exemplar=output_exemplar
)

@property
def shard_names(self) -> Sequence[str]:
Expand Down
127 changes: 114 additions & 13 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,22 @@ def batch_size(self) -> int:
return self._batch_size


def fsspec_expand_glob(url):
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved
expanded_urls = braceexpand.braceexpand(url)
for expanded_url in expanded_urls:
if "*" in expanded_url:
fs = fsspec.core.url_to_fs(expanded_url)[0]
globbed = fs.glob(expanded_url)
# have to append the fs prefix back on
protocol, _ = fsspec.core.split_protocol(expanded_url)
if protocol is None:
yield from globbed
else:
yield from [f"{protocol}://{path}" for path in globbed]
else:
yield expanded_url


def concatenate_and_group_texts(
encoding: BatchEncoding,
seq_len: int,
Expand Down Expand Up @@ -520,19 +536,7 @@ def urls_for_split(self, split):
else:
raise ValueError(f"Unknown split {split}")

def fsspec_expand_glob(url):
if "*" in url:
fs = fsspec.core.url_to_fs(url)[0]
globbed = fs.glob(url)
# have to append the fs prefix back on
protocol, _ = fsspec.core.split_protocol(url)
if protocol is None:
return globbed
return [f"{protocol}://{path}" for path in globbed]
else:
return [url]

urls = [globbed for pat in urls for url in braceexpand.braceexpand(pat) for globbed in fsspec_expand_glob(url)]
urls = [globbed for url in urls for globbed in fsspec_expand_glob(url)]
return urls


Expand Down Expand Up @@ -584,6 +588,103 @@ def tagged_eval_sets(
return [(eval_sets[name], tags[name]) for name in eval_sets]


@dataclass
class LMSupervisedDatasetConfig(LMDatasetSourceConfig):
"""This class represents a dataset source with URLs or hf name/id."""

cache_dir: str = "cache/"

tags: Optional[List[str]] = None
"""tags for the dataset. Typically the name of the dataset in the config will be added as a tag as well"""
name: Optional[str] = None # name for hf dataset

validation_urls: List[str] = () # type:ignore

# def token_seq_dataset(
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved
# self, split: str, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True
# ) -> Optional[TokenSeqDataset]:
# cache = self.build_or_load_cache(split, monitors=monitors)
# if cache is None:
# return None
# return TokenSeqDataset(cache, seq_len)

# def validation_set(
# self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True
# ) -> Optional[TokenSeqDataset]:
# return self.token_seq_dataset("validation", seq_len, monitors)

# def validation_sets(
# self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True
# ) -> Mapping[str, AsyncDataset[np.ndarray]]:
# validation_set = self.validation_set(seq_len, monitors)
# if validation_set is not None:
# return {"": validation_set}
# else:
# return {}

# Add tagged eval set with split for auxiliary and validation dataset


def preprocess_supervised_example(batch, tokenizer: PreTrainedTokenizerBase):
sources = [example["input"] for example in batch]

targets = [f"{example['output']}" for example in batch]
# TODO: this seems pretty wasteful since you end up tokenizing twice, but it's how the original code does it.
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved
examples = [s + t for s, t in zip(sources, targets)]
sources_tokenized = tokenizer(sources, padding=False, truncation=True)
examples_tokenized = tokenizer(examples, padding=False, truncation=True)

source_lens = [len(s) for s in sources_tokenized["input_ids"]]

return {
"input_ids": [np.array(example, dtype=np.int32) for example in examples_tokenized["input_ids"]],
"sources_len": np.array(source_lens, dtype=np.int32),
}


def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase) -> LmExample:
"""
Prepare an example for training. This function converts the (cached) batch encoding into an LmExample.

It goes through the following steps:

1. Pad the batch to the maximum length.
2. Mask out the input and prompt if requested.
3. Create an LmExample with the input_ids as the input and the next token as the target.
"""
with local_cpu_mesh():
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved
# annoyingly, pad expects things to be batched so we have to prepend a batch axis
ex = tokenizer.pad({k: np.expand_dims(v, 0) for k, v in ex.items()}, return_tensors="np", padding="max_length")
ex = {k: v[0] for k, v in ex.items()}
input_ids = hax.named(ex["input_ids"], "position")
# mask out padding and anything before the start of the target
Pos = input_ids.resolve_axis("position")
loss_mask = hax.arange(Pos) >= ex["sources_len"] - 1

# don't predict the padding
targets = hax.roll(input_ids, -1, Pos)
loss_mask = loss_mask & (targets != tokenizer.pad_token_id)
loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_))
lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask)
return lm_ex


def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrainedTokenizerBase):
import levanter.data

validation_urls = [url for url_pat in config.validation_urls for url in fsspec_expand_glob(url_pat)]
dataset = levanter.data.datasource_from_jsonl(validation_urls)

output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((), dtype=np.int32)}

dataset = dataset.map_batches(lambda ex: preprocess_supervised_example(ex, tokenizer), batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer), output_exemplar=output_exemplar) # type: ignore
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved
dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True) # type: ignore
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

return dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer))


@dataclass
class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig):
"""This class supports loading data both from HF Datasets and from a raw dataset of jsonl urls"""
Expand Down
20 changes: 18 additions & 2 deletions src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from levanter import callbacks
from levanter.checkpoint import load_checkpoint
from levanter.compat.hf_checkpoints import HFCompatConfig, save_hf_checkpoint_callback
from levanter.data.text import CausalLmDataset, LMDatasetConfig, LMMixtureDatasetConfig
from levanter.data.text import CausalLmDataset, LMDatasetConfig, LMMixtureDatasetConfig, LMSupervisedDatasetConfig
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved
from levanter.models.gpt2 import Gpt2Config
from levanter.models.lm_model import LmConfig, compute_next_token_loss
from levanter.optim import AdamConfig, OptimizerConfig
Expand All @@ -30,6 +30,7 @@
@dataclass
class TrainLmConfig:
data: Union[LMDatasetConfig, LMMixtureDatasetConfig] = field(default_factory=LMDatasetConfig)
supervised_data: Optional[LMSupervisedDatasetConfig] = None
trainer: TrainerConfig = field(default_factory=TrainerConfig)
model: LmConfig = field(default_factory=Gpt2Config)
optimizer: OptimizerConfig = field(default_factory=AdamConfig)
Expand Down Expand Up @@ -170,7 +171,6 @@ def main(config: TrainLmConfig):
(CausalLmDataset(ds, Pos, KeyPos, ignore_index=config.data.ignore_token_id), tags)
for ds, tags in tagged_eval_datasets
]

cb = levanter.eval.cb_tagged_lm_evaluate(
EvalBatch,
causal_datasets,
Expand All @@ -182,6 +182,22 @@ def main(config: TrainLmConfig):
)
trainer.add_hook(cb, every=config.trainer.steps_per_eval)

if config.supervised_data is not None:
logger.info("Using supervised data")
supervised_eval = [(levanter.data.text.mk_supervised_dataset(config.supervised_data, tokenizer), "")]
# TODO Add tags
cb = levanter.eval.cb_tagged_lm_evaluate(
EvalBatch,
supervised_eval,
tokenizer,
trainer.device_mesh,
compute_axis_mapping,
max_eval_examples_per_ds,
prefix="internal_eval",
mp=config.trainer.mp,
)
trainer.add_hook(cb, every=config.trainer.steps_per_eval)

flops_per_token = config.model.flops_per_token(vocab_size)
flops_per_example = 3 * flops_per_token * Pos.size if flops_per_token is not None else None
trainer.add_hook(
Expand Down
82 changes: 82 additions & 0 deletions tests/test_supervised.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import numpy as np
from transformers import AutoTokenizer

import haliax

from levanter.data.text import _prepare_supervised_example, preprocess_supervised_example


def test_supervised_eval():
examples = [
{
"input": "Find all c in Z_3 such that Z_3[x]/(x^2 + c) is a field.\nA. 0\nB. 1\nC. 2\nD. 3\nAnswer:",
"output": "B",
}
]
tokenizer = AutoTokenizer.from_pretrained("gpt2")

if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

output = preprocess_supervised_example(examples, tokenizer)
assert len(output["input_ids"][0]) == output["sources_len"][0] + 1

ex = {
"input_ids": np.array(
[
16742,
477,
269,
287,
1168,
62,
18,
884,
326,
1168,
62,
18,
58,
87,
60,
29006,
87,
61,
17,
1343,
269,
8,
318,
257,
2214,
13,
198,
32,
13,
657,
198,
33,
13,
352,
198,
34,
13,
362,
198,
35,
13,
513,
198,
33706,
25,
33,
],
dtype=np.int32,
),
"sources_len": np.array(45, dtype=np.int32),
}

lm_ex = _prepare_supervised_example(ex, tokenizer)

assert lm_ex.loss_mask["position", 44]
assert haliax.sum(lm_ex.loss_mask) == 1
Loading