Skip to content

Commit

Permalink
load data from marin sources
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Nov 6, 2024
1 parent ba682ca commit 812accb
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 12 deletions.
45 changes: 33 additions & 12 deletions examples/sft/sft.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import os
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional

import jax.random as jrandom
import transformers
Expand All @@ -13,11 +15,10 @@
from levanter import callbacks
from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, save_hf_checkpoint_callback
from levanter.data import PermutationDataset
from levanter.data.text import EpochDataset, mk_supervised_dataset
from levanter.data.text import ChatSFTDatasetConfig, EpochDataset, mk_chat_sft_dataset, mk_supervised_dataset
from levanter.main.train_lm import TrainLmConfig
from levanter.models.lm_model import LmHeadModel, compute_next_token_loss
from levanter.trainer import Trainer
from levanter.utils.py_utils import non_caching_cycle


logger = logging.getLogger(__name__)
Expand All @@ -29,12 +30,26 @@
DEFAULT_UNK_TOKEN = "<unk>"


class DatasetType(str, Enum):
"""Type of dataset to use"""

HUGGINGFACE = "huggingface" # Use HF dataset
CHAT_JSONL = "chat_jsonl" # Use JSONL files with chat format


@dataclass
class SFTConfig(TrainLmConfig):
# inherit most of the config from TrainLmConfig
max_tune_length: int = 2048 # maximum length of the input to the model during tuning
max_tune_length: int = 2048
model_name_or_path: str = "meta-llama/Llama-2-7b-hf"
tokenizer: str = "meta-llama/Llama-2-7b-hf" # Tokenizer to use
tokenizer: str = "meta-llama/Llama-2-7b-hf"

# Add dataset type and chat-specific fields
dataset_type: DatasetType = DatasetType.HUGGINGFACE
chat_train_urls: Optional[List[str]] = None
messages_field: str = "messages"
input_role: str = "user"
output_role: str = "assistant"


def train(config: SFTConfig):
Expand Down Expand Up @@ -79,19 +94,26 @@ def train(config: SFTConfig):
logger.info(f"Overriding data seed with {config.data_seed}")
data_key = jrandom.PRNGKey(config.data_seed)

# Configure supervised dataset
supervised_config = config.supervised_data

# Create supervised dataset using generic machinery
logger.info("Creating supervised dataset")
train_dataset = mk_supervised_dataset(supervised_config, tokenizer)
if config.dataset_type == DatasetType.CHAT_JSONL:
chat_config = ChatSFTDatasetConfig(
cache_dir=config.supervised_data.cache_dir,
train_urls=config.chat_train_urls, # No validation in this config
messages_field=config.messages_field,
input_role=config.input_role,
output_role=config.output_role,
)
train_dataset = mk_chat_sft_dataset(chat_config, tokenizer)
else:
train_dataset = mk_supervised_dataset(config.supervised_data, tokenizer)
logger.info("Supervised dataset created")
train_dataset = PermutationDataset(train_dataset, data_key)

# Then wrap for epochs
# if config.epoch > 0:
# logger.info(f"Wrapping dataset for {config.epoch} epochs")
# train_dataset = EpochDataset(train_dataset, max_epochs=config.epoch)
if config.epoch > 0:
logger.info(f"Wrapping dataset for {config.epoch} epochs")
train_dataset = EpochDataset(train_dataset, max_epochs=config.epoch)

logger.info("Creating optimizer")
optimizer = config.optimizer.build(config.trainer.num_train_steps)
Expand Down Expand Up @@ -134,7 +156,6 @@ def train(config: SFTConfig):
)

loader = trainer.data_loader(train_dataset, trainer.TrainBatch)
loader = non_caching_cycle(loader)

if int(state.step) != 0:
logger.info(f"Resuming training from step {state.step}")
Expand Down
51 changes: 51 additions & 0 deletions examples/sft/tulu-llama-sft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Model configuration
model:
type: llama
seq_len: 2048
hidden_dim: 4096
intermediate_dim: 11008
num_layers: 32
num_heads: 32
num_kv_heads: 32
use_flash_attention: true
flash_attention_block_size: 512
use_bias: false
use_layer_norm_weight: false

# Training configuration
trainer:
mp: p=f32,c=bfloat16
tracker:
type: wandb
project: "levanter-sft"
tags: ["llama", "sft"]
num_train_steps: 750000
train_batch_size: 64
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
steps_per_eval: 1000

# Optimizer settings
optimizer:
learning_rate: 2e-5
weight_decay: 0.0
min_lr_ratio: 0.1
warmup: 100

# Supervised data configuration
dataset_type: chat_jsonl
chat_train_urls:
- "gs://marin-us-central2/documents/allenai--tulu-v2-sft-mixture-0ba27c/data/**/*.jsonl.gz"
supervised_data:
cache_dir: "gs://levanter-checkpoints/marin/sft_cache/chat-data"
messages_field: "messages"
input_role: "user"
output_role: "assistant"

# Additional settings
tokenizer: "EleutherAI/gpt-neox-20b"
max_tune_length: 2048
epoch: 0

initialize_from_hf: false
91 changes: 91 additions & 0 deletions src/levanter/data/sharded_datasource.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import io
import json
import os
import re
import warnings
from typing import (
TYPE_CHECKING,
Expand All @@ -16,11 +17,13 @@
Tuple,
TypeVar,
)
from urllib.parse import urlparse

import datasets
import fsspec
import numpy as np
import pyarrow.parquet as pq
from google.cloud import storage

from levanter.utils import fsspec_utils

Expand Down Expand Up @@ -144,6 +147,68 @@ def map_batches(
)


