From 87a09220c428ce0aa177344f90ed5cc669277533 Mon Sep 17 00:00:00 2001 From: huyiwen <1020030101@qq.com> Date: Tue, 6 Aug 2024 21:09:12 +0800 Subject: [PATCH] [fix] get vocab_size from logits --- utilization/model/huggingface_model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/utilization/model/huggingface_model.py b/utilization/model/huggingface_model.py index 6fe748db..b6a97a1c 100644 --- a/utilization/model/huggingface_model.py +++ b/utilization/model/huggingface_model.py @@ -439,10 +439,11 @@ def get_ppl_with_cache( exact_match: bool = False, ) -> List[Tuple[float, int]]: logits, labels, input_lengths = self.get_cache(batched_targets, prefix_cache, return_caches=False) + vocab_size = logits.shape[-1] last_logits = torch.cat(prefix_cache.next_logits, dim=0).to(logits.device) shift_logits = torch.cat([last_logits, logits[:, :-1]], dim=-2) labels[labels == self.tokenizer.pad_token_id] = -100 - probs = self.loss_fct(shift_logits.view(-1, self.model.config.vocab_size), + probs = self.loss_fct(shift_logits.view(-1, vocab_size), labels.view(-1)).view(labels.size(0), -1) if exact_match: @@ -508,10 +509,11 @@ def get_ppl( logits = self.model( input_ids=batched_encodings["input_ids"], attention_mask=batched_encodings["attention_mask"] ).logits + vocab_size = logits.shape[-1] shift_logits = logits.detach()[:, :-1].contiguous() shift_labels = batched_encodings["input_ids"][:, 1:].contiguous() shift_labels[shift_labels == self.tokenizer.pad_token_id] = -100 - probs = self.loss_fct(shift_logits.view(-1, self.model.config.vocab_size), + probs = self.loss_fct(shift_logits.view(-1, vocab_size), shift_labels.view(-1)).view(shift_labels.size(0), -1).cpu() tgt_starts = [None] * len(batched_inputs)