Skip to content

Commit

Permalink
more chat
Browse files Browse the repository at this point in the history
  • Loading branch information
jquesnelle committed Aug 9, 2024
1 parent 632ddaa commit e7600d0
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 7 deletions.
28 changes: 27 additions & 1 deletion run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
import numpy as np
from nanotron import logging
from nanotron.config import DataArgs, DatasetStageArgs, NanosetDatasetsArgs, PretrainDatasetsArgs
from nanotron.data.dataloader_builder import build_nanoset_dataloader
from nanotron.config.config import ChatDatasetsArgs
from nanotron.data.chat_dataset import ChatDataset
from nanotron.data.dataloader_builder import build_chat_dataloader, build_nanoset_dataloader
from nanotron.dataloader import (
clm_process,
dummy_infinite_data_generator,
Expand Down Expand Up @@ -172,6 +174,30 @@ def get_dataloader_from_data_stage(
)

return train_dataloader

# Case 4: Chat Datasets
elif isinstance(data.dataset, ChatDatasetsArgs):
with main_rank_first(trainer.parallel_context.world_pg):
train_dataset = ChatDataset(
dataset_path=data.dataset.hf_dataset,
tokenizer_name_or_path=trainer.config.tokenizer.tokenizer_name_or_path,
sequence_length=trainer.sequence_length,
train_on_completions_only=data.dataset.train_on_completions_only,
remove_cross_attention=data.dataset.remove_cross_attention,
pack_samples=data.dataset.pack_samples,
split=data.dataset.hf_dataset_split,
conversation_column_name=data.dataset.conversation_column_name,
dp_rank=trainer.parallel_context.dp_pg.rank(),
dp_ranks_size=trainer.parallel_context.dp_pg.size(),
)

# Prepare dataloader
train_dataloader = build_chat_dataloader(
dataset=train_dataset,
parallel_context=trainer.parallel_context,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
)
else:
raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}")

Expand Down
16 changes: 16 additions & 0 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,22 @@ def __post_init__(self):
self.dataset_weights = list(tmp_dataset_folder.values())


@dataclass
class ChatDatasetsArgs:
hf_dataset: str
hf_dataset_split: str
conversation_column_name: str
# Debug
train_on_completions_only: bool = True
remove_cross_attention: bool = True
pack_samples: bool = True

def __post_init__(self):
if self.hf_dataset_split is None:
self.hf_dataset_split = "train"
if self.conversation_column_name is None:
self.conversation_column_name = "conversations"

@dataclass
class DataArgs:
"""Arguments related to the data and data files processing"""
Expand Down
13 changes: 8 additions & 5 deletions src/nanotron/data/chat_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
conversation_column_name: str,
train_on_completions_only: bool = True,
remove_cross_attention: bool = True,
pack_samples: bool = True,
split: str = "train",
dp_rank: int = 0,
dp_ranks_size: int = 1,
Expand All @@ -58,6 +59,7 @@ def __init__(
self.conversation_column_name = conversation_column_name
self.skip_num_samples = skip_num_samples
self.seed = seed
self.pack_samples = pack_samples

# Load, split and shuffle dataset
self.dataset = load_dataset(dataset_path, split=split, streaming=True)
Expand Down Expand Up @@ -100,11 +102,12 @@ def __iter__(self):
buffer_is_completition.extend(is_completition)
buffer_lengths.append(len(tokens))

if len(buffer_tokens) > max_buffer_token_len: # Can't pack more samples, yield
# Pop last sample from buffers
sample_tokens = buffer_tokens[: -len(tokens)]
sample_completitions = buffer_is_completition[: -len(tokens)]
sample_lengths = buffer_lengths[:-1]
if len(buffer_tokens) > max_buffer_token_len or not self.pack_samples: # Can't pack more samples, yield
if self.pack_samples:
# Pop last sample from buffers
sample_tokens = buffer_tokens[: -len(tokens)]
sample_completitions = buffer_is_completition[: -len(tokens)]
sample_lengths = buffer_lengths[:-1]

# TODO(tj.solergibert) Delete (debug)
assert len(sample_tokens) == len(sample_completitions) == sum(sample_lengths)
Expand Down
75 changes: 75 additions & 0 deletions src/nanotron/data/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,78 @@ def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Uni
)

return result

# TODO(tj.solergibert) After "Beta", delete all the functs except `build_position_ids` and move `build_position_ids` to chat_dataset.py
def build_position_ids(lengths, sequence_length) -> np.array:
position_ids = [list(range(length)) for length in lengths] # Create position ids list
return np.array([x for xs in position_ids for x in xs], dtype=np.int32) # Flatten list of position ids


