-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add utility functions for text processing and visualization #17
Conversation
I've added an adapted version of https://github.com/jettjaniak/tinyevals/pull/25 by @siwei-li in |
let's move it to delphi/evals/vis.py |
As mentioned in the original PR on old repo, I don't think it's worth optimizing for performance in the visualization code. It feel it could be much simpler if we dropped the vectorization |
You can add a wrapper with squeeze/unsqueeze to utils
…On Thu, 8 Feb 2024, 16:34 Siwei Li, ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In src/delphi/vis/utils.py
<#17 (comment)>:
> +def get_logits(model, sample_tok):
+ sample_tok = sample_tok.unsqueeze(0)
+ return model(sample_tok).logits[0]
+
+
+def get_probs(model, sample_tok):
+ logits = get_logits(model, sample_tok)
+ # drop the value for the last position, as we don't know
+ # what is the correct next token there
+ # pos, d_vocab
+ return torch.softmax(logits, dim=-1)[:-1]
+
+
+def get_correct_probs(model, sample_tok):
+ probs = get_probs(model, sample_tok)
+ # out of d_vocab values, take the one that corresponds to the correct next token
+ return probs[range(len(probs)), sample_tok[1:]]
+
+
+def get_correct_and_all_probs(model, sample_tok):
+ """Get probabilities for the actual next token and for all predictions"""
+ probs = get_probs(model, sample_tok)
+ correct_probs = probs[range(len(probs)), sample_tok[1:]]
+ return correct_probs, probs
So I tried to use the get_next_logprobs() but we'd need a lot of
squeeze() and unsqueeze() that way, could we keep these functions for 1-d
token sequences just for visualization?
—
Reply to this email directly, view it on GitHub
<#17 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AECKP37WMVNLGKZW6PYSE23YSVVIBAVCNFSM6AAAAABCT7GQESVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMYTQNZRGQ4DIMZRG4>
.
You are receiving this because your review was requested.Message ID:
***@***.***>
|
I think you can use sth like pyproject.toml to configure isort in all 3 instances (pre-commit, CI, vscode), but I'm fine with either solution |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for making all of these changes!
I suggested a bunch of simplifications
Everything should be addressed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thank you! Commented on one bug, please make sure to correct it before merging.
src/delphi/eval/compare_models.py
Outdated
next_probs_a, probs_a = get_correct_and_all_probs(model_a, sample_tok) | ||
next_probs_b, probs_b = get_correct_and_all_probs(model_b, sample_tok) | ||
probs_a, next_probs_a = get_all_and_next_logprobs_single(model_a, sample_tok) | ||
probs_b, next_probs_b = get_all_and_next_logprobs_single(model_b, sample_tok) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this probs or logprobs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
! Good catch!
def get_next_and_top_k_probs( | ||
model: PreTrainedModel, input_ids: Int[torch.Tensor, "seq"], k: int = 3 | ||
) -> tuple[Float[torch.Tensor, "shorter_seq"], torch.return_types.topk,]: | ||
all_logprobs, next_logprobs = get_all_and_next_logprobs_single(model, input_ids) | ||
all_probs = torch.exp(all_logprobs) | ||
next_probs = torch.exp(next_logprobs) | ||
top_k = torch.topk(all_probs, k, dim=-1) | ||
return next_probs, top_k |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you using this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the demo notebook
This partially addresses #9 (jettjaniak/tinyevals#3)
Including a notebook to demonstrate usage