From 7c22554ef65980147419e8a9393a228cf633db8a Mon Sep 17 00:00:00 2001 From: Equim Date: Thu, 3 Nov 2022 01:31:47 +0800 Subject: [PATCH] `named_parameters` does not have to be recursive --- mingpt/model.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/mingpt/model.py b/mingpt/model.py index 83ee22dc..e29f9844 100644 --- a/mingpt/model.py +++ b/mingpt/model.py @@ -226,20 +226,16 @@ def configure_optimizers(self, train_config): whitelist_weight_modules = (torch.nn.Linear, ) blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) for mn, m in self.named_modules(): - for pn, p in m.named_parameters(): - fpn = '%s.%s' % (mn, pn) if mn else pn # full param name - # random note: because named_modules and named_parameters are recursive - # we will see the same tensors p many many times. but doing it this way - # allows us to know which parent module any tensor p belongs to... + for pn, _ in m.named_parameters(prefix=mn, recurse=False): if pn.endswith('bias'): # all biases will not be decayed - no_decay.add(fpn) + no_decay.add(pn) elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): # weights of whitelist modules will be weight decayed - decay.add(fpn) + decay.add(pn) elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): # weights of blacklist modules will NOT be weight decayed - no_decay.add(fpn) + no_decay.add(pn) # validate that we considered every parameter param_dict = {pn: p for pn, p in self.named_parameters()}