From a3052302ec725242a01c767117c4f7fe3859c72d Mon Sep 17 00:00:00 2001 From: Afroz Mohiuddin Date: Wed, 1 Jul 2020 16:37:35 -0700 Subject: [PATCH] [TRAX] Change train.lr_schedule to train.lr_schedule_fn and update comments. PiperOrigin-RevId: 319319214 --- trax/supervised/__init__.py | 1 + trax/supervised/configs/lstm_seq2seq_wmt_ende.gin | 2 +- trax/supervised/configs/mlp_mnist.gin | 2 +- trax/supervised/configs/resnet50_frn_imagenet_8gb.gin | 1 - trax/supervised/configs/resnet50_imagenet_8gb_testing.gin | 1 - trax/supervised/configs/wide_resnet_cifar10_8gb.gin | 1 - trax/supervised/trainer_lib.py | 8 ++++---- 7 files changed, 7 insertions(+), 9 deletions(-) diff --git a/trax/supervised/__init__.py b/trax/supervised/__init__.py index b07cc1b5a..74ef780da 100644 --- a/trax/supervised/__init__.py +++ b/trax/supervised/__init__.py @@ -16,6 +16,7 @@ """Supervised learning imports in Trax.""" from trax.supervised import inputs +from trax.supervised import lr_schedules from trax.supervised import tf_inputs from trax.supervised import trainer_lib from trax.supervised import training diff --git a/trax/supervised/configs/lstm_seq2seq_wmt_ende.gin b/trax/supervised/configs/lstm_seq2seq_wmt_ende.gin index a211a9651..1fe9a9c16 100644 --- a/trax/supervised/configs/lstm_seq2seq_wmt_ende.gin +++ b/trax/supervised/configs/lstm_seq2seq_wmt_ende.gin @@ -54,7 +54,7 @@ train.eval_frequency = 1000 train.eval_steps = 10 train.model = @trax.models.LSTMSeq2SeqAttn train.optimizer = @trax.optimizers.Adam -train.lr_schedule = @lr_schedules.warmup() +train.lr_schedule_fn = @lr_schedules.warmup train.steps = 250000 # Parameters for LSTMSeq2SeqAttn: diff --git a/trax/supervised/configs/mlp_mnist.gin b/trax/supervised/configs/mlp_mnist.gin index 448c95e8d..9163dced9 100644 --- a/trax/supervised/configs/mlp_mnist.gin +++ b/trax/supervised/configs/mlp_mnist.gin @@ -49,4 +49,4 @@ train.eval_steps = 10 train.model = @trax.models.MLP train.steps = 2000 train.checkpoint_highest = 'accuracy' -train.lr_schedule = @lr_schedules.constant() +train.lr_schedule_fn = @lr_schedules.constant diff --git a/trax/supervised/configs/resnet50_frn_imagenet_8gb.gin b/trax/supervised/configs/resnet50_frn_imagenet_8gb.gin index 953bc237a..9825d53bf 100644 --- a/trax/supervised/configs/resnet50_frn_imagenet_8gb.gin +++ b/trax/supervised/configs/resnet50_frn_imagenet_8gb.gin @@ -61,4 +61,3 @@ train.eval_steps = 20 train.model = @trax.models.Resnet50 train.optimizer = @trax.optimizers.Momentum train.steps = 300000 -train.lr_schedule = @lr_schedules.multifactor() diff --git a/trax/supervised/configs/resnet50_imagenet_8gb_testing.gin b/trax/supervised/configs/resnet50_imagenet_8gb_testing.gin index 3e6974c3c..a137608eb 100644 --- a/trax/supervised/configs/resnet50_imagenet_8gb_testing.gin +++ b/trax/supervised/configs/resnet50_imagenet_8gb_testing.gin @@ -53,4 +53,3 @@ train.eval_steps = 20 train.model = @trax.models.Resnet50 train.optimizer = @trax.optimizers.Momentum train.steps = 100000 -train.lr_schedule = @lr_schedules.multifactor() diff --git a/trax/supervised/configs/wide_resnet_cifar10_8gb.gin b/trax/supervised/configs/wide_resnet_cifar10_8gb.gin index f30566b7e..eac303ef2 100644 --- a/trax/supervised/configs/wide_resnet_cifar10_8gb.gin +++ b/trax/supervised/configs/wide_resnet_cifar10_8gb.gin @@ -57,4 +57,3 @@ train.eval_steps = 10 train.model = @trax.models.WideResnet train.optimizer = @trax.optimizers.Momentum train.steps = 10000 -train.lr_schedule = @lr_schedules.multifactor() diff --git a/trax/supervised/trainer_lib.py b/trax/supervised/trainer_lib.py index df85e9305..75624d41e 100644 --- a/trax/supervised/trainer_lib.py +++ b/trax/supervised/trainer_lib.py @@ -537,7 +537,7 @@ def train(output_dir, loss_fn=tl.CrossEntropyLoss(), inputs=trax_inputs.batcher, optimizer=trax_opt.Adafactor, - lr_schedule=lr.multifactor(), + lr_schedule_fn=lr.multifactor, trainer_class=Trainer, steps=1000, checkpoints_at=None, @@ -559,8 +559,8 @@ def train(output_dir, rng -> loss. inputs: callable returning trax.inputs.Inputs. optimizer: The optimizer (see optimizers/base.py for signature). - lr_schedule: A learning rate schedule as a function that takes history and - returns a function from step to learning rate (a float). + lr_schedule_fn: A learning rate schedule function, that when called returns + a function from step to learning rate (a float). trainer_class: The trainer class to use. steps: int, total number of training steps. checkpoints_at: list of integers. Save a checkpoint for each training step @@ -582,7 +582,7 @@ def train(output_dir, return custom_train_fn(output_dir, model=model) n_devices = num_devices() - trainer = trainer_class(model, loss_fn, optimizer, lr_schedule, inputs, + trainer = trainer_class(model, loss_fn, optimizer, lr_schedule_fn(), inputs, output_dir, random_seed=random_seed, n_devices=n_devices,