diff --git a/mingpt/model.py b/mingpt/model.py index 83ee22dc..97470b26 100644 --- a/mingpt/model.py +++ b/mingpt/model.py @@ -110,6 +110,8 @@ def get_default_config(): C.embd_pdrop = 0.1 C.resid_pdrop = 0.1 C.attn_pdrop = 0.1 + # parameter dtype + C.dtype = torch.float32 return C def __init__(self, config): @@ -118,6 +120,15 @@ def __init__(self, config): assert config.block_size is not None self.block_size = config.block_size + if isinstance(config.dtype, str): + try: + config.dtype = getattr(torch, config.dtype) + except: + raise ValueError(f"Unknown dtype {config.dtype}") + # check that the dtype is a floating point + self.dtype = config.dtype + assert torch.is_floating_point(self.dtype) + type_given = config.model_type is not None params_given = all([config.n_layer is not None, config.n_head is not None, config.n_embd is not None]) assert type_given ^ params_given # exactly one of these (XOR) @@ -170,6 +181,24 @@ def _init_weights(self, module): elif isinstance(module, nn.LayerNorm): torch.nn.init.zeros_(module.bias) torch.nn.init.ones_(module.weight) + module = module.to(self.dtype) + + def get_memory_footprint(self, return_buffers=True): + r""" + Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. + Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the + PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2 + Arguments: + return_buffers (`bool`, *optional*): + Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers + are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch + norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2 + """ + mem = sum([param.nelement() * param.element_size() for param in self.parameters()]) + if return_buffers: + mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()]) + mem = mem + mem_bufs + return mem @classmethod def from_pretrained(cls, model_type): diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py new file mode 100644 index 00000000..b0f60b10 --- /dev/null +++ b/tests/test_modeling_gpt2.py @@ -0,0 +1,36 @@ +""" +Some tests for minGPT +""" + +import unittest +import torch +from mingpt.model import GPT + +class GPT2Tester(unittest.TestCase): + + def test_dtypes(self): + """ + Dtype tests for GPT2 model + """ + config_fp16 = GPT.get_default_config() + config_fp16.merge_from_dict({'dtype':'float16', 'vocab_size':50257, 'block_size':1024}) + config_fp16.model_type = 'gpt2' + + config_fp32 = GPT.get_default_config() + config_fp32.merge_from_dict({'vocab_size':50257, 'block_size':1024}) + config_fp32.model_type = 'gpt2' + + + model_fp16 = GPT(config_fp16) + model_fp32 = GPT(config_fp32) + + # Check whether the dtype has been checked correctly + self.assertTrue(model_fp16.dtype == torch.float16) + self.assertTrue(model_fp32.dtype == torch.float32) + + # Checck whether the memory footprint is half of the fp32 model + self.assertTrue(model_fp16.get_memory_footprint() == model_fp32.get_memory_footprint() // 2) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file