Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] MLM Training Objective #680

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
39 changes: 39 additions & 0 deletions config/roberta-tiny.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
data:
id: dlwh/wikitext_103_detokenized
# train_urls:
# - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz"
# validation_urls:
# - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz"
cache_dir: "cache/roberta-tiny"
tokenizer: "roberta-base"

model:
type: roberta
vocab_size: 50265
hidden_size: 32
intermediate_size: 64
num_hidden_layers: 4
num_attention_heads: 2
max_position_embeddings: 512
hidden_act: "gelu"
hidden_dropout_prob: 0.1
attention_probs_dropout_prob: 0.1
gradient_checkpointing: true

trainer:
tracker:
- type: wandb
project: "levanter"
tags: ["openwebtext", "roberta", "itest"]

mp: p=f32,c=bfloat16
model_axis_size: 1
per_device_parallelism: -1

train_batch_size: 32
num_train_steps: 20000

optimizer:
learning_rate: 1E-3
weight_decay: 0.1
warmup: 0.01
38 changes: 38 additions & 0 deletions config/roberta.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
data:
train_urls:
- "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz"
validation_urls:
- "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz"
cache_dir: "gs://levanter-data/tokenized/openwebtext_roberta/"
tokenizer: "roberta-base"

model:
type: roberta
vocab_size: 50265
hidden_size: 768
intermediate_size: 3072
num_hidden_layers: 12
num_attention_heads: 12
max_position_embeddings: 512
hidden_act: "gelu"
hidden_dropout_prob: 0.1
attention_probs_dropout_prob: 0.1
gradient_checkpointing: true

trainer:
tracker:
- type: wandb
project: "levanter"
tags: ["openwebtext", "roberta", "itest"]

mp: p=f32,c=bfloat16
model_axis_size: 1
per_device_parallelism: -1

train_batch_size: 32
num_train_steps: 20000

optimizer:
learning_rate: 1E-3
weight_decay: 0.1
warmup: 0.01
91 changes: 81 additions & 10 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import equinox as eqx
import fsspec
import jax
import jax.numpy as jnp
import numpy as np
import pyarrow as pa
import regex
Expand All @@ -25,13 +26,11 @@

from levanter.data.mixture import MixtureDataset, StopStrategy

# intercept the logging nonsense here
from levanter.logging import silence_transformer_nag # noqa
from levanter.models.attention import AttentionMask
from levanter.models.lm_model import LmExample
from levanter.models.lm_model import MaskedLmExample, LmExample
from levanter.utils.hf_utils import num_cpus_used_by_tokenizer


silence_transformer_nag() # noqa
from transformers import BatchEncoding, PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast # noqa

Expand All @@ -53,7 +52,6 @@
from levanter.shapes import NamedShapeSpec, ShapeSpec # noqa
from levanter.utils.jax_utils import use_cpu_device # noqa


logger = logging.getLogger("levanter.data.text")

# TASKS:
Expand All @@ -64,6 +62,83 @@

DEFAULT_IGNORE_INDEX = -100 # Mirrors pytorch's default ignore index

class MaskedLmDataset(ShardableDataset[MaskedLmExample]):
def __init__(
self,
dataset: ShardableDataset[np.ndarray],
QPos: Axis,
KPos: Axis,
mask_token_id: int,
mask_prob: float = 0.15,
noise_prob: float = 0.1,
key: Optional[PRNGKeyArray] = None,
# ignore_index: Optional[int] = DEFAULT_IGNORE_INDEX,
):
self.dataset = dataset
self.QPos = QPos
self.KPos = KPos
self.mask_prob = mask_prob
prady-saligram marked this conversation as resolved.
Show resolved Hide resolved
self.noise_prob = noise_prob
self.key = key
self.mask_token_id = mask_token_id

if self.mask_prob > 0.0 and self.key is None:
raise ValueError("must provide key if mask_prob > 0.0")

