Skip to content
This repository has been archived by the owner on Jan 21, 2025. It is now read-only.

This shouldn't have public changes once the diffbase is submitted. #191

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions mesh_tensorflow/transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1563,7 +1563,7 @@ def get_estimator(model_type, vocabulary, mesh_shape,

def train_model(estimator, vocabulary, sequence_length, batch_size,
train_dataset_fn, train_steps, ensemble_inputs,
dataset_split="train"):
dataset_split="train", skip_seen_data=False):
"""Train a Mesh-TF model.

Args:
Expand All @@ -1585,17 +1585,26 @@ def train_model(estimator, vocabulary, sequence_length, batch_size,
configure Unitransformer.ensemble to the right size. If None, then all
models are trained on the same inputs.
dataset_split: str, which dataset split to train on.
skip_seen_data: a boolean, is `False` by default. Used when a training run
restarts to skip already seen data.
"""

def input_fn(params):
del params

dataset = train_dataset_fn(
sequence_length=sequence_length,
vocabulary=vocabulary,
dataset_split=dataset_split)
dataset = dataset.repeat().batch(
batch_size * (ensemble_inputs or 1), drop_remainder=True)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

# On the first time data is read in after relaunching, skip data that has
# already been seen.
if skip_seen_data:
recovered_step = estimator.get_variable_value("global_step")
dataset = dataset.skip(recovered_step)
return dataset

estimator.train(input_fn=input_fn, max_steps=train_steps)
Expand Down Expand Up @@ -2117,7 +2126,8 @@ def run(tpu_job_name,
perplexity_eval_steps=100,
init_checkpoint=None,
ensemble_inputs=None,
train_model_fn=train_model):
train_model_fn=train_model,
skip_seen_data=False):
"""Run training, eval, or inference depending on `mode`.

Args:
Expand Down Expand Up @@ -2173,6 +2183,8 @@ def run(tpu_job_name,
init_checkpoint: a string, see `get_estimator` docstring for details.
ensemble_inputs: an integer, see `train_model` docstring for details.
train_model_fn: an optional train function, is `train_model` by default.
skip_seen_data: a boolean, is `False` by default. Used when a training run
restarts to skip already seen data.
"""
if isinstance(sequence_length, int):
sequence_length = {"inputs": sequence_length,
Expand Down Expand Up @@ -2247,8 +2259,11 @@ def run(tpu_job_name,
# train_model
if train_dataset_fn is None:
raise ValueError("Must provide train_dataset_fn through gin")

train_model_fn(estimator, vocabulary, sequence_length, batch_size,
train_dataset_fn, train_steps, ensemble_inputs)
train_dataset_fn, train_steps, ensemble_inputs,
skip_seen_data=skip_seen_data)

elif mode == "perplexity_eval":
if eval_dataset_fn is None:
if train_dataset_fn is not None:
Expand Down