Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple supervisde evals, some cleanup around that #803

Merged
merged 4 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,17 @@ data:
tokenizer: gpt2
cache_dir: "gs://levanter-data/tokenized/data_mix"
supervised_data:
validation_urls:
- "gs://marin-us-central2/benchmarks/mmlu/mmlu-*-dev-evaluation.jsonl.gz"
- "gs://marin-us-central2/benchmarks/mmlu/mmlu-*-validation-evaluation.jsonl.gz"
cache_dir: "gs://marin-us-central2/benchmarks/tokenized-gpt2/mmlu/"
input_field: "input"
output_field: "output"
mmlu:
validation_urls:
- "gs://marin-us-central2/evaluation/mmlu-eval-subject-2eb39e/cais/*-validation-evaluation.jsonl.gz"
cache_dir: "gs://levanter-data/tokenized-gpt2/mmlu/"
tags: [ "e"]
arc_easy:
validation_urls:
- "gs://marin-us-central2/evaluation/arc-easy-b39e70/allenai/ai2_arc-ARC-Easy-validation-evaluation.jsonl.gz"
cache_dir: "gs://levanter-data/tokenized-gpt2/arc_easy/"
tags: [ "arc", "e"]

model:
type: gpt2
hidden_dim: 768
Expand Down
6 changes: 3 additions & 3 deletions examples/sft/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from levanter import callbacks
from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, save_hf_checkpoint_callback
from levanter.data import PermutationDataset
from levanter.data.text import ChatSFTDatasetConfig, EpochDataset, mk_chat_sft_dataset, mk_supervised_dataset
from levanter.data.text import ChatUrlDataSourceConfig, EpochDataset, mk_chat_sft_dataset, mk_supervised_dataset
from levanter.main.train_lm import TrainLmConfig
from levanter.models.lm_model import LmHeadModel, compute_next_token_loss
from levanter.trainer import Trainer
Expand Down Expand Up @@ -100,7 +100,7 @@ def train(config: SFTConfig):
if config.dataset_type == DatasetType.CHAT_JSONL:
assert config.chat_train_urls is not None
assert config.supervised_data is not None
chat_config = ChatSFTDatasetConfig(
chat_config = ChatUrlDataSourceConfig(
cache_dir=config.supervised_data.cache_dir,
train_urls=config.chat_train_urls, # No validation in this config
messages_field=config.messages_field,
Expand All @@ -110,7 +110,7 @@ def train(config: SFTConfig):
train_dataset = mk_chat_sft_dataset(chat_config, tokenizer, model_config.Pos)
else:
assert config.supervised_data is not None
train_dataset = mk_supervised_dataset(config.supervised_data, tokenizer, model_config.Pos)
train_dataset = mk_supervised_dataset(config.supervised_data, "train", tokenizer, model_config.Pos)
logger.info("Supervised dataset created")
train_dataset = PermutationDataset(train_dataset, data_key)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies = [
"transformers>=4.41.2",
"optax>=0.1.9",
"wandb>=0.17.8",
"draccus>=0.8.0",
"draccus>=0.9.3",
"pyarrow>=11.0.0",
"zstandard>=0.20.0",
"datasets>=3.1.0,<4.0",
Expand Down
7 changes: 3 additions & 4 deletions src/levanter/compat/hf_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from levanter.trainer import StepInfo
from levanter.utils import jax_utils
from levanter.utils.cloud_utils import temp_dir_before_upload
from levanter.utils.hf_utils import HfTokenizer
from levanter.utils.jax_utils import best_effort_sharding, local_cpu_mesh, use_cpu_device
from levanter.utils.py_utils import dataclass_with_default_init, logical_cpu_memory_size

Expand Down Expand Up @@ -872,7 +873,7 @@ def cb(step: StepInfo):

def arbitrary_load_from_hf(
model_name_or_path, from_pretrained_lambda, revision=None, local_cache_dir=None, trust_remote_code=True
) -> Union[PreTrainedTokenizerBase | ProcessorMixin]:
) -> Union[HfTokenizer | ProcessorMixin]:
is_url_like = urlparse(model_name_or_path).scheme != ""
if is_url_like:
if revision is not None:
Expand All @@ -889,9 +890,7 @@ def arbitrary_load_from_hf(
return from_pretrained_lambda(model_name_or_path, revision=revision, trust_remote_code=trust_remote_code)


def load_tokenizer(
model_name_or_path, revision=None, local_cache_dir=None, trust_remote_code=True
) -> PreTrainedTokenizerBase:
def load_tokenizer(model_name_or_path, revision=None, local_cache_dir=None, trust_remote_code=True) -> HfTokenizer:
"""Like AutoTokenizer.from_pretrained, but works with gs:// paths or anything on fsspec"""
return arbitrary_load_from_hf(
model_name_or_path,
Expand Down
2 changes: 1 addition & 1 deletion src/levanter/data/_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def __call__(self, batch):

match transform:
case _MapTransform(fn=fn):
batch = map(fn, batch)
batch = [fn(x) for x in batch]
case _BatchMapTransform(fn=fn):
batch = fn(batch)
is_soa_form = isinstance(batch, dict) or isinstance(batch, pa.RecordBatch)
Expand Down
112 changes: 48 additions & 64 deletions src/levanter/data/sharded_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,31 +184,6 @@ def gcs_glob(pattern: str) -> list[str]:
return matching_urls


def datasource_from_chat_jsonl(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this just moved to text

urls: Sequence[str], messages_field: str = "messages", input_role: str = "user", output_role: str = "assistant"
) -> "ShardedDataSource[dict]":
"""Creates a ShardedDataSource from JSONL files containing chat messages.

Args:
urls: Sequence of URLs or glob patterns pointing to JSONL files
messages_field: Field name containing the messages in each JSON object
input_role: Role identifier for input messages
output_role: Role identifier for output messages

Returns:
ShardedDataSource configured for chat data
"""
# Expand any glob patterns in the URLs
expanded_urls = []
for url in urls:
if any(c in url for c in "*?[]"):
expanded_urls.extend(gcs_glob(url))
else:
expanded_urls.append(url)

return ChatJsonlDataSource(expanded_urls, messages_field, input_role, output_role)


def datasource_from_hf(id: str, *, split, **kwargs) -> ShardedDataSource[dict]:
"""
Create a ShardedDataset from a HuggingFace dataset. Arguments are passed to load_dataset.
Expand Down Expand Up @@ -288,14 +263,49 @@ class TextUrlDataSource(ShardedDataSource[str]):

def __init__(self, urls, text_key="text"):
self.urls = urls
self._shard_name_to_url_mapping = _mk_shard_name_mapping(urls)
self.text_key = text_key
self.base_ds = UrlDataSource(urls, columns=[text_key])

@property
def shard_names(self) -> Sequence[str]:
return list(self._shard_name_to_url_mapping.keys())
return self.base_ds.shard_names

def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]:
url = self.base_ds._shard_name_to_url_mapping[shard_name]
i = 0
compression = "infer"
if url.endswith(".zstd"): # hacky way to detect zstd
compression = "zstd"

format = _sniff_format_for_dataset(url)

# special case for txt files
if format == ".txt":
with fsspec.open(url, "r", compression=compression) as f:
for line in f:
if i >= row:
yield line
i += 1
else:
for doc in self.base_ds.open_shard_at_row(shard_name, row):
yield doc[self.text_key]


class UrlDataSource(ShardedDataSource[dict]):
"""
Dataset for various dict-like formats.
"""

def __init__(self, urls, columns=None):
self.urls = urls
self._shard_name_to_url_mapping = _mk_shard_name_mapping(urls)
self.columns = columns

@property
def shard_names(self) -> Sequence[str]:
return list(self._shard_name_to_url_mapping.keys())

def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]:
url = self._shard_name_to_url_mapping[shard_name]
i = 0
compression = "infer"
Expand All @@ -310,19 +320,18 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]:
# which is not nothing, but not ideal.
for line in f:
if i >= row:
yield json.loads(line)[self.text_key]
i += 1
case ".txt":
with fsspec.open(url, "r", compression=compression) as f:
for line in f:
if i >= row:
yield line
obj = json.loads(line)
if self.columns:
yield {col: obj[col] for col in self.columns}
i += 1
case ".json":
with fsspec.open(url, "r", compression=compression) as f:
data = json.load(f)
for doc in data[row:]:
yield doc[self.text_key]
if self.columns:
yield {col: doc[col] for col in self.columns}
else:
yield doc
case ".parquet":
with fsspec.open(url, "rb", compression=compression) as f:
parquet_file = pq.ParquetFile(f)
Expand All @@ -347,11 +356,11 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]:

# Read from the starting row group onwards
for rg_idx in range(row_group_index, parquet_file.num_row_groups):
table = parquet_file.read_row_group(rg_idx, columns=[self.text_key])
table = parquet_file.read_row_group(rg_idx, columns=self.columns)
if rg_idx == row_group_index:
table = table.slice(start_row_in_group)
for record in table.to_pylist():
yield record[self.text_key]
yield record
case _:
raise ValueError(f"Unknown format {format}")

Expand Down Expand Up @@ -531,32 +540,6 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]:
return iter(data[row:])


class ChatJsonlDataSource(JsonlDataSource):
"""DataSource that reads JSONL files containing OpenAI chat format messages."""

def __init__(self, urls: Sequence[str], messages_field: str, input_role: str, output_role: str):
super().__init__(urls)
self.messages_field = messages_field
self.input_role = input_role
self.output_role = output_role

def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]:
url = self._shard_name_to_url_mapping[shard_name]
i = 0
with fsspec.open(url, "r", compression="infer") as f:
for line in f:
if i >= row:
data = json.loads(line)
messages = data[self.messages_field]

# Extract input/output from messages
input_msg = next(m["content"] for m in messages if m["role"] == self.input_role)
output_msg = next(m["content"] for m in messages if m["role"] == self.output_role)

yield {"input": input_msg, "output": output_msg}
i += 1


class ParquetDataSource(ShardedDataSource[dict]):
def __init__(self, urls):
self.urls = urls
Expand Down Expand Up @@ -650,7 +633,8 @@ def shard_names(self) -> Sequence[str]:
return self.source.shard_names

def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[T]:
return map(self.fn, self.source.open_shard_at_row(shard_name, row))
for doc in self.source.open_shard_at_row(shard_name, row):
yield self.fn(doc)


class _BatchMappedShardedDataSource(ShardedDataSource[T], _TransformedDataset):
Expand Down
Loading
Loading