From 812accb4f5b29f6f0bda84bcf87a4d4f7c538091 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Wed, 6 Nov 2024 10:08:51 -0800 Subject: [PATCH] load data from marin sources --- examples/sft/sft.py | 45 ++++++++---- examples/sft/tulu-llama-sft.yaml | 51 ++++++++++++++ src/levanter/data/sharded_datasource.py | 91 +++++++++++++++++++++++++ src/levanter/data/text.py | 78 +++++++++++++++++++++ 4 files changed, 253 insertions(+), 12 deletions(-) create mode 100644 examples/sft/tulu-llama-sft.yaml diff --git a/examples/sft/sft.py b/examples/sft/sft.py index 9813184b9..2ced8591c 100644 --- a/examples/sft/sft.py +++ b/examples/sft/sft.py @@ -1,6 +1,8 @@ import logging import os from dataclasses import dataclass +from enum import Enum +from typing import List, Optional import jax.random as jrandom import transformers @@ -13,11 +15,10 @@ from levanter import callbacks from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, save_hf_checkpoint_callback from levanter.data import PermutationDataset -from levanter.data.text import EpochDataset, mk_supervised_dataset +from levanter.data.text import ChatSFTDatasetConfig, EpochDataset, mk_chat_sft_dataset, mk_supervised_dataset from levanter.main.train_lm import TrainLmConfig from levanter.models.lm_model import LmHeadModel, compute_next_token_loss from levanter.trainer import Trainer -from levanter.utils.py_utils import non_caching_cycle logger = logging.getLogger(__name__) @@ -29,12 +30,26 @@ DEFAULT_UNK_TOKEN = "" +class DatasetType(str, Enum): + """Type of dataset to use""" + + HUGGINGFACE = "huggingface" # Use HF dataset + CHAT_JSONL = "chat_jsonl" # Use JSONL files with chat format + + @dataclass class SFTConfig(TrainLmConfig): # inherit most of the config from TrainLmConfig - max_tune_length: int = 2048 # maximum length of the input to the model during tuning + max_tune_length: int = 2048 model_name_or_path: str = "meta-llama/Llama-2-7b-hf" - tokenizer: str = "meta-llama/Llama-2-7b-hf" # Tokenizer to use + tokenizer: str = "meta-llama/Llama-2-7b-hf" + + # Add dataset type and chat-specific fields + dataset_type: DatasetType = DatasetType.HUGGINGFACE + chat_train_urls: Optional[List[str]] = None + messages_field: str = "messages" + input_role: str = "user" + output_role: str = "assistant" def train(config: SFTConfig): @@ -79,19 +94,26 @@ def train(config: SFTConfig): logger.info(f"Overriding data seed with {config.data_seed}") data_key = jrandom.PRNGKey(config.data_seed) - # Configure supervised dataset - supervised_config = config.supervised_data - # Create supervised dataset using generic machinery logger.info("Creating supervised dataset") - train_dataset = mk_supervised_dataset(supervised_config, tokenizer) + if config.dataset_type == DatasetType.CHAT_JSONL: + chat_config = ChatSFTDatasetConfig( + cache_dir=config.supervised_data.cache_dir, + train_urls=config.chat_train_urls, # No validation in this config + messages_field=config.messages_field, + input_role=config.input_role, + output_role=config.output_role, + ) + train_dataset = mk_chat_sft_dataset(chat_config, tokenizer) + else: + train_dataset = mk_supervised_dataset(config.supervised_data, tokenizer) logger.info("Supervised dataset created") train_dataset = PermutationDataset(train_dataset, data_key) # Then wrap for epochs - # if config.epoch > 0: - # logger.info(f"Wrapping dataset for {config.epoch} epochs") - # train_dataset = EpochDataset(train_dataset, max_epochs=config.epoch) + if config.epoch > 0: + logger.info(f"Wrapping dataset for {config.epoch} epochs") + train_dataset = EpochDataset(train_dataset, max_epochs=config.epoch) logger.info("Creating optimizer") optimizer = config.optimizer.build(config.trainer.num_train_steps) @@ -134,7 +156,6 @@ def train(config: SFTConfig): ) loader = trainer.data_loader(train_dataset, trainer.TrainBatch) - loader = non_caching_cycle(loader) if int(state.step) != 0: logger.info(f"Resuming training from step {state.step}") diff --git a/examples/sft/tulu-llama-sft.yaml b/examples/sft/tulu-llama-sft.yaml new file mode 100644 index 000000000..6086e624d --- /dev/null +++ b/examples/sft/tulu-llama-sft.yaml @@ -0,0 +1,51 @@ +# Model configuration +model: + type: llama + seq_len: 2048 + hidden_dim: 4096 + intermediate_dim: 11008 + num_layers: 32 + num_heads: 32 + num_kv_heads: 32 + use_flash_attention: true + flash_attention_block_size: 512 + use_bias: false + use_layer_norm_weight: false + +# Training configuration +trainer: + mp: p=f32,c=bfloat16 + tracker: + type: wandb + project: "levanter-sft" + tags: ["llama", "sft"] + num_train_steps: 750000 + train_batch_size: 64 + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" + steps_per_eval: 1000 + +# Optimizer settings +optimizer: + learning_rate: 2e-5 + weight_decay: 0.0 + min_lr_ratio: 0.1 + warmup: 100 + +# Supervised data configuration +dataset_type: chat_jsonl +chat_train_urls: + - "gs://marin-us-central2/documents/allenai--tulu-v2-sft-mixture-0ba27c/data/**/*.jsonl.gz" +supervised_data: + cache_dir: "gs://levanter-checkpoints/marin/sft_cache/chat-data" +messages_field: "messages" +input_role: "user" +output_role: "assistant" + +# Additional settings +tokenizer: "EleutherAI/gpt-neox-20b" +max_tune_length: 2048 +epoch: 0 + +initialize_from_hf: false diff --git a/src/levanter/data/sharded_datasource.py b/src/levanter/data/sharded_datasource.py index 186a0d9dd..333ddf768 100644 --- a/src/levanter/data/sharded_datasource.py +++ b/src/levanter/data/sharded_datasource.py @@ -1,6 +1,7 @@ import io import json import os +import re import warnings from typing import ( TYPE_CHECKING, @@ -16,11 +17,13 @@ Tuple, TypeVar, ) +from urllib.parse import urlparse import datasets import fsspec import numpy as np import pyarrow.parquet as pq +from google.cloud import storage from levanter.utils import fsspec_utils @@ -144,6 +147,68 @@ def map_batches( ) +def gcs_glob(pattern: str) -> list[str]: + """Glob files in Google Cloud Storage. + + Args: + pattern: GCS path pattern (gs://bucket/path/*) + + Returns: + List of matching GCS URLs + """ + if not pattern.startswith("gs://"): + # Handle local files + import glob + + return glob.glob(pattern) + + # Parse bucket and prefix from gs:// URL + parsed = urlparse(pattern) + bucket_name = parsed.netloc + prefix = parsed.path.lstrip("/") + + # Convert glob pattern to regex + prefix_no_glob = prefix.split("*")[0] + pattern_as_regex = re.compile(re.escape(prefix).replace("\\*", ".*")) + + # Initialize GCS client + client = storage.Client() + bucket = client.bucket(bucket_name) + + # List matching blobs + matching_urls = [] + for blob in bucket.list_blobs(prefix=prefix_no_glob): + if pattern_as_regex.match(blob.name): + matching_urls.append(f"gs://{bucket_name}/{blob.name}") + + return matching_urls + + +def datasource_from_chat_jsonl( + urls: Sequence[str], messages_field: str = "messages", input_role: str = "user", output_role: str = "assistant" +) -> "ShardedDataSource[dict]": + """Creates a ShardedDataSource from JSONL files containing chat messages. + + Args: + urls: Sequence of URLs or glob patterns pointing to JSONL files + messages_field: Field name containing the messages in each JSON object + input_role: Role identifier for input messages + output_role: Role identifier for output messages + + Returns: + ShardedDataSource configured for chat data + """ + # Expand any glob patterns in the URLs + expanded_urls = [] + for url in urls: + if any(c in url for c in "*?[]"): + expanded_urls.extend(gcs_glob(url)) + else: + expanded_urls.append(url) + + return ChatJsonlDataSource(expanded_urls, messages_field, input_role, output_role) + + def datasource_from_hf(id: str, *, split, **kwargs) -> ShardedDataSource[dict]: """ Create a ShardedDataset from a HuggingFace dataset. Arguments are passed to load_dataset. @@ -463,6 +528,32 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: return iter(data[row:]) +class ChatJsonlDataSource(JsonlDataSource): + """DataSource that reads JSONL files containing OpenAI chat format messages.""" + + def __init__(self, urls: Sequence[str], messages_field: str, input_role: str, output_role: str): + super().__init__(urls) + self.messages_field = messages_field + self.input_role = input_role + self.output_role = output_role + + def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: + url = self._shard_name_to_url_mapping[shard_name] + i = 0 + with fsspec.open(url, "r", compression="infer") as f: + for line in f: + if i >= row: + data = json.loads(line) + messages = data[self.messages_field] + + # Extract input/output from messages + input_msg = next(m["content"] for m in messages if m["role"] == self.input_role) + output_msg = next(m["content"] for m in messages if m["role"] == self.output_role) + + yield {"input": input_msg, "output": output_msg} + i += 1 + + class ParquetDataSource(ShardedDataSource[dict]): def __init__(self, urls): self.urls = urls diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 0181889d9..b42bcf5f6 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -743,6 +743,84 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain return dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer)) +@dataclass +class ChatSFTDatasetConfig(LMSupervisedDatasetConfig): + """Config for loading JSONL files in OpenAI chat format for supervised fine-tuning.""" + + # Chat format specific fields + messages_field: str = "messages" + input_role: str = "user" + output_role: str = "assistant" + train_urls: List[str] = field(default_factory=list) # Add this line + + def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]: + import levanter.data + + """Gets ShardedDataSource for either training or validation data.""" + urls = self.validation_urls if split == "validation" else self.train_urls + + if not urls: + return None + + # Use the datasource_from_chat_jsonl function from sharded_datasource + return levanter.data.sharded_datasource.datasource_from_chat_jsonl( + urls, messages_field=self.messages_field, input_role=self.input_role, output_role=self.output_role + ) + + +def preprocess_chat_example(batch, tokenizer: PreTrainedTokenizerBase) -> dict: + """ + Preprocess chat examples to match the format of preprocess_supervised_example. + Returns a dict with input_ids and sources_len like the supervised case. + """ + # Get sources (inputs) and targets (outputs) from the batch + sources = [example["input"] for example in batch] + targets = [example["output"] for example in batch] + + # Tokenize sources alone first to get the source lengths + sources_tokenized = tokenizer(sources, padding=False, truncation=True) + + # Combine source and target for full examples + full_examples = [f"{s}{t}" for s, t in zip(sources, targets)] + examples_tokenized = tokenizer(full_examples, padding=False, truncation=True) + + # Get source lengths to mask loss appropriately + source_lens = [len(s) for s in sources_tokenized["input_ids"]] + + return { + "input_ids": [np.array(example, dtype=np.int32) for example in examples_tokenized["input_ids"]], + "sources_len": np.array(source_lens, dtype=np.int32), + } + + +def mk_chat_sft_dataset(config: ChatSFTDatasetConfig, tokenizer: PreTrainedTokenizerBase) -> AsyncDataset[LmExample]: + """Creates a dataset from JSONL files containing chat format data for SFT.""" + source = config.get_shard_source("train") + if source is None: + raise ValueError("No training data source found") + + # Set up example structure matching supervised case + output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)} + + # Process the dataset + dataset = source.map_batches( + lambda ex: preprocess_chat_example(ex, tokenizer), + batch_size=128, + num_cpus=num_cpus_used_by_tokenizer(tokenizer), + output_exemplar=output_exemplar, + ) + + # Cache the processed data + dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True) + + # Ensure padding token is set (needed by _prepare_supervised_example) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Reuse the supervised prepare function directly + return dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer)) + + @dataclass class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig): """This class supports loading data both from HF Datasets and from a raw dataset of jsonl urls"""