Skip to content

Commit

Permalink
draft from call
Browse files Browse the repository at this point in the history
  • Loading branch information
jettjaniak committed Feb 21, 2024
1 parent 75e68aa commit 5067032
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions scripts/inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import os

import numpy as np
import pandas as pd
import torch
from datasets import Dataset, load_dataset
Expand Down Expand Up @@ -35,27 +36,31 @@ def main(

model = AutoModelForCausalLM.from_pretrained(model_name)

logprobs_list = []
# logprobs_list = []
total_sequences = (
len(val_ds) if not funct_test else 320
) # Use only 320 sequences if funct_test is True

logprobs = np.empty((total_sequences, 513))
logprobs[:, 0] = float("nan")
for i in tqdm(range(0, total_sequences, batch_size)):
batch_end = min(i + batch_size, total_sequences)
batch_sequences = [val_ds[j]["tokens"] for j in range(i, batch_end)]
batch_sequences_tensor = torch.tensor(batch_sequences)

_, next_logprobs = get_all_and_next_logprobs(model, batch_sequences_tensor)
logprobs_list.append(next_logprobs)
logprobs_tensor = get_all_and_next_logprobs(model, batch_sequences_tensor)[1]
logprobs[i:batch_end, 1:] = logprobs_tensor.cpu().numpy()

accumulated_logprobs = torch.cat(logprobs_list, dim=0)
# logprobs_list.append(next_logprobs)

nan_tensor = torch.full((accumulated_logprobs.size(0), 1), float("nan"))
extended_next_logprobs = torch.cat(
[nan_tensor, accumulated_logprobs], dim=1
) # 513 tokens
# accumulated_logprobs = torch.cat(logprobs_list, dim=0)

df_dataset = pd.DataFrame({"logprobs": extended_next_logprobs.tolist()})
# nan_tensor = torch.full((accumulated_logprobs.size(0), 1), float("nan"))
# extended_next_logprobs = torch.cat(
# [nan_tensor, accumulated_logprobs], dim=1
# ) # 513 tokens

df_dataset = pd.DataFrame({"logprobs": logprobs.tolist()})
hf_dataset = Dataset.from_pandas(df_dataset)

# change the repo_id to your hf username in generate_logprobs.sh
Expand Down

0 comments on commit 5067032

Please sign in to comment.