Skip to content

Commit

Permalink
correct total byte calculation for
Browse files Browse the repository at this point in the history
bpb when there are no tags
  • Loading branch information
dlwh committed Nov 6, 2024
1 parent f53c991 commit 83b6471
Showing 1 changed file with 30 additions and 22 deletions.
52 changes: 30 additions & 22 deletions src/levanter/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
self.datasets = []
tag_index: dict[str, int] = {}
for i, (dataset, tags) in enumerate(datasets):
if tags is None:
if not tags and len(datasets) > 1:
warnings.warn("Dataset has no tags. Giving it an index")
tags = [f"domain_{i}"]
for tag in tags:
Expand Down Expand Up @@ -204,14 +204,16 @@ def eval_callback(step: StepInfo):
}

logger.info(f"{prefix} loss: {result.micro_avg_loss:.3f}")
for tag, loss in result.tag_macro_losses.items():
# don't log leaf tag macro losses because it doesn't mean anything different than micro loss
if tag in evaluator.dataset.tag_to_index:
continue
if not tag:
continue
log_dict[_join_prefix(prefix, tag) + "/macro_loss"] = loss
logger.info(f"{tag} macro loss: {loss:.3f}")
has_tags = len(evaluator.dataset.tag_to_index) > 1 # 1 tag means there's no difference between micro and macro
if has_tags:
for tag, loss in result.tag_macro_losses.items():
# don't log leaf tag macro losses because it doesn't mean anything different than micro loss
if tag in evaluator.dataset.tag_to_index:
continue
if not tag:
continue
log_dict[_join_prefix(prefix, tag) + "/macro_loss"] = loss
logger.info(f"{tag} macro loss: {loss:.3f}")

for tag, loss in result.tag_micro_losses.items():
if not tag:
Expand All @@ -225,11 +227,14 @@ def eval_callback(step: StepInfo):

if tokenizer is not None:
log_dict[_join_prefix(prefix, "bpb")] = result.micro_bpb
log_dict[_join_prefix(prefix, "macro_bpb")] = result.macro_bpb
if has_tags:
log_dict[_join_prefix(prefix, "macro_bpb")] = result.macro_bpb
for tag, bpb in result.tag_micro_bpb.items():
log_dict[_join_prefix(prefix, tag) + "/bpb"] = bpb
for tag, bpb in result.tag_macro_bpb.items():
log_dict[_join_prefix(prefix, tag) + "/macro_bpb"] = bpb

if has_tags:
for tag, bpb in result.tag_macro_bpb.items():
log_dict[_join_prefix(prefix, tag) + "/macro_bpb"] = bpb

levanter.tracker.log_metrics(log_dict, step=step.step)

Expand Down Expand Up @@ -304,26 +309,29 @@ def accum_for_batch(m: LmHeadModel, state: _EvalRunningMeans, batch: LmExample,
this_loss_per_tag = hax.einsum("-> tag", mask, losses, tags) # [Tag]

mean = state.token_avg_loss.add(this_loss / this_tokens, this_tokens)
# careful: this_tokens_per_tag can be 0 if there are no tokens for that tag
safe_mean = hax.where(this_tokens_per_tag, this_loss_per_tag / this_tokens_per_tag, 0.0)
mean_per_tag = state.loss_per_tag.add(safe_mean, this_tokens_per_tag)
state = dataclasses.replace(state, token_avg_loss=mean)

state = dataclasses.replace(state, token_avg_loss=mean, loss_per_tag=mean_per_tag)
if len(self.dataset.tag_to_index) > 0:
# careful: this_tokens_per_tag can be 0 if there are no tokens for that tag
safe_mean = hax.where(this_tokens_per_tag, this_loss_per_tag / this_tokens_per_tag, 0.0)
mean_per_tag = state.loss_per_tag.add(safe_mean, this_tokens_per_tag)
state = dataclasses.replace(state, loss_per_tag=mean_per_tag)

if self.bytes_per_token is not None:
next_tokens = hax.roll(batch.tokens, -1, m.Pos) # [Batch, Pos], rolled by 1 for next token task
bytes_per_pos = self.bytes_per_token.take("vocab", next_tokens) # [Batch, Pos]
bytes_per_pos = bytes_per_pos * mask # [Batch, Pos]
bytes_per_tag = hax.einsum("-> tag", bytes_per_pos, tags) # [Tag]
total_bytes = hax.sum(bytes_per_tag)
bytes_per_tag = hax.einsum("-> tag", mask, bytes_per_pos, tags) # [Tag]
this_bytes = hax.einsum("->", bytes_per_pos, mask) # Scalar

# log loss -> bits is log2(e) * loss
bpb_per_tag = this_loss_per_tag / hax.maximum(bytes_per_tag, 1) * jnp.log2(jnp.e)
bpb = this_loss / hax.maximum(total_bytes, 1) * jnp.log2(jnp.e)
bpb = this_loss / hax.maximum(this_bytes, 1) * jnp.log2(jnp.e)

bpb_mean = state.bpb.add(bpb, this_tokens)
bpb_per_tag_mean = state.bpb_per_tag.add(bpb_per_tag, this_tokens_per_tag)
state = dataclasses.replace(state, bpb=bpb_mean, bpb_per_tag=bpb_per_tag_mean)
state = dataclasses.replace(state, bpb=bpb_mean)
if len(self.dataset.tag_to_index) > 0:
bpb_per_tag_mean = state.bpb_per_tag.add(bpb_per_tag, this_tokens_per_tag)
state = dataclasses.replace(state, bpb_per_tag=bpb_per_tag_mean)

return state

Expand Down

0 comments on commit 83b6471

Please sign in to comment.