Skip to content

Commit

Permalink
Enable packed dataset for validation; add a2a_experimental argument (#…
Browse files Browse the repository at this point in the history
…11378)

* Enable packed dataset for validation; add a2a_experimental argument

* Apply isort and black reformatting

Signed-off-by: michal2409 <[email protected]>

---------

Signed-off-by: michal2409 <[email protected]>
Co-authored-by: michal2409 <[email protected]>
  • Loading branch information
michal2409 and michal2409 authored Nov 23, 2024
1 parent 9d80f84 commit e83d3ea
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 19 deletions.
49 changes: 37 additions & 12 deletions nemo/collections/llm/gpt/data/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,17 +117,28 @@ def prepare_data(self) -> None:
"""
Prepare packed sequence data
"""
if self.packed_sequence_size > 0 and not self.train_path_packed.is_file():
if self.packed_sequence_size > 0:
from nemo.collections.llm.gpt.data.packed_sequence import prepare_packed_sequence_data

prepare_packed_sequence_data(
input_path=self.train_path,
output_path=self.train_path_packed,
packed_sequence_size=self.packed_sequence_size,
tokenizer=self.tokenizer,
max_seq_length=self.seq_length,
seed=self.seed,
)
if not self.train_path_packed.is_file():
prepare_packed_sequence_data(
input_path=self.train_path,
output_path=self.train_path_packed,
packed_sequence_size=self.packed_sequence_size,
tokenizer=self.tokenizer,
max_seq_length=self.seq_length,
seed=self.seed,
)

if not self.validation_path_packed.is_file():
prepare_packed_sequence_data(
input_path=self.validation_path,
output_path=self.validation_path_packed,
packed_sequence_size=self.packed_sequence_size,
tokenizer=self.tokenizer,
max_seq_length=self.seq_length,
seed=self.seed,
)

def setup(self, stage: str):
"""Called by pytorch lightning in datamodule setup"""
Expand Down Expand Up @@ -195,7 +206,7 @@ def val_dataloader(self) -> DataLoader:
# pylint: disable=C0115,C0116
return self._create_dataloader(
self._create_dataset(
self.validation_path,
self.validation_path if self.packed_sequence_size <= 0 else self.validation_path_packed,
is_test=True,
**self.dataset_kwargs,
),
Expand Down Expand Up @@ -249,15 +260,29 @@ def train_path_packed(self) -> Path:
"""Path to training dataset file for packed sequence. The file path contains a reference to the
tokenizer/model name since packed sequence dataset consists of tokenized indices."""
if self.packed_sequence_size > 0:
if self.packed_sequence_specs.packed_data_path is not None:
return self.packed_sequence_specs.packed_data_path
if self.packed_sequence_specs.packed_train_data_path is not None:
return self.packed_sequence_specs.packed_train_data_path
tokenizer_model_name = self._extract_tokenizer_model_name()
folder_name = self.dataset_root / "packed" / tokenizer_model_name
folder_name.mkdir(parents=True, exist_ok=True)
return folder_name / f"training_{self.packed_sequence_size}.npy"
else:
raise ValueError("`train_path_packed` invalid since packed sequence size is not specified.")

@property
def validation_path_packed(self) -> Path:
"""Path to validation dataset file for packed sequence. The file path contains a reference to the
tokenizer/model name since packed sequence dataset consists of tokenized indices."""
if self.packed_sequence_size > 0:
if self.packed_sequence_specs.packed_val_data_path is not None:
return self.packed_sequence_specs.packed_val_data_path
tokenizer_model_name = self._extract_tokenizer_model_name()
folder_name = self.dataset_root / "packed" / tokenizer_model_name
folder_name.mkdir(parents=True, exist_ok=True)
return folder_name / f"validation_{self.packed_sequence_size}.npy"
else:
raise ValueError("`validation_path_packed` invalid since packed sequence size is not specified.")

@property
def validation_path(self) -> Path:
"""Path to validation dataset file"""
Expand Down
30 changes: 23 additions & 7 deletions nemo/collections/llm/gpt/data/packed_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,31 @@ class PackedSequenceSpecs:
This field is set by llm.finetune api.
"""

packed_data_path: str = None
packed_train_data_path: str = None
"""
If specified, use the packed dataset from this file instead of the default path.
If specified, use this file for the packed training dataset instead of the default path.
"""

packed_val_data_path: str = None
"""
If specified, use this file for the packed validation dataset instead of the default path.
"""

def __post_init__(self):
if self.packed_data_path is not None:
self.packed_data_path = Path(self.packed_data_path)
if self.packed_train_data_path is not None:
self.packed_train_data_path = Path(self.packed_train_data_path)
assert (
self.packed_train_data_path.suffix == ".npy"
), f"packed training data file must be a .npy file: {self.packed_train_data_path}"
assert (
self.packed_train_data_path.exists()
), f"packed training data file does not exist: {self.packed_train_data_path}"

if self.packed_val_data_path is not None:
self.packed_val_data_path = Path(self.packed_val_data_path)
assert (
self.packed_val_data_path.suffix == ".npy"
), f"packed validation data file must be a .npy file: {self.packed_val_data_path}"
assert (
self.packed_data_path.suffix == ".npy"
), f"packed data file must be a .npy file: {self.packed_data_path}"
assert self.packed_data_path.exists(), f"packed data file does not exist: {self.packed_data_path}"
self.packed_val_data_path.exists()
), f"packed validation data file does not exist: {self.packed_val_data_path}"
3 changes: 3 additions & 0 deletions nemo/collections/llm/peft/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ class LoRA(PEFT):
dropout (float): Dropout rate for the low-rank projection. Defaults to 0.0.
dropout_position (Literal['pre', 'post'], optional): Position for applying dropout.
Can be 'pre' (before the low-rank projection) or 'post' (after). Defaults to 'post'.
a2a_experimental (bool): Enables the experimental All-to-All (A2A) communication strategy. Defaults to False.
Example:
--------
Expand Down Expand Up @@ -151,6 +152,7 @@ class LoRA(PEFT):
dropout_position: Literal['pre', 'post'] = 'post'
lora_A_init_method: str = "xavier"
lora_B_init_method: str = "zero"
a2a_experimental: bool = False

def transform(self, m: nn.Module, name=None, prefix=None):
"""
Expand Down Expand Up @@ -224,6 +226,7 @@ def wildcard_match(pattern, key):
model_parallel_config=getattr(m, "config", None),
alpha=self.alpha,
is_expert=is_expert_linear(full_name),
a2a_experimental=self.a2a_experimental,
)
return AdapterParallelAdd(m, adapter)
return m
Expand Down

0 comments on commit e83d3ea

Please sign in to comment.