Skip to content

Commit

Permalink
[TRAX] Change train.lr_schedule to train.lr_schedule_fn and update co…
Browse files Browse the repository at this point in the history
…mments.

PiperOrigin-RevId: 319319214
  • Loading branch information
afrozenator authored and copybara-github committed Jul 1, 2020
1 parent 714fb8e commit a305230
Show file tree
Hide file tree
Showing 7 changed files with 7 additions and 9 deletions.
1 change: 1 addition & 0 deletions trax/supervised/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Supervised learning imports in Trax."""

from trax.supervised import inputs
from trax.supervised import lr_schedules
from trax.supervised import tf_inputs
from trax.supervised import trainer_lib
from trax.supervised import training
Expand Down
2 changes: 1 addition & 1 deletion trax/supervised/configs/lstm_seq2seq_wmt_ende.gin
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ train.eval_frequency = 1000
train.eval_steps = 10
train.model = @trax.models.LSTMSeq2SeqAttn
train.optimizer = @trax.optimizers.Adam
train.lr_schedule = @lr_schedules.warmup()
train.lr_schedule_fn = @lr_schedules.warmup
train.steps = 250000

# Parameters for LSTMSeq2SeqAttn:
Expand Down
2 changes: 1 addition & 1 deletion trax/supervised/configs/mlp_mnist.gin
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ train.eval_steps = 10
train.model = @trax.models.MLP
train.steps = 2000
train.checkpoint_highest = 'accuracy'
train.lr_schedule = @lr_schedules.constant()
train.lr_schedule_fn = @lr_schedules.constant
1 change: 0 additions & 1 deletion trax/supervised/configs/resnet50_frn_imagenet_8gb.gin
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,3 @@ train.eval_steps = 20
train.model = @trax.models.Resnet50
train.optimizer = @trax.optimizers.Momentum
train.steps = 300000
train.lr_schedule = @lr_schedules.multifactor()
1 change: 0 additions & 1 deletion trax/supervised/configs/resnet50_imagenet_8gb_testing.gin
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,3 @@ train.eval_steps = 20
train.model = @trax.models.Resnet50
train.optimizer = @trax.optimizers.Momentum
train.steps = 100000
train.lr_schedule = @lr_schedules.multifactor()
1 change: 0 additions & 1 deletion trax/supervised/configs/wide_resnet_cifar10_8gb.gin
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,3 @@ train.eval_steps = 10
train.model = @trax.models.WideResnet
train.optimizer = @trax.optimizers.Momentum
train.steps = 10000
train.lr_schedule = @lr_schedules.multifactor()
8 changes: 4 additions & 4 deletions trax/supervised/trainer_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def train(output_dir,
loss_fn=tl.CrossEntropyLoss(),
inputs=trax_inputs.batcher,
optimizer=trax_opt.Adafactor,
lr_schedule=lr.multifactor(),
lr_schedule_fn=lr.multifactor,
trainer_class=Trainer,
steps=1000,
checkpoints_at=None,
Expand All @@ -559,8 +559,8 @@ def train(output_dir,
rng -> loss.
inputs: callable returning trax.inputs.Inputs.
optimizer: The optimizer (see optimizers/base.py for signature).
lr_schedule: A learning rate schedule as a function that takes history and
returns a function from step to learning rate (a float).
lr_schedule_fn: A learning rate schedule function, that when called returns
a function from step to learning rate (a float).
trainer_class: The trainer class to use.
steps: int, total number of training steps.
checkpoints_at: list of integers. Save a checkpoint for each training step
Expand All @@ -582,7 +582,7 @@ def train(output_dir,
return custom_train_fn(output_dir, model=model)

n_devices = num_devices()
trainer = trainer_class(model, loss_fn, optimizer, lr_schedule, inputs,
trainer = trainer_class(model, loss_fn, optimizer, lr_schedule_fn(), inputs,
output_dir,
random_seed=random_seed,
n_devices=n_devices,
Expand Down

0 comments on commit a305230

Please sign in to comment.