diff --git a/src/levanter/eval.py b/src/levanter/eval.py index 555dd1466..99e132dc2 100644 --- a/src/levanter/eval.py +++ b/src/levanter/eval.py @@ -63,7 +63,7 @@ def __init__( self.datasets = [] tag_index: dict[str, int] = {} for i, (dataset, tags) in enumerate(datasets): - if tags is None: + if not tags and len(datasets) > 1: warnings.warn("Dataset has no tags. Giving it an index") tags = [f"domain_{i}"] for tag in tags: @@ -204,14 +204,16 @@ def eval_callback(step: StepInfo): } logger.info(f"{prefix} loss: {result.micro_avg_loss:.3f}") - for tag, loss in result.tag_macro_losses.items(): - # don't log leaf tag macro losses because it doesn't mean anything different than micro loss - if tag in evaluator.dataset.tag_to_index: - continue - if not tag: - continue - log_dict[_join_prefix(prefix, tag) + "/macro_loss"] = loss - logger.info(f"{tag} macro loss: {loss:.3f}") + has_tags = len(evaluator.dataset.tag_to_index) > 1 # 1 tag means there's no difference between micro and macro + if has_tags: + for tag, loss in result.tag_macro_losses.items(): + # don't log leaf tag macro losses because it doesn't mean anything different than micro loss + if tag in evaluator.dataset.tag_to_index: + continue + if not tag: + continue + log_dict[_join_prefix(prefix, tag) + "/macro_loss"] = loss + logger.info(f"{tag} macro loss: {loss:.3f}") for tag, loss in result.tag_micro_losses.items(): if not tag: @@ -225,11 +227,14 @@ def eval_callback(step: StepInfo): if tokenizer is not None: log_dict[_join_prefix(prefix, "bpb")] = result.micro_bpb - log_dict[_join_prefix(prefix, "macro_bpb")] = result.macro_bpb + if has_tags: + log_dict[_join_prefix(prefix, "macro_bpb")] = result.macro_bpb for tag, bpb in result.tag_micro_bpb.items(): log_dict[_join_prefix(prefix, tag) + "/bpb"] = bpb - for tag, bpb in result.tag_macro_bpb.items(): - log_dict[_join_prefix(prefix, tag) + "/macro_bpb"] = bpb + + if has_tags: + for tag, bpb in result.tag_macro_bpb.items(): + log_dict[_join_prefix(prefix, tag) + "/macro_bpb"] = bpb levanter.tracker.log_metrics(log_dict, step=step.step) @@ -304,26 +309,29 @@ def accum_for_batch(m: LmHeadModel, state: _EvalRunningMeans, batch: LmExample, this_loss_per_tag = hax.einsum("-> tag", mask, losses, tags) # [Tag] mean = state.token_avg_loss.add(this_loss / this_tokens, this_tokens) - # careful: this_tokens_per_tag can be 0 if there are no tokens for that tag - safe_mean = hax.where(this_tokens_per_tag, this_loss_per_tag / this_tokens_per_tag, 0.0) - mean_per_tag = state.loss_per_tag.add(safe_mean, this_tokens_per_tag) + state = dataclasses.replace(state, token_avg_loss=mean) - state = dataclasses.replace(state, token_avg_loss=mean, loss_per_tag=mean_per_tag) + if len(self.dataset.tag_to_index) > 0: + # careful: this_tokens_per_tag can be 0 if there are no tokens for that tag + safe_mean = hax.where(this_tokens_per_tag, this_loss_per_tag / this_tokens_per_tag, 0.0) + mean_per_tag = state.loss_per_tag.add(safe_mean, this_tokens_per_tag) + state = dataclasses.replace(state, loss_per_tag=mean_per_tag) if self.bytes_per_token is not None: next_tokens = hax.roll(batch.tokens, -1, m.Pos) # [Batch, Pos], rolled by 1 for next token task bytes_per_pos = self.bytes_per_token.take("vocab", next_tokens) # [Batch, Pos] - bytes_per_pos = bytes_per_pos * mask # [Batch, Pos] - bytes_per_tag = hax.einsum("-> tag", bytes_per_pos, tags) # [Tag] - total_bytes = hax.sum(bytes_per_tag) + bytes_per_tag = hax.einsum("-> tag", mask, bytes_per_pos, tags) # [Tag] + this_bytes = hax.einsum("->", bytes_per_pos, mask) # Scalar # log loss -> bits is log2(e) * loss bpb_per_tag = this_loss_per_tag / hax.maximum(bytes_per_tag, 1) * jnp.log2(jnp.e) - bpb = this_loss / hax.maximum(total_bytes, 1) * jnp.log2(jnp.e) + bpb = this_loss / hax.maximum(this_bytes, 1) * jnp.log2(jnp.e) bpb_mean = state.bpb.add(bpb, this_tokens) - bpb_per_tag_mean = state.bpb_per_tag.add(bpb_per_tag, this_tokens_per_tag) - state = dataclasses.replace(state, bpb=bpb_mean, bpb_per_tag=bpb_per_tag_mean) + state = dataclasses.replace(state, bpb=bpb_mean) + if len(self.dataset.tag_to_index) > 0: + bpb_per_tag_mean = state.bpb_per_tag.add(bpb_per_tag, this_tokens_per_tag) + state = dataclasses.replace(state, bpb_per_tag=bpb_per_tag_mean) return state