Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference script #16

Merged

Conversation

transcendingvictor
Copy link
Collaborator

Two new files in the Scripts folder. A python module that takes for a given model and a given split (usually "validation), and optionally a given batchsize, creates a dataset (.parquet file) with the log probabilities for the correct next token. The bash script inlcudes the name of all the llama models (from 100k to 25.6m) and iterates over them calling the python module. The datasets are created in a new folder called "Correct logprobs".

@transcendingvictor transcendingvictor linked an issue Jan 31, 2024 that may be closed by this pull request
@transcendingvictor transcendingvictor changed the title 10 script to run inference on whole validation dataset [DRAFT] 10 script to run inference on whole validation dataset Jan 31, 2024
Copy link
Collaborator

@joshuawe joshuawe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments with either simple questions or mention of a best practice that could (optionally) be improved.

  • What I find necessary are typehints and docstrings. That is good for you and others to understand you code :)
  • Some of your files seem to be redundant. There are files in src/delphi/eval/ and in scripts/ and I do not think we need both of them. It probably makes sense to have the ìnference.pyin theeval/folder and the bash script in thescripts/` folder.
  • Is the file ìnference_on_validation.py` necessary anymore?

scripts/inference.py Outdated Show resolved Hide resolved
scripts/inference.py Outdated Show resolved Hide resolved
scripts/inference.py Outdated Show resolved Hide resolved
scripts/inference.py Outdated Show resolved Hide resolved
scripts/inference.py Outdated Show resolved Hide resolved
src/delphi/eval/inference_on_validation.py Outdated Show resolved Hide resolved
scripts/inference.py Outdated Show resolved Hide resolved
src/delphi/eval/generate_logprobs.sh Outdated Show resolved Hide resolved
src/delphi/eval/inference.py Outdated Show resolved Hide resolved
scripts/inference.py Outdated Show resolved Hide resolved
@jettjaniak
Copy link
Contributor

please see #24, rebase on top of main and use the function introduced there

@jettjaniak
Copy link
Contributor

same for #25

@transcendingvictor transcendingvictor force-pushed the 10-script-to-run-inference-on-whole-validation-dataset branch from e5a29da to d36a135 Compare February 13, 2024 21:27
@transcendingvictor transcendingvictor marked this pull request as ready for review February 15, 2024 20:52
Copy link
Contributor

@jettjaniak jettjaniak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a few small comments - please implement them and merge

output_file = os.path.join(output_folder, f'{model_name.replace("/", "-")}.parquet')
accumulated_df.to_parquet(output_file)
_, next_logprobs = get_all_and_next_logprobs(model, val_sequences)
accumulated_logprobs = torch.cat((accumulated_logprobs, next_logprobs), dim=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach is pretty inefficient, you copy all the data accumulated so far in each step. Instead you could append next_logprobs to logprobs_list and then at the end have all_logprobs = torch.cat(logprobs_list, dim=0).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But actually, you're converting it all to list anyway to make a DataFrame. But then Pandas stores it's data as numpy arrays 🙈 It's fine for now, but please remind me on 1-1 to chat about this.

df_dataset = pd.DataFrame({"logprobs": extended_next_logprobs.tolist()})
hf_dataset = Dataset.from_pandas(df_dataset)

# change the repo_id to your hf username
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that should be an argument

parser.add_argument(
"--batch_size",
type=int,
default=80,
help="Batch size for processing (default: 80)",
)

parser.add_argument(
"--dataset_name",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"--dataset_name",
"--dataset-name",

@jettjaniak
Copy link
Contributor

also rebase required

@jettjaniak jettjaniak changed the title [DRAFT] 10 script to run inference on whole validation dataset inference script Feb 17, 2024
@transcendingvictor transcendingvictor force-pushed the 10-script-to-run-inference-on-whole-validation-dataset branch from e4bd41c to 67bf0f2 Compare February 17, 2024 12:42
Copy link
Contributor

@jettjaniak jettjaniak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A Few more comments

scripts/generate_logprobs.sh Show resolved Hide resolved
val_ds = load_validation_dataset(dataset_name)

# model accepts 2D tensors (batch_size, seq_len)
val_sequences = torch.tensor([s["tokens"] for s in val_ds])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes a copy of whole validation dataset, we don't need it


logprobs_list = []
for i in tqdm(range(0, len(val_sequences), batch_size)):
batch_sequences = val_sequences[i : i + batch_size]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use val_ds here to make a batch tensor. val_ds[tokens] [i:i+b] should work, otherwise with for loop

.gitignore Outdated Show resolved Hide resolved
@transcendingvictor transcendingvictor force-pushed the 10-script-to-run-inference-on-whole-validation-dataset branch from f556811 to 95f4272 Compare February 21, 2024 16:56
@transcendingvictor transcendingvictor merged commit 75e68aa into main Feb 21, 2024
1 check passed
@transcendingvictor transcendingvictor deleted the 10-script-to-run-inference-on-whole-validation-dataset branch February 21, 2024 17:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

inference script
3 participants