def gcs_glob(pattern: str) -> list[str]:
"""Glob files in Google Cloud Storage.
Args:
pattern: GCS path pattern (gs://bucket/path/*)
Returns:
List of matching GCS URLs
"""
if not pattern.startswith("gs://"):
# Handle local files
import glob

return glob.glob(pattern)

# Parse bucket and prefix from gs:// URL
parsed = urlparse(pattern)
bucket_name = parsed.netloc
prefix = parsed.path.lstrip("/")

# Convert glob pattern to regex
prefix_no_glob = prefix.split("*")[0]
pattern_as_regex = re.compile(re.escape(prefix).replace("\\*", ".*"))

# Initialize GCS client
client = storage.Client()
bucket = client.bucket(bucket_name)

# List matching blobs
matching_urls = []
for blob in bucket.list_blobs(prefix=prefix_no_glob):
if pattern_as_regex.match(blob.name):
matching_urls.append(f"gs://{bucket_name}/{blob.name}")

return matching_urls


def datasource_from_chat_jsonl(
urls: Sequence[str], messages_field: str = "messages", input_role: str = "user", output_role: str = "assistant"
) -> "ShardedDataSource[dict]":
"""Creates a ShardedDataSource from JSONL files containing chat messages.
Args:
urls: Sequence of URLs or glob patterns pointing to JSONL files
messages_field: Field name containing the messages in each JSON object
input_role: Role identifier for input messages
output_role: Role identifier for output messages
Returns:
ShardedDataSource configured for chat data
"""
# Expand any glob patterns in the URLs
expanded_urls = []
for url in urls:
if any(c in url for c in "*?[]"):
expanded_urls.extend(gcs_glob(url))
else:
expanded_urls.append(url)

return ChatJsonlDataSource(expanded_urls, messages_field, input_role, output_role)


def datasource_from_hf(id: str, *, split, **kwargs) -> ShardedDataSource[dict]:
"""
Create a ShardedDataset from a HuggingFace dataset. Arguments are passed to load_dataset.
Expand Down Expand Up @@ -463,6 +528,32 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]:
return iter(data[row:])


class ChatJsonlDataSource(JsonlDataSource):
"""DataSource that reads JSONL files containing OpenAI chat format messages."""

def __init__(self, urls: Sequence[str], messages_field: str, input_role: str, output_role: str):
super().__init__(urls)
self.messages_field = messages_field
self.input_role = input_role
self.output_role = output_role

def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]:
url = self._shard_name_to_url_mapping[shard_name]
i = 0
with fsspec.open(url, "r", compression="infer") as f:
for line in f:
if i >= row:
data = json.loads(line)
messages = data[self.messages_field]

# Extract input/output from messages
input_msg = next(m["content"] for m in messages if m["role"] == self.input_role)
output_msg = next(m["content"] for m in messages if m["role"] == self.output_role)

yield {"input": input_msg, "output": output_msg}
i += 1


class ParquetDataSource(ShardedDataSource[dict]):
def __init__(self, urls):
self.urls = urls
Expand Down
78 changes: 78 additions & 0 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,84 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain
return dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer))


@dataclass
class ChatSFTDatasetConfig(LMSupervisedDatasetConfig):
"""Config for loading JSONL files in OpenAI chat format for supervised fine-tuning."""

# Chat format specific fields
messages_field: str = "messages"
input_role: str = "user"
output_role: str = "assistant"
train_urls: List[str] = field(default_factory=list) # Add this line

def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]:
import levanter.data

"""Gets ShardedDataSource for either training or validation data."""
urls = self.validation_urls if split == "validation" else self.train_urls

if not urls:
return None

# Use the datasource_from_chat_jsonl function from sharded_datasource
return levanter.data.sharded_datasource.datasource_from_chat_jsonl(
urls, messages_field=self.messages_field, input_role=self.input_role, output_role=self.output_role
)


def preprocess_chat_example(batch, tokenizer: PreTrainedTokenizerBase) -> dict:
"""
Preprocess chat examples to match the format of preprocess_supervised_example.
Returns a dict with input_ids and sources_len like the supervised case.
"""
# Get sources (inputs) and targets (outputs) from the batch
sources = [example["input"] for example in batch]
targets = [example["output"] for example in batch]

# Tokenize sources alone first to get the source lengths
sources_tokenized = tokenizer(sources, padding=False, truncation=True)

# Combine source and target for full examples
full_examples = [f"{s}{t}" for s, t in zip(sources, targets)]
examples_tokenized = tokenizer(full_examples, padding=False, truncation=True)

# Get source lengths to mask loss appropriately
source_lens = [len(s) for s in sources_tokenized["input_ids"]]

return {
"input_ids": [np.array(example, dtype=np.int32) for example in examples_tokenized["input_ids"]],
"sources_len": np.array(source_lens, dtype=np.int32),
}


def mk_chat_sft_dataset(config: ChatSFTDatasetConfig, tokenizer: PreTrainedTokenizerBase) -> AsyncDataset[LmExample]:
"""Creates a dataset from JSONL files containing chat format data for SFT."""
source = config.get_shard_source("train")
if source is None:
raise ValueError("No training data source found")

# Set up example structure matching supervised case
output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)}

# Process the dataset
dataset = source.map_batches(
lambda ex: preprocess_chat_example(ex, tokenizer),
batch_size=128,
num_cpus=num_cpus_used_by_tokenizer(tokenizer),
output_exemplar=output_exemplar,
)

# Cache the processed data
dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True)

# Ensure padding token is set (needed by _prepare_supervised_example)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

# Reuse the supervised prepare function directly
return dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer))


@dataclass
class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig):
"""This class supports loading data both from HF Datasets and from a raw dataset of jsonl urls"""
Expand Down

0 comments on commit 812accb

Please sign in to comment.