diff --git a/large_language_model/megatron-lm/.gitignore b/large_language_model/megatron-lm/.gitignore new file mode 100644 index 000000000..3394e3c48 --- /dev/null +++ b/large_language_model/megatron-lm/.gitignore @@ -0,0 +1,8 @@ +megatron/__pycache__/ +megatron/data/__pycache__/ +megatron/model/__pycache__/ +megatron/mpu/__pycache__/ +megatron/optimizer/__pycache__/ +megatron/tokenizer/__pycache__/ +megatron/fused_kernels/__pycache__/ +megatron/fused_kernels/build/ diff --git a/large_language_model/megatron-lm/Dockerfile b/large_language_model/megatron-lm/Dockerfile index 27c528d9f..b19cb1226 100644 --- a/large_language_model/megatron-lm/Dockerfile +++ b/large_language_model/megatron-lm/Dockerfile @@ -1,4 +1,4 @@ -ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:22.04-py3 +ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:24.04-py3 FROM ${FROM_IMAGE_NAME} # Copy code diff --git a/large_language_model/megatron-lm/LICENSE b/large_language_model/megatron-lm/LICENSE index 4a60b9752..21e25d32b 100755 --- a/large_language_model/megatron-lm/LICENSE +++ b/large_language_model/megatron-lm/LICENSE @@ -1,6 +1,6 @@ The following applies to all files unless otherwise noted: -# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions diff --git a/large_language_model/megatron-lm/README.md b/large_language_model/megatron-lm/README.md index 3fb89ef47..f5c0253f4 100755 --- a/large_language_model/megatron-lm/README.md +++ b/large_language_model/megatron-lm/README.md @@ -7,7 +7,7 @@ Our codebase is capable of training large language models with both model and da ### Steps to configure machine -To use this repository, please install a supported version of PyTorch with GPU support (python 3.8, pytorch 1.12, cuda 11.6.2, and nccl 2.12.10 and above) and NVIDIA [APEX](https://github.com/NVIDIA/apex#quick-start). We recommend using one of [NGC's PyTorch containers](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch). The latest tested compatible version is `nvcr.io/nvidia/pytorch:22.04-py3`). +To use this repository, please install a supported version of PyTorch with GPU support (python 3.8, pytorch 1.12, cuda 11.6.2, and nccl 2.12.10 and above) and NVIDIA [APEX](https://github.com/NVIDIA/apex#quick-start). We recommend using one of [NGC's PyTorch containers](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch). The latest tested compatible version is `nvcr.io/nvidia/pytorch:24.04-py3`). ### Steps to run and time @@ -256,3 +256,7 @@ cd scripts sbatch preprocess.sh sbatch preprocess_val.sh ``` + +# 4. Model +### Publication/Attribution +Megatron ([1](https://arxiv.org/pdf/1909.08053.pdf) and [2](https://arxiv.org/pdf/2104.04473.pdf)) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA. diff --git a/large_language_model/megatron-lm/gpt3_blend.sh b/large_language_model/megatron-lm/gpt3_blend.sh deleted file mode 100755 index 01312f607..000000000 --- a/large_language_model/megatron-lm/gpt3_blend.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -COM_DIR="/c4/preprocessed_c4_googlespm" -C4_0="${COM_DIR}/c4_en_6_c4_spm_text_document" -C4_1="${COM_DIR}/c4_en_7_c4_spm_text_document" -DATA_BLEND="0.5 ${C4_0} 0.5 ${C4_1}" -VALID_C4="${COM_DIR}/c4_en_validation_subset_c4_spm_text_document" -VALID_DATA_BLEND="1.00 ${VALID_C4}" diff --git a/large_language_model/megatron-lm/megatron/arguments.py b/large_language_model/megatron-lm/megatron/arguments.py index 7d3a042ab..884dfe178 100755 --- a/large_language_model/megatron-lm/megatron/arguments.py +++ b/large_language_model/megatron-lm/megatron/arguments.py @@ -206,6 +206,7 @@ def validate_args(args, defaults={}): # Consumed tokens. args.consumed_train_samples = 0 args.consumed_valid_samples = 0 + args.tokens_per_batch = args.eval_interval * args.global_batch_size * args.seq_length # Iteration-based training. if args.train_iters: diff --git a/large_language_model/megatron-lm/megatron/data/bert_dataset.py b/large_language_model/megatron-lm/megatron/data/bert_dataset.py deleted file mode 100755 index 916a3be06..000000000 --- a/large_language_model/megatron-lm/megatron/data/bert_dataset.py +++ /dev/null @@ -1,195 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""BERT Style dataset.""" - -import numpy as np -import torch - -from megatron import ( - get_args, - get_tokenizer, - mpu, - print_rank_0 -) -from megatron.data.dataset_utils import ( - get_samples_mapping, - get_a_and_b_segments, - truncate_segments, - create_tokens_and_tokentypes, - create_masked_lm_predictions -) - - -class BertDataset(torch.utils.data.Dataset): - - def __init__(self, name, indexed_dataset, data_prefix, - num_epochs, max_num_samples, masked_lm_prob, - max_seq_length, short_seq_prob, seed, binary_head): - - # Params to store. - self.name = name - self.seed = seed - self.masked_lm_prob = masked_lm_prob - self.max_seq_length = max_seq_length - self.binary_head = binary_head - - # Dataset. - self.indexed_dataset = indexed_dataset - - # Build the samples mapping. - self.samples_mapping = get_samples_mapping(self.indexed_dataset, - data_prefix, - num_epochs, - max_num_samples, - self.max_seq_length - 3, # account for added tokens - short_seq_prob, - self.seed, - self.name, - self.binary_head) - - # Vocab stuff. - tokenizer = get_tokenizer() - self.vocab_id_list = list(tokenizer.inv_vocab.keys()) - self.vocab_id_to_token_dict = tokenizer.inv_vocab - self.cls_id = tokenizer.cls - self.sep_id = tokenizer.sep - self.mask_id = tokenizer.mask - self.pad_id = tokenizer.pad - - def __len__(self): - return self.samples_mapping.shape[0] - - def __getitem__(self, idx): - start_idx, end_idx, seq_length = self.samples_mapping[idx] - sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)] - # Note that this rng state should be numpy and not python since - # python randint is inclusive whereas the numpy one is exclusive. - # We % 2**32 since numpy requres the seed to be between 0 and 2**32 - 1 - np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32)) - return build_training_sample(sample, seq_length, - self.max_seq_length, # needed for padding - self.vocab_id_list, - self.vocab_id_to_token_dict, - self.cls_id, self.sep_id, - self.mask_id, self.pad_id, - self.masked_lm_prob, np_rng, - self.binary_head) - - - - -def build_training_sample(sample, - target_seq_length, max_seq_length, - vocab_id_list, vocab_id_to_token_dict, - cls_id, sep_id, mask_id, pad_id, - masked_lm_prob, np_rng, binary_head): - """Biuld training sample. - - Arguments: - sample: A list of sentences in which each sentence is a list token ids. - target_seq_length: Desired sequence length. - max_seq_length: Maximum length of the sequence. All values are padded to - this length. - vocab_id_list: List of vocabulary ids. Used to pick a random id. - vocab_id_to_token_dict: A dictionary from vocab ids to text tokens. - cls_id: Start of example id. - sep_id: Separator id. - mask_id: Mask token id. - pad_id: Padding token id. - masked_lm_prob: Probability to mask tokens. - np_rng: Random number genenrator. Note that this rng state should be - numpy and not python since python randint is inclusive for - the opper bound whereas the numpy one is exclusive. - """ - - if binary_head: - # We assume that we have at least two sentences in the sample - assert len(sample) > 1 - assert target_seq_length <= max_seq_length - - # Divide sample into two segments (A and B). - if binary_head: - tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, - np_rng) - else: - tokens_a = [] - for j in range(len(sample)): - tokens_a.extend(sample[j]) - tokens_b = [] - is_next_random = False - - # Truncate to `target_sequence_length`. - max_num_tokens = target_seq_length - truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a), - len(tokens_b), max_num_tokens, np_rng) - - # Build tokens and toketypes. - tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b, - cls_id, sep_id) - - # Masking. - max_predictions_per_seq = masked_lm_prob * max_num_tokens - (tokens, masked_positions, masked_labels, _, _) = create_masked_lm_predictions( - tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, - cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng) - - # Padding. - tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \ - = pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, - masked_labels, pad_id, max_seq_length) - - train_sample = { - 'text': tokens_np, - 'types': tokentypes_np, - 'labels': labels_np, - 'is_random': int(is_next_random), - 'loss_mask': loss_mask_np, - 'padding_mask': padding_mask_np, - 'truncated': int(truncated)} - return train_sample - - -def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, - masked_labels, pad_id, max_seq_length): - """Pad sequences and convert them to numpy.""" - - # Some checks. - num_tokens = len(tokens) - padding_length = max_seq_length - num_tokens - assert padding_length >= 0 - assert len(tokentypes) == num_tokens - assert len(masked_positions) == len(masked_labels) - - # Tokens and token types. - filler = [pad_id] * padding_length - tokens_np = np.array(tokens + filler, dtype=np.int64) - tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) - - # Padding mask. - padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, - dtype=np.int64) - - # Lables and loss mask. - labels = [-1] * max_seq_length - loss_mask = [0] * max_seq_length - for i in range(len(masked_positions)): - assert masked_positions[i] < num_tokens - labels[masked_positions[i]] = masked_labels[i] - loss_mask[masked_positions[i]] = 1 - labels_np = np.array(labels, dtype=np.int64) - loss_mask_np = np.array(loss_mask, dtype=np.int64) - - return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np diff --git a/large_language_model/megatron-lm/megatron/data/indexed_dataset.py b/large_language_model/megatron-lm/megatron/data/indexed_dataset.py index 2f6e1b845..203ca9c3d 100755 --- a/large_language_model/megatron-lm/megatron/data/indexed_dataset.py +++ b/large_language_model/megatron-lm/megatron/data/indexed_dataset.py @@ -95,7 +95,7 @@ def write_longs(f, a): 3: np.int16, 4: np.int32, 5: np.int64, - 6: np.float, + 6: float, 7: np.double, 8: np.uint16 } @@ -268,7 +268,7 @@ class IndexedDatasetBuilder(object): np.int16: 2, np.int32: 4, np.int64: 8, - np.float: 4, + float: 4, np.double: 8 } diff --git a/large_language_model/megatron-lm/megatron/model/vit_model.py b/large_language_model/megatron-lm/megatron/model/vit_model.py deleted file mode 100755 index 80f351175..000000000 --- a/large_language_model/megatron-lm/megatron/model/vit_model.py +++ /dev/null @@ -1,234 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Vision Transformer(VIT) model.""" - -import math -import einops -import torch -import torch.nn.functional as F -from megatron import get_args -from megatron.model.transformer import ParallelTransformer -from megatron.model.utils import ( - get_linear_layer, - init_method_normal, - scaled_init_method_normal, -) -from .module import MegatronModule - - -class VitMlpHead(MegatronModule): - """Pooler layer. - - Pool hidden states of a specific token (for example start of the - sequence) and add a linear transformation followed by a tanh. - - Arguments: - hidden_size: hidden size - init_method: weight initialization method for the linear layer. - bias is set to zero. - """ - - def __init__(self, hidden_size, num_classes): - super(VitMlpHead, self).__init__() - self.dense_in = torch.nn.Linear(hidden_size, hidden_size) - self.dense_out = torch.nn.Linear(hidden_size, num_classes) - torch.nn.init.constant_(self.dense_out.bias, -10) - - def forward(self, hidden_states, sequence_index=0): - # hidden_states: [b, s, h] - # sequence_index: index of the token to pool. - hidden_state = hidden_states[:, sequence_index, :] - dense_in_result = self.dense_in(hidden_state) - tanh_result = torch.tanh(dense_in_result) - dense_out_result = self.dense_out(tanh_result) - return dense_out_result - - -def twod_interpolate_position_embeddings_hook( - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, -): - - args = get_args() - num_patches_per_dim = args.img_dim // args.patch_dim - num_patches = num_patches_per_dim ** 2 - seq_length = num_patches + 1 - hidden_size = args.hidden_size - - key = prefix + "weight" - # import pdb - # pdb.set_trace() - assert key in state_dict - if key in state_dict: - input_param = state_dict[key] - - assert input_param.shape[1] == hidden_size - if input_param.shape[0] != seq_length: - # update input_param and load it to state_dict[key] - - num_tok_input = input_param.shape[0] - 1 - num_tok_new = seq_length - 1 - input_param_tok, input_param_grid = ( - input_param[:1, :], - input_param[1:, :], - ) - - gs_input = int(math.sqrt(num_tok_input)) - gs_new = int(math.sqrt(num_tok_new)) - - input_param_grid = input_param_grid.transpose(0, 1).contiguous() - input_param_grid = input_param_grid.reshape( - (1, -1, gs_input, gs_input) - ) - input_param_grid = input_param_grid.float() - scale_factor = gs_new / gs_input - - input_param_grid = F.interpolate( - input_param_grid, scale_factor=scale_factor, mode="bilinear" - ) - - input_param_grid = input_param_grid.half() - input_param_grid = input_param_grid.reshape((-1, gs_new * gs_new)) - input_param_grid = input_param_grid.transpose(0, 1).contiguous() - - assert input_param_grid.shape[1] == hidden_size - input_param = torch.cat((input_param_tok, input_param_grid), dim=0) - assert ( - input_param.shape[0] == seq_length - and input_param.shape[1] == hidden_size - ) - - state_dict[key] = input_param - - -class VitModel(MegatronModule): - """Vision Transformer Model.""" - - def __init__(self, - num_classes, - finetune=False, - pre_process=True, - post_process=True): - super(VitModel, self).__init__(share_word_embeddings=False) - args = get_args() - - self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy - if args.init_method_xavier_uniform: - self.init_method = torch.nn.init.xavier_uniform_ - self.scaled_init_method = torch.nn.init.xavier_uniform_ - else: - self.init_method = init_method_normal(args.init_method_std) - self.scaled_init_method = scaled_init_method_normal( - args.init_method_std, args.num_layers - ) - - self.pre_process = pre_process - self.post_process = post_process - self.hidden_size = args.hidden_size - self.num_classes = num_classes - self.patch_dim = args.patch_dim - self.img_dim = args.img_dim - self.finetune = finetune - - assert self.img_dim % self.patch_dim == 0 - self.num_patches_per_dim = self.img_dim // self.patch_dim - self.num_patches = self.num_patches_per_dim ** 2 - self.seq_length = self.num_patches + 1 - self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels - - if self.pre_process: - # cls_token - self.cls_token = torch.nn.Parameter( - torch.randn(1, 1, self.hidden_size) - ) - torch.nn.init.zeros_(self.cls_token) - - # Linear encoder - self.linear_encoder = torch.nn.Linear( - self.flatten_dim, self.hidden_size - ) - - # embedding - self.position_embeddings = torch.nn.Embedding( - self.seq_length, self.hidden_size - ) - init_method_normal(args.init_method_std)( - self.position_embeddings.weight - ) - self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda() - - self.position_embeddings._register_load_state_dict_pre_hook( - twod_interpolate_position_embeddings_hook - ) - - self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout) - - # Transformer - self.transformer = ParallelTransformer( - self.init_method, - self.scaled_init_method, - pre_process=self.pre_process, - post_process=self.post_process - ) - - if self.post_process: - # MLP head - if not self.finetune: - self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes) - else: - self.class_head = get_linear_layer( - self.hidden_size, num_classes, torch.nn.init.zeros_ - ) - - def set_input_tensor(self, input_tensor): - """See megatron.model.transformer.set_input_tensor()""" - self.transformer.set_input_tensor(input_tensor) - - def forward(self, input): - - if self.pre_process: - rearranged_input = einops.rearrange( - input, - "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", - p1=self.patch_dim, - p2=self.patch_dim, - ) - - assert rearranged_input.dtype == torch.half - encoder_output = self.linear_encoder(rearranged_input) - cls_tokens = self.cls_token.expand(encoder_output.shape[0], -1, -1) - concatenated_tokens = torch.cat((cls_tokens, encoder_output), dim=1) - - token_embeddings = concatenated_tokens + \ - self.position_embeddings(self.position_ids) - hidden_states = self.embedding_dropout(token_embeddings) - else: - hidden_states = input - - hidden_states = self.transformer(hidden_states, None) - - if self.post_process: - if not self.finetune: - hidden_states = self.mlp_head(hidden_states) - else: - hidden_states = self.class_head(hidden_states[:, 0, :]) - - return hidden_states diff --git a/large_language_model/megatron-lm/megatron/optimizer/clip_grads.py b/large_language_model/megatron-lm/megatron/optimizer/clip_grads.py index ad249bd5d..22fe8275d 100755 --- a/large_language_model/megatron-lm/megatron/optimizer/clip_grads.py +++ b/large_language_model/megatron-lm/megatron/optimizer/clip_grads.py @@ -16,7 +16,7 @@ """Gradient clipping.""" import torch -from torch._six import inf +from torch import inf from apex.multi_tensor_apply import multi_tensor_applier import amp_C diff --git a/large_language_model/megatron-lm/megatron/training.py b/large_language_model/megatron-lm/megatron/training.py index b4f5e1956..9145b1dc7 100755 --- a/large_language_model/megatron-lm/megatron/training.py +++ b/large_language_model/megatron-lm/megatron/training.py @@ -73,7 +73,7 @@ def pretrain(train_valid_test_dataset_provider, 1) initialize Megatron. 2) setup model, optimizer and lr schedule using the model_provider. 3) call train_val_test_data_provider to get train/val/test datasets. - 4) train the modle using the forward_step_func. + 4) train the model using the forward_step_func. Arguments: train_valid_test_dataset_provider: a function that takes the size of @@ -95,9 +95,6 @@ def pretrain(train_valid_test_dataset_provider, args_defaults: a dictionary from argument-name to argument-value. It to set already parse arguments. """ - # The reference implementation does not clear the cache currently - # but the submissions are required to do so - mllogger.event(key=mllogger.constants.CACHE_CLEAR, value=True) mllogger.start(key=mllogger.constants.INIT_START, sync=False) # Initalize and get arguments, timers, and Tensorboard writer. @@ -200,10 +197,6 @@ def pretrain(train_valid_test_dataset_provider, iteration = 0 mllogger.start(key=mllogger.constants.EPOCH_START, metadata={'epoch_num': 0}, sync=False) - mllogger.start(key=mllogger.constants.BLOCK_START, - metadata={'first_epoch_num': 0, - 'epoch_count': args.eval_interval * args.global_batch_size * args.seq_length}, - sync=False) if args.do_train and args.train_iters > 0: iteration = train(forward_step_func, @@ -230,19 +223,13 @@ def pretrain(train_valid_test_dataset_provider, 0, process_non_loss_data_func, True) - status = 'aborted' - mllogger.log_run_stop(status) mllogger.event(key="trained_samples", value=(args.consumed_train_samples - args.ext_lr_steps) * args.seq_length, sync=False) - mllogger.event(key="train_samples", - value=(args.consumed_train_samples - args.ext_lr_steps) * args.seq_length, - sync=False) - mllogger.end(key=mllogger.constants.BLOCK_STOP, - metadata={'first_epoch_num': 0}, - sync=False) mllogger.end(key=mllogger.constants.EPOCH_STOP, metadata={'epoch_num': (args.consumed_train_samples - args.ext_lr_steps) * args.seq_length}, sync=False) + status = 'aborted' + mllogger.log_run_stop(status) def update_train_iters(args): @@ -732,6 +719,11 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, args = get_args() timers = get_timers() + # log block start + mllogger.start(key=mllogger.constants.BLOCK_START, + metadata={'first_epoch_num': (args.consumed_train_samples - args.ext_lr_steps) * args.seq_length, 'epoch_count': args.tokens_per_batch}, + sync=False) + # Write args to tensorboard write_args_to_tensorboard() @@ -788,6 +780,10 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, valid_data_iterator, model, iteration, process_non_loss_data_func, False) + # log block start + mllogger.start(key=mllogger.constants.BLOCK_START, + metadata={'first_epoch_num': (args.consumed_train_samples - args.ext_lr_steps) * args.seq_length, 'epoch_count': args.tokens_per_batch}, + sync=False) # Checkpointing saved_checkpoint = False @@ -819,16 +815,10 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, mllogger.event(key="trained_samples", value=(args.consumed_train_samples - args.ext_lr_steps) * args.seq_length, sync=False) - mllogger.event(key="train_samples", - value=(args.consumed_train_samples - args.ext_lr_steps) * args.seq_length, - sync=False) if not saved_checkpoint: save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler) print_datetime('exiting program after {} minutes'.format(train_time)) - mllogger.end(key=mllogger.constants.BLOCK_STOP, - metadata={'first_epoch_num': 0}, - sync=False) mllogger.end(key=mllogger.constants.EPOCH_STOP, metadata={'epoch_num': (args.consumed_train_samples - args.ext_lr_steps) * args.seq_length}, sync=False) sys.exit() @@ -925,8 +915,12 @@ def evaluate_and_print_results(prefix, forward_step_func, verbose=False): """Helper function to evaluate and dump results on screen.""" args = get_args() - mllogger.start(key=mllogger.constants.EVAL_START, - metadata={'epoch_num': (args.consumed_train_samples - args.ext_lr_steps) * args.seq_length}, sync=False) + if iteration > 0: + mllogger.end(key=mllogger.constants.BLOCK_STOP, + metadata={'first_epoch_num': ((args.consumed_train_samples - args.ext_lr_steps) * args.seq_length) - args.tokens_per_batch, 'epoch_count': args.tokens_per_batch}, + sync=False) + mllogger.start(key=mllogger.constants.EVAL_START, + metadata={'epoch_num': (args.consumed_train_samples - args.ext_lr_steps) * args.seq_length}, sync=False) writer = get_tensorboard_writer() @@ -962,8 +956,14 @@ def evaluate_and_print_results(prefix, forward_step_func, print_rank_last('-' * length) print_rank_last(string) print_rank_last('-' * length) - mllogger.end(key=mllogger.constants.EVAL_STOP, - metadata=dict(epoch_num=(args.consumed_train_samples - args.ext_lr_steps) * args.seq_length), sync=False) + if iteration > 0: + mllogger.end(key=mllogger.constants.EVAL_STOP, + metadata=dict(epoch_num=(args.consumed_train_samples - args.ext_lr_steps) * args.seq_length), sync=False) + + if total_loss_dict and 'lm loss' in total_loss_dict.keys() and total_loss_dict['lm loss'].item() < 2.69: + print_rank_0('Target accuracy reached') + status = 'success' + mllogger.log_run_stop(status) def cyclic_iter(iter): diff --git a/large_language_model/megatron-lm/run_gpt3.sh b/large_language_model/megatron-lm/run_gpt3.sh index ce35f2b26..afef70199 100755 --- a/large_language_model/megatron-lm/run_gpt3.sh +++ b/large_language_model/megatron-lm/run_gpt3.sh @@ -1,11 +1,10 @@ #!/bin/bash -#SBATCH -p luna -A mlperf -t 00:20:00 --nodes=8 --exclusive --mem=0 --overcommit --ntasks-per-node=8 --job-name=mlperf-megatron:megatron - # Vars without defaults LOG_DIR=${1:?LOG_DIR not set} BPE_DIR=${2:?BPE_DIR not set} -CONT="${3:?CONT not set}" +COM_DIR=${3:?COM_DIR not set} +CONT="${4:?CONT not set}" # Vars with defaults : "${MEGATRON_DIR:=$PWD}" @@ -27,7 +26,11 @@ mkdir -p ${CHECKPOINT_DIR} mkdir -p ${TENSORBOARD_DIR} # Get the data blend -. $PWD/gpt3_blend.sh +C4_6="${COM_DIR}/c4_en_6_c4_spm_text_document" +C4_7="${COM_DIR}/c4_en_7_c4_spm_text_document" +DATA_BLEND="0.5 ${C4_6} 0.5 ${C4_7}" +VALID_C4="${COM_DIR}/c4_en_validation_subset_c4_spm_text_document" +VALID_DATA_BLEND="1.00 ${VALID_C4}" ################################################################################ ### Set exit duration based on variable time allocated for this specific job ### @@ -91,14 +94,17 @@ options=" \ --no-seq-len-plus-one-tokens \ --seed ${RANDOM} " +EXTERNAL_CHECKPOINT_MOUNT="" [ ${USE_BF16} = true ] && options+=" --bf16" if [ -n "${EXTERNAL_MODEL_CHECKPOINT_DIR}" ]; then options+=" \ --no-load-rng \ --use-ext-ckpt \ + --use-distributed-checkpointing \ --ext-iterations $(( $EXTERNAL_TRAINING_ITERATIONS * $EXTERNAL_GBS / $GBS)) \ --ext-lr-steps $(( $EXTERNAL_TRAINING_ITERATIONS * $EXTERNAL_GBS)) \ --load ${EXTERNAL_MODEL_CHECKPOINT_DIR}" + EXTERNAL_CHECKPOINT_MOUNT=",${EXTERNAL_MODEL_CHECKPOINT_DIR}:${EXTERNAL_MODEL_CHECKPOINT_DIR}" else options+=" --load ${CHECKPOINT_DIR}" fi @@ -109,7 +115,7 @@ DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` srun -l \ --container-image $CONT \ - --container-mounts "$PWD:$PWD,${COM_DIR}:${COM_DIR},${LOG_DIR}:${LOG_DIR},${BPE_DIR}:${BPE_DIR}" \ + --container-mounts "$PWD:$PWD,${COM_DIR}:${COM_DIR},${LOG_DIR}:${LOG_DIR},${BPE_DIR}:${BPE_DIR}${EXTERNAL_CHECKPOINT_MOUNT}" \ --output=$LOG_DIR/GPT3-175B-runlog-$DATETIME.log sh -c "${run_cmd}" set +x