def shard(self, shard_id: int, num_shards: int) -> "MaskedLmDataset":
return MaskedLmDataset(
self.dataset.shard(shard_id, num_shards), self.QPos, self.KPos,
self.mask_token_id,
self.mask_prob, self.noise_prob, self.key
)

def __iter__(self) -> Iterator[MaskedLmExample]:
key = self.key
sharding = jax.sharding.SingleDeviceSharding(jax.local_devices(backend="cpu")[0])

with use_cpu_device():
@functools.partial(eqx.filter_jit, out_shardings=sharding)
def _create_mlm_example(tokens, key):
tokens_array = tokens.array
targets = tokens_array.copy()

if self.mask_prob > 0:
this_key, key = jax.random.split(key)
mask_shape = tokens_array.shape
mask = jax.random.bernoulli(this_key, self.mask_prob, mask_shape)

rand = jax.random.uniform(this_key, mask_shape)
mask_token = jnp.where(rand < 0.8, self.mask_token_id, tokens_array)
random_tokens = jax.random.randint(this_key, mask_shape, 0, tokens_array.max() + 1)
mask_token = jnp.where((rand >= 0.8) & (rand < 0.8 + self.noise_prob), random_tokens, mask_token)
masked_tokens = jnp.where(mask, mask_token, tokens_array)

# Set targets to the original tokens where mask is True, otherwise set to mask_token_id
targets = jnp.where(mask, tokens_array, self.mask_token_id)

masked_tokens_named = hax.named(masked_tokens, self.QPos)
targets_named = hax.named(targets, self.QPos)

attn_mask_shape = (tokens_array.shape[0], tokens_array.shape[0])
attn_mask = hax.named(jnp.ones(attn_mask_shape, dtype=jnp.bool_), (self.QPos, self.KPos))

example = MaskedLmExample.masked_lm(tokens=masked_tokens_named, targets=targets_named, mask_token_id=self.mask_token_id, attn_mask=attn_mask)
else:
targets_named = hax.named(targets, self.QPos)
attn_mask_shape = (tokens_array.shape[0], tokens_array.shape[0])
attn_mask = hax.named(jnp.ones(attn_mask_shape, dtype=jnp.bool_), (self.QPos, self.KPos))

example = MaskedLmExample.masked_lm(tokens=tokens, targets=targets_named, mask_token_id=self.mask_token_id, attn_mask=attn_mask)

return example

for tokens in self.dataset:
tokens_array = jnp.array(tokens)
tokens_named = hax.named(tokens_array, self.QPos)
example = _create_mlm_example(tokens_named, key)
yield example



class CausalLmDataset(ShardableDataset[LmExample]):
def __init__(
Expand Down Expand Up @@ -95,18 +170,13 @@ def __iter__(self) -> Iterator[LmExample]:
sharding = jax.sharding.SingleDeviceSharding(jax.local_devices(backend="cpu")[0])

with use_cpu_device():

@functools.partial(eqx.filter_jit, out_shardings=sharding)
def _create_lm_example(tokens, key):
tokens = hax.named(tokens, self.QPos)

example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id)

if self.fcm_prob > 0:
# masks for attention
# We support forgetful causal masking (FCM) which is a technique that improves training speed by
# randomly masking out some of the context. This is a bit like dropout, but it's applied to the attention
# mask instead of the activations. It's described in https://arxiv.org/abs/2210.13432
assert self.key is not None
this_key, key = jax.random.split(key)
fcm_mask = hax.nn.attention.forgetful_causal_mask(self.KPos, self.fcm_prob, key=this_key)
Expand All @@ -120,6 +190,7 @@ def _create_lm_example(tokens, key):
yield example



class TokenSeqDataset(ShardableDataset[np.ndarray]):
"""
A dataset that yields sequences of tokens of fixed length from a TokenizedDocumentCache.
Expand Down Expand Up @@ -832,4 +903,4 @@ def build_caches(

@property
def sources(self) -> dict[str, LMDatasetSourceConfig]:
return self.configs
return self.configs
Loading