# TODO(tj.solergibert) Delete (debug), just 4 switching the remove cross-attention setting
def build_position_ids_dummy(lengths, sequence_length) -> np.array:
return np.array(list(range(sum(lengths))), dtype=np.int32) # TODO numpy arange


# TODO(tj.solergibert) Delete (debug), just 4 switching the training only on completitions setting.
def build_labels_completions_only(input_ids, is_completitions):
return is_completitions


# TODO(tj.solergibert) Delete (debug), just 4 switching the training only on completitions setting
def build_labels(input_ids, is_completitions):
return [True for _ in range(len(is_completitions))]


@dataclasses.dataclass
class DataCollatorForSFT:
"""
Data collator used with Chat Dataset.
- input_pp_rank: Discards last input id token
- output_pp_rank: Discards first label id token
- other pp ranks: Don't have data. Instead, we use `TensorPointer` to point to the rank having the data.
"""

input_pp_rank: int
output_pp_rank: int
parallel_context: ParallelContext

def __call__(self, examples: List[Dict[str, List[int]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
# Process the case when current rank doesn't require data. We return `TensorPointer` that points to ranks having the data.

current_pp_rank = dist.get_rank(self.parallel_context.pp_pg)
if current_pp_rank not in [
self.input_pp_rank,
self.output_pp_rank,
]:
assert all(len(example) == 0 for example in examples)
return {
"input_ids": TensorPointer(group_rank=self.input_pp_rank),
"label_ids": TensorPointer(group_rank=self.output_pp_rank),
"position_ids": TensorPointer(group_rank=self.input_pp_rank),
"label_mask": TensorPointer(group_rank=self.output_pp_rank),
}

input_ids = np.vstack([examples[i]["input_ids"] for i in range(len(examples))]) # (b, s)
is_completitions = np.vstack([examples[i]["is_completitions"] for i in range(len(examples))]) # (b, s)
position_ids = np.vstack([examples[i]["position_ids"] for i in range(len(examples))]) # (b, s)

result: Dict[str, Union[np.ndarray, TensorPointer]] = {}

result["input_ids"] = TensorPointer(group_rank=self.input_pp_rank)
result["position_ids"] = TensorPointer(group_rank=self.input_pp_rank)
result["label_ids"] = TensorPointer(group_rank=self.output_pp_rank)
result["label_mask"] = TensorPointer(group_rank=self.output_pp_rank)

# Process inputs
if current_pp_rank == self.input_pp_rank:
result["input_ids"] = input_ids[:, :-1]
result["position_ids"] = position_ids[:, :-1]

# Process labels: shift them to the left.
if current_pp_rank == self.output_pp_rank:
result["label_ids"] = input_ids[:, 1:]
result["label_mask"] = is_completitions[:, 1:]

# Cast np.array to torch.Tensor
result = {k: v if isinstance(v, TensorPointer) else torch.from_numpy(v) for k, v in result.items()}
return result
32 changes: 31 additions & 1 deletion src/nanotron/data/dataloader_builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import nanotron.distributed as dist
from nanotron import logging
from nanotron.data.collator import NanosetDataCollatorForCLM
from nanotron.data.collator import DataCollatorForSFT, NanosetDataCollatorForCLM
from nanotron.dataloader import (
EmptyInfiniteDataset,
get_dataloader_worker_init,
Expand Down Expand Up @@ -62,3 +62,33 @@ def build_nanoset_dataloader(
pin_memory=dataloader_pin_memory,
worker_init_fn=get_dataloader_worker_init(dp_rank=dp_rank),
)

def build_chat_dataloader(
dataset,
parallel_context: ParallelContext,
input_pp_rank: int,
output_pp_rank: int,
dataloader_pin_memory: bool = True,
) -> DataLoader:

# Case of ranks not requiring data. We give them a dummy dataset, then the collator will do his job
if dist.get_rank(parallel_context.pp_pg) not in [input_pp_rank, output_pp_rank]:
dataset_length = 1_000_000 # len(dataset) TODO find a more elegant way to specify this dummy dataset
dataset = EmptyInfiniteDataset(length=dataset_length)

data_collator = DataCollatorForSFT(
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
parallel_context=parallel_context,
)

dp_rank = parallel_context.dp_pg.rank()

return DataLoader(
dataset,
batch_size=1,
collate_fn=data_collator,
num_workers=0,
pin_memory=dataloader_pin_memory,
worker_init_fn=get_dataloader_worker_init(dp_rank=dp_rank),
)

0 comments on commit e7600d0

Please sign in to comment.