Skip to content

Commit

Permalink
fix: remove torch tensor in gemini wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
ledong0110 committed Jul 31, 2024
1 parent 897fbac commit 6b52e19
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions src/vieval/tools/wrapper/GeminiWrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import torch
import json
import os
import openai
Expand Down Expand Up @@ -49,7 +48,7 @@ def __init__(self, model_name=None, generation_config=None):

def __call__(self, prompts, return_probs=False):
generations = []
generations_probs = [torch.tensor([])] * len(prompts)
generations_probs = [[]] * len(prompts)
num_generated_tokens = []
for prompt in prompts:
processed_prompt = [list(p.values())[1] for p in prompt]
Expand All @@ -74,7 +73,7 @@ def __call__(self, prompts, return_probs=False):

def compute_logprob_and_length(self, prompts, completions):
completions_num_tokens = [0] * len(prompts)
completions_logprobs = [torch.tensor([])] * len(prompts)
completions_logprobs = [[]] * len(prompts)
# Not Implement
return completions_logprobs, completions_num_tokens

Expand Down

0 comments on commit 6b52e19

Please sign in to comment.