Skip to content

Commit

Permalink
add bits-per-byte calculation to levanter (#729)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Sep 18, 2024
1 parent 79fa64c commit 07b3f16
Show file tree
Hide file tree
Showing 8 changed files with 235 additions and 30 deletions.
136 changes: 115 additions & 21 deletions src/levanter/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import defaultdict
from typing import Callable, Mapping, Optional, Sequence, TypeVar

import equinox as eqx
import jax.numpy as jnp
import jmp
import numpy as np
Expand All @@ -19,7 +20,8 @@
from levanter.logging import LoadingTimeTrackerIterator
from levanter.models.lm_model import LmExample, LmHeadModel, compute_next_token_loss
from levanter.trainer import StepInfo
from levanter.utils.stat_utils import RunningMean
from levanter.utils.hf_utils import HfTokenizer, byte_length_of_token
from levanter.utils.stat_utils import Arrayish, RunningMean
from levanter.utils.tree_utils import inference_mode


Expand All @@ -37,6 +39,10 @@ class EvalResult:
tag_macro_losses: dict[str, float] # per tag average-per-token loss
tag_micro_losses: dict[str, float] # per tag total loss, for "parent" tags
total_eval_loading_time: float
micro_bpb: Optional[float] = None
macro_bpb: Optional[float] = None
tag_macro_bpb: Optional[dict[str, float]] = None
tag_micro_bpb: Optional[dict[str, float]] = None


# This class doesn't try to be async or work with incomplete datasets, because it's eval
Expand Down Expand Up @@ -152,6 +158,7 @@ def _join_prefix(prefix: str, tag: str) -> str:
def cb_tagged_lm_evaluate(
EvalBatch: hax.Axis,
tagged_eval_sets: Sequence[tuple[AsyncDataset[LmExample], Sequence[str]]],
tokenizer: Optional[HfTokenizer] = None,
device_mesh: Optional[Mesh] = None,
axis_mapping: ResourceMapping = None,
max_examples_per_dataset: Optional[int] = None,
Expand All @@ -173,12 +180,15 @@ def cb_tagged_lm_evaluate(
Args:
EvalBatch: The axis for the evaluation batch (mostly for the batch size)
tagged_eval_sets: A list of datasets, each with its own domain tag
tokenizer: The tokenizer to use for bits-per-byte evaluation (optional)
device_mesh: The mesh to use for evaluation
axis_mapping: The axis mapping to use for evaluation
max_examples_per_dataset: The maximum number of examples to use from each dataset
prefix: The prefix to use for logging the losses
"""

evaluator = TaggedEvaluator(
EvalBatch, tagged_eval_sets, device_mesh, axis_mapping, max_examples_per_dataset, mp=mp
EvalBatch, tagged_eval_sets, tokenizer, device_mesh, axis_mapping, max_examples_per_dataset, mp=mp
)

def eval_callback(step: StepInfo):
Expand Down Expand Up @@ -213,6 +223,14 @@ def eval_callback(step: StepInfo):
log_dict[_join_prefix(prefix, tag) + "/micro_loss"] = loss
logger.info(f"{tag} micro loss: {loss:.3f}")

if tokenizer is not None:
log_dict[_join_prefix(prefix, "bpb")] = result.micro_bpb
log_dict[_join_prefix(prefix, "macro_bpb")] = result.macro_bpb
for tag, bpb in result.tag_micro_bpb.items():
log_dict[_join_prefix(prefix, tag) + "/bpb"] = bpb
for tag, bpb in result.tag_macro_bpb.items():
log_dict[_join_prefix(prefix, tag) + "/macro_bpb"] = bpb

levanter.tracker.log_metrics(log_dict, step=step.step)

return result
Expand All @@ -225,6 +243,8 @@ class TaggedEvaluator:
Evaluates multiple tagged datasets using a given evaluation function.
Scores for each tag are aggregated and logged separately, as well as getting an overall score.
TaggedEvaluator computes both log-perplexity and bits-per-byte for each tag, if a tokenizer is provided.
Tags are arranged hierarchically with "/" as separator, and we log both a micro and macro average loss
for each tag.
Expand All @@ -234,6 +254,7 @@ def __init__(
self,
EvalBatch: hax.Axis,
tagged_eval_sets: Sequence[tuple[AsyncDataset, Sequence[str]]],
tokenizer: Optional[HfTokenizer] = None,
device_mesh=None,
axis_mapping=None,
max_examples_per_dataset=None,
Expand All @@ -249,6 +270,8 @@ def __init__(
axis_resources=axis_mapping,
)
self.mp = mp
self.tokenizer = tokenizer
self.bytes_per_token = self._calculate_bytes_per_token_type(tokenizer)

# tags are arranged hierarchically with "/" as separator. We want to log the average loss for each tag.
hierarchy: dict[str, list[int]] = {}
Expand All @@ -264,37 +287,54 @@ def __init__(
self.hierarchy = hierarchy

@hax.named_jit(out_axis_resources=axis_mapping)
def accum_for_batch(
m: LmHeadModel, state: tuple[RunningMean, RunningMean], batch: LmExample, tags: hax.NamedArray
):
def accum_for_batch(m: LmHeadModel, state: _EvalRunningMeans, batch: LmExample, tags: hax.NamedArray):
m = inference_mode(m, True)

if self.mp is not None:
m = self.mp.cast_to_compute(m)

with hax.axis_mapping(axis_mapping):
total_mean, mean_per_tag = state
losses = compute_next_token_loss(m, batch, reduction=None, reduction_axis=())
mask = batch.loss_mask # [Batch, Token]
mask = batch.loss_mask # [Batch, Pos]
this_tokens = hax.sum(mask)
this_loss = hax.einsum("->", losses, mask) # to scalar

# all the *_per_tag variables are [Tag]
this_tokens_per_tag = hax.einsum("-> tag", mask, tags)
this_loss_per_tag = hax.einsum("-> tag", mask, losses, tags) # [Tag]

mean = total_mean.add(this_loss / this_tokens, this_tokens)
mean = state.token_avg_loss.add(this_loss / this_tokens, this_tokens)
# careful: this_tokens_per_tag can be 0 if there are no tokens for that tag
safe_mean = hax.where(this_tokens_per_tag, this_loss_per_tag / this_tokens_per_tag, 0.0)
mean_per_tag = mean_per_tag.add(safe_mean, this_tokens_per_tag)
mean_per_tag = state.loss_per_tag.add(safe_mean, this_tokens_per_tag)

state = dataclasses.replace(state, token_avg_loss=mean, loss_per_tag=mean_per_tag)

if self.bytes_per_token is not None:
next_tokens = hax.roll(batch.tokens, -1, m.Pos) # [Batch, Pos], rolled by 1 for next token task
bytes_per_pos = self.bytes_per_token.take("vocab", next_tokens) # [Batch, Pos]
bytes_per_pos = bytes_per_pos * mask # [Batch, Pos]
bytes_per_tag = hax.einsum("-> tag", bytes_per_pos, tags) # [Tag]
total_bytes = hax.sum(bytes_per_tag)

return mean, mean_per_tag
# log loss -> bits is log2(e) * loss
bpb_per_tag = this_loss_per_tag / hax.maximum(bytes_per_tag, 1) * jnp.log2(jnp.e)
bpb = this_loss / hax.maximum(total_bytes, 1) * jnp.log2(jnp.e)

bpb_mean = state.bpb.add(bpb, this_tokens)
bpb_per_tag_mean = state.bpb_per_tag.add(bpb_per_tag, this_tokens_per_tag)
state = dataclasses.replace(state, bpb=bpb_mean, bpb_per_tag=bpb_per_tag_mean)

return state

self.accum_for_batch = accum_for_batch

def evaluate(self, m: LmHeadModel):
total_loss = jnp.zeros(())
mean_losses_per_tag = hax.zeros(self.dataset.Tag, dtype=np.float32)

state = (RunningMean.zeros_like(total_loss), RunningMean.zeros_like(mean_losses_per_tag))
state = _EvalRunningMeans.zeros_like(total_loss, mean_losses_per_tag)
del total_loss, mean_losses_per_tag
state = hax.shard(state)

iterator = LoadingTimeTrackerIterator(self.loader)
Expand All @@ -304,19 +344,30 @@ def evaluate(self, m: LmHeadModel):
state = self.accum_for_batch(m, state, batch, tags)
n += 1

total_loss, losses_per_tag = state

micro_avg_loss = total_loss.mean.item()
tag_avg_loss = losses_per_tag.mean
micro_avg_loss = state.token_avg_loss.mean.item()
tag_avg_loss = state.loss_per_tag.mean

# TODO: why do i have to jit this
macro_avg_loss = hax.named_jit(lambda x: hax.mean(x).array)(tag_avg_loss).item()

tag_macro_loss = {}
tag_micro_loss = {}
if self.bytes_per_token is not None:
micro_bpb = state.bpb.mean.item()
tag_avg_bpb = state.bpb_per_tag.mean
macro_avg_bpb = hax.named_jit(lambda x: hax.mean(x).array)(tag_avg_bpb).item()
else:
micro_bpb = None
macro_avg_bpb = None

tag_macro_loss: dict[str, float] = {}
tag_micro_loss: dict[str, float] = {}
tag_macro_bpb: dict[str, float] = {}
tag_micro_bpb: dict[str, float] = {}

mean_loss_per_tag_cpu = np.array(losses_per_tag.mean.array) # type: ignore
total_tokens_per_tag_cpu = np.array(losses_per_tag.total.array) # type: ignore
mean_loss_per_tag_cpu = np.array(state.loss_per_tag.mean.array)
total_tokens_per_tag_cpu = np.array(state.loss_per_tag.mean.array)

mean_bits_per_tag_cpu = np.array(state.bpb_per_tag.mean.array)
total_bytes_per_tag_cpu = np.array(state.bpb_per_tag.mean.array)

# add in the hierarchy
for parent, children in self.hierarchy.items():
Expand All @@ -333,8 +384,51 @@ def evaluate(self, m: LmHeadModel):
# (average doesn't support where directly so we just 0 out the weights)
tag_micro_loss[parent] = np.average(mean_loss_per_tag_cpu, weights=total_tokens_per_tag_cpu * mask)

if self.bytes_per_token is not None:
tag_macro_bpb[parent] = np.mean(mean_bits_per_tag_cpu, where=mask)
tag_micro_bpb[parent] = np.average(mean_bits_per_tag_cpu, weights=total_bytes_per_tag_cpu * mask)

for tag, index in self.dataset.tag_to_index.items():
tag_micro_loss[tag] = mean_loss_per_tag_cpu[index]
tag_micro_loss[tag] = float(mean_loss_per_tag_cpu[index])
# no macro loss for the leaf tags

return EvalResult(micro_avg_loss, macro_avg_loss, tag_macro_loss, tag_micro_loss, iterator.total_time)
if self.bytes_per_token is not None:
tag_micro_bpb[tag] = float(mean_bits_per_tag_cpu[index])

return EvalResult(
micro_avg_loss,
macro_avg_loss,
tag_macro_loss,
tag_micro_loss,
iterator.total_time,
micro_bpb,
macro_avg_bpb,
tag_macro_bpb,
tag_micro_bpb,
)

def _calculate_bytes_per_token_type(self, tokenizer: HfTokenizer) -> Optional[hax.NamedArray]:
if tokenizer is None:
return None
else:
# calculate the number of bytes in each token
Vocab = hax.Axis("vocab", len(tokenizer.get_vocab()))
bytes = np.ndarray((Vocab.size,), dtype=np.int32)

for i in range(Vocab.size):
bytes[i] = byte_length_of_token(tokenizer, i)

return hax.named(jnp.array(bytes), Vocab)


class _EvalRunningMeans(eqx.Module):
token_avg_loss: RunningMean # average loss averaged over all tokens
loss_per_tag: RunningMean # average loss per tag
bpb: RunningMean # bits per byte averaged over all tokens
bpb_per_tag: RunningMean # bits per byte per tag

@staticmethod
def zeros_like(total: Arrayish, per_tag: Arrayish) -> "_EvalRunningMeans":
z = RunningMean.zeros_like(total)
per_tag = RunningMean.zeros_like(per_tag)
return _EvalRunningMeans(z, per_tag, z, per_tag)
1 change: 1 addition & 0 deletions src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def main(config: TrainLmConfig):
cb = levanter.eval.cb_tagged_lm_evaluate(
EvalBatch,
causal_datasets,
tokenizer,
trainer.device_mesh,
compute_axis_mapping,
max_eval_examples_per_ds,
Expand Down
30 changes: 30 additions & 0 deletions src/levanter/utils/hf_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import os
import re
from typing import TypeAlias

from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

from levanter.logging import silence_transformer_nag
from levanter.utils.py_utils import logical_cpu_core_count
Expand All @@ -8,6 +12,8 @@

_HF_TOKENIZER_OFF_VALUES = {"off", "false", "f", "no", "n", "0"}

HfTokenizer: TypeAlias = PreTrainedTokenizerFast | PreTrainedTokenizer


def num_cpus_used_by_tokenizer(tokenizer) -> int:
if getattr(tokenizer, "is_fast", False):
Expand All @@ -20,3 +26,27 @@ def num_cpus_used_by_tokenizer(tokenizer) -> int:
return min(max(1, logical_cpu_core_count() - 2), 12)
else:
return 1


def byte_length_of_token(tokenizer, idx: int) -> int:
# this is a pain because we want the prefix spaces, but we don't want extra noise for bytes
# e.g. in llama
# >>> t.convert_ids_to_tokens(q[2])
# '▁this'
# >>> t.convert_ids_to_tokens(25)
# '<0x16>'
# We want the _ (as a single byte, not the 3 it's encoded as) but not the <0x16>, which should instead be a single byte \x16
# decode strips the prefix spaces, but does correctly handle the <0x16> case
# we can avoid prefix space issues by prepending another token before decoding, then stripping
repr = tokenizer.convert_ids_to_tokens(idx)
if idx in tokenizer.all_special_ids:
# NB: special tokens don't have bytes, but they contribute to perplexity/bits
return 0
# handle bytes specially. This is a bit of a hack, but there's no other way
elif m := re.match(r"<0x([0-9A-Fa-f]+)>", repr):
return len(bytes.fromhex(m.group(1)))
else:
extra_token = tokenizer(".", add_special_tokens=False)["input_ids"][0]
excess_bytes = len(".".encode("utf-8"))
decoded = tokenizer.decode([extra_token, idx]).encode("utf-8")
return len(decoded) - excess_bytes
2 changes: 1 addition & 1 deletion src/levanter/utils/stat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import haliax as hax


Arrayish: typing.TypeAlias = hax.NamedArray | np.ndarray | jnp.ndarray | float
Arrayish: typing.TypeAlias = hax.NamedArray | np.ndarray | jnp.ndarray


class RunningMean(eqx.Module):
Expand Down
23 changes: 20 additions & 3 deletions src/levanter/utils/thread_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,41 @@ class AsyncIteratorWrapper(Iterator):
def __init__(self, async_iter):
self.async_iter = async_iter
self.loop = asyncio.new_event_loop()
self.executor = ThreadPoolExecutor(max_workers=1)
self.thread = threading.Thread(target=self._run_loop, daemon=True)
self.thread.start()
self._exhausted = False # Flag to indicate if the iterator is exhausted

def _run_loop(self):
asyncio.set_event_loop(self.loop)
self.loop.run_forever()

def _run_async_task(self, coro):
return asyncio.run_coroutine_threadsafe(coro, self.loop).result()
if not self.loop.is_running() or not self.thread.is_alive():
raise StopIteration
try:
future = asyncio.run_coroutine_threadsafe(coro, self.loop)
return future.result()
except (RuntimeError, asyncio.CancelledError):
raise StopIteration

def __iter__(self):
return self

def __next__(self):
if self._exhausted:
raise StopIteration
try:
return self._run_async_task(self.async_iter.__anext__())
except StopAsyncIteration:
self.loop.call_soon_threadsafe(self.loop.stop)
self._exhausted = True # Mark the iterator as exhausted
if self.loop.is_running():
self.loop.call_soon_threadsafe(self.loop.stop)
self.thread.join()
raise StopIteration

def close(self):
"""Close the event loop and thread gracefully."""
if self.loop.is_running():
self.loop.call_soon_threadsafe(self.loop.stop)
self.thread.join()
self.loop.close()
Loading

0 comments on commit 07b3f16

Please sign in to comment.