Skip to content

Commit

Permalink
Merge pull request #18 from swiss-ai/update_upstream
Browse files Browse the repository at this point in the history
Update upstream
  • Loading branch information
ischlag authored Sep 2, 2024
2 parents ddc8fa1 + ca9b096 commit 67332b4
Show file tree
Hide file tree
Showing 12 changed files with 237 additions and 67 deletions.
2 changes: 1 addition & 1 deletion docs/nanoset.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ To work with `Nanosets`, we just need to configure 1 argument:

Finally, to use the `Nanosets`, launch the training with [`run_train.py`](../run_train.py).
```shell
torchrun --nproc-per-node 8 run_train.py --config configs/config_nanoset.yaml
torchrun --nproc-per-node 1 run_train.py --config examples/config_nanoset.yaml
```

## Under the hood
Expand Down
7 changes: 4 additions & 3 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ class NanosetDatasetsArgs:
dataset_folder: Union[str, dict, List[str]]

def __post_init__(self):
if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset file
if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset folder
self.dataset_folder = [self.dataset_folder]
self.dataset_weights = [1]
elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset file
elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset folder
self.dataset_weights = None # Set to None so we consume all the samples randomly
elif isinstance(self.dataset_folder, dict): # Case 3: dict with > 1 dataset_folder and weights
tmp_dataset_folder = self.dataset_folder.copy()
Expand All @@ -111,7 +111,7 @@ def __post_init__(self):
class DataArgs:
"""Arguments related to the data and data files processing"""

dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs]
dataset: Optional[Union[PretrainDatasetsArgs, NanosetDatasetsArgs]]
seed: Optional[int]
num_loading_workers: Optional[int] = 1

Expand Down Expand Up @@ -145,6 +145,7 @@ class CheckpointsArgs:
checkpoints_path: Path
checkpoint_interval: int
save_initial_state: Optional[bool] = False
save_final_state: Optional[bool] = False
resume_checkpoint_path: Optional[Path] = None
checkpoints_path_is_shared_file_system: Optional[bool] = False

Expand Down
2 changes: 2 additions & 0 deletions src/nanotron/config/parallelism_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class ParallelismArgs:
tp_linear_async_communication: Optional[bool] = None
recompute_layer: bool = False

tp_recompute_allgather: bool = True

expert_parallel_size: int = 1

def __post_init__(self):
Expand Down
6 changes: 4 additions & 2 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def __init__(
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=gate_up_contiguous_chunks,
tp_recompute_allgather=parallel_config.tp_recompute_allgather,
)
self.down_proj = TensorParallelRowLinear(
config.intermediate_size,
Expand All @@ -164,8 +165,7 @@ def __init__(
bias=False,
async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
)
# TODO @nouamane: why can't we torch.jit.script GLUActivation?
self.split_silu_mul = GLUActivation(config.hidden_act)
self.split_silu_mul = torch.compile(GLUActivation(config.hidden_act))

def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim]
merged_states = self.gate_up_proj(hidden_states)
Expand Down Expand Up @@ -302,6 +302,7 @@ def __init__(
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=qkv_contiguous_chunks,
tp_recompute_allgather=parallel_config.tp_recompute_allgather,
)
# TODO(kunhao): We want to have only one version per device and not one version per layer.
self.rotary_embedding = RotaryEmbedding(
Expand Down Expand Up @@ -738,6 +739,7 @@ def __init__(
# TODO @thomasw21: refactor so that we store that default in a single place.
"mode": self.tp_mode,
"async_communication": tp_linear_async_communication,
"tp_recompute_allgather": parallel_config.tp_recompute_allgather,
},
module_input_keys={"x"},
module_output_keys={"logits"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]):
@staticmethod
def backward(ctx, grad_output):
group = ctx.group
return DifferentiableReduceScatterSum.apply(grad_output, group), None
out = DifferentiableReduceScatterSum.apply(grad_output, group)
return out, None


class DifferentiableReduceScatterSum(torch.autograd.Function):
Expand Down Expand Up @@ -113,7 +114,7 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]):
*rest_size,
device=tensor.device,
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
requires_grad=False,
)
dist.reduce_scatter_tensor(sharded_tensor, tensor, group=group, op=dist.ReduceOp.SUM)
return sharded_tensor
Expand Down
Loading

0 comments on commit 67332b4

Please sign in to comment.