-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement multi-token prediction option for models #479
Comments
Implementing multi-token prediction in NanoGPT could be very amazing it would increase the sample efficiency so we can get more accurate results as well as models can perform speculative decoding, where they generate multiple possible sequences of tokens and then we can choose the most likely one |
Was anybody able to do this? I think we can add this functionality. |
I want to start working on this |
Even i would like to start and contribute on this multi token prediction. |
I'd love to start contributing too! I tried to think a bit, but I did not manage to find a finalized solution
In the forward pass, after lines 180-181, we need to perform the sequential pass described in Figure 2 of the paper. I was thinking something like this: # Old code, stays the same
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
x = self.transformer.drop(tok_emb + pos_emb)
for block in self.transformer.h:
x = block(x)
# New code. I am unclear about whether LN should go before or after the multi-token heads
# and unsure about how we actually implement Figure 2 here
if target is not None:
for block in self.transformer.hmt:
logits = self.lm_head(self.transformer.ln_f(block(x)))
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
loss = None I want to understand this (would love to hear people's opinions):
However, most importantly, in Figure 2 they do the backward pass directly. How do we adapt this to our code? I think the key idea is that we want to compute each loss independently (each MT head is independent) and therefore we want to accumulate the gradients and feed them back to the rest of the network. My code as is does not work because I am not accumulating gradients. |
Actually, perhaps a better way to do this is as follows, this should accumulate gradients # New code
logits_list = []
if targets is not None:
loss = 0
for head in self.transformer.hmt:
logits = self.lm_head(self.transformer.ln_f(head(x)))
logits_list.append(logits)
loss += F.cross_entropy(logits.view(-1, logits.size(-1)), targets[:, i].view(-1), ignore_index=-1)
else:
# inference-time optimization: only forward the lm_head on the very last position
for head in self.transformer.hmt:
logits = self.lm_head(self.transformer.ln_f(head(x[:, [-1], :]))) # note: using list [-1] to preserve the time dim
logits_list.append(logits)
loss = None
return logits_list, loss |
An important change is to change the batch for training. The context should have the same shape as before, but now we want targets of shape def get_batch(split, multi_token=False, F=1):
assert F >= 1, "number of future tokens must be at least 1."
assert multi_token and F == 1, "when multi_token is True, F must be larger than 1."
assert not multi_token and F > 1, "when next-token prediction is being used, F must be 1."
# We recreate np.memmap every batch to avoid a memory leak, as per
# https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
if split == 'train':
data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
else:
data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
y = torch.stack([torch.stack([data[i+j+1:i+j+1+block_size] for j in range(F)], dim=-1) for i in ix])
if not multi_token:
y.squeeze(-1)
if device_type == 'cuda':
# pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
else:
x, y = x.to(device), y.to(device)
return x, y The script above for # feed through each head separately
head_outputs = []
for head in self.transformer.hmt:
x = head(x)
head_outputs.append(x)
# Stack them together and use LayerNorm
x = torch.stack(head_outputs, dim=-2) # (B, T, n_multi_token, n_embd)
x = self.ln_f(x) # (B, T, n_multi_token, n_embd), works because it acts on the final dimension
# Final linear layer mapping (B, T, n_multi_token, n_embd) -> (B, T, n_multi_token, vocab_size)
logits = self.lm_head(x) # (B, T, n_multi_token, vocab_size)
if targets is None:
loss = None
else:
# Compute log-probabilities
log_probs = F.log_softmax(_logits, dim=-1).view(b*t*n_multi_token, self.config.vocab_size)
expanded_targets = targets.view(B*T*n_usable_heads, 1)
# Compute loss
log_probs_true_tokens = torch.gather(
input=log_probs, dim=-1, index=expanded_targets).squeeze(-1) # (B*T*n_multi_token, )
loss = - log_probs_true_tokens.mean() # scalar
return logits, loss A few edits might be necessary to make this work smoothly with the rest of the package. First of all, one should be able to choose whether to use only the next-token head or all the future tokens ones, especially at inference time. This also does not implement self-speculative decoding. |
Per the recent paper from Meta, it appears that models that predict multiple future tokens can exhibit significantly greater sample efficiency than models trained only on next-token prediction, plus the extra token heads can be used to implement speculative decoding to speed up inference (up to 3X in their experiments), without the need for a draft model.
It would be amazing to see multi-token prediction implemented in nanoGPT, as it would allow the community to easily experiment with this promising technique.
The text was updated successfully, but these errors were encountered: