Skip to content

Commit

Permalink
Fix training data shifting bug
Browse files Browse the repository at this point in the history
  • Loading branch information
jaidhyani committed Apr 27, 2024
1 parent 3281227 commit 161c3f0
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/delphi/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ def get_xy_batch(
end = (batch_num + 1) * batch_size
batch_indices = indices[start:end]
data = dataset[batch_indices][feature_name].to(device)
return data[:, :-1], data[:, 1:]
# *ForCausalLM models do shifting internally, so input and labels are the same
return data, data


def gen_minibatches(
Expand Down

0 comments on commit 161c3f0

Please sign in to comment.