Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored DETR implementation on optax API. #1062

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
16 changes: 10 additions & 6 deletions scenic/model_lib/matchers/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
import jax.numpy as jnp
import numpy as np
from ott.geometry import geometry
from ott.tools import transport
from ott.geometry import epsilon_scheduler
from ott.solvers import linear


def idx2permutation(row_ind, col_ind):
Expand Down Expand Up @@ -128,11 +129,14 @@ def sinkhorn_matcher(cost: jnp.ndarray,
"""
def coupling_fn(c):
geom = geometry.Geometry(
cost_matrix=c, epsilon=epsilon, init=init, decay=decay)
return transport.solve(geom,
max_iterations=num_iters,
chg_momentum_from=chg_momentum_from,
threshold=threshold).matrix
cost_matrix=c,
epsilon=epsilon_scheduler.Epsilon(
target=epsilon, init=init, decay=decay))
return linear.solve(
geom,
momentum=linear.acceleration.Momentum(start=chg_momentum_from),
max_iterations=num_iters,
threshold=threshold).matrix

coupling = jax.vmap(coupling_fn)(cost)

Expand Down
2 changes: 1 addition & 1 deletion scenic/projects/baselines/detr/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ In order to train DETR on COCO object detection, you can use the
`detr_config.py` in the [configs directory](configs):

```shell
$ python scenic/projects/baselines/detr/main.py -- \
$ python scenic/projects/baselines/detr/main.py \
--config=scenic/projects/baselines/detr/configs/detr_config.py \
--workdir=./
```
Expand Down
44 changes: 14 additions & 30 deletions scenic/projects/baselines/detr/configs/detr_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"""
# pylint: enable=line-too-long

import copy
import ml_collections
_COCO_TRAIN_SIZE = 118287
NUM_EPOCHS = 300
Expand Down Expand Up @@ -63,38 +62,22 @@ def get_config():
config.eos_coef = 0.1

# Training.
config.trainer_name = 'detr_trainer'
config.optimizer = 'adam'
config.optimizer_configs = ml_collections.ConfigDict()
config.optimizer_configs.weight_decay = 1e-4
config.optimizer_configs.beta1 = 0.9
config.optimizer_configs.beta2 = 0.999
config.max_grad_norm = 0.1
config.num_training_epochs = NUM_EPOCHS
config.batch_size = 64
config.rng_seed = 0

decay_events = {500: 400}

# Learning rate.
# Optimizer.
steps_per_epoch = _COCO_TRAIN_SIZE // config.batch_size
config.lr_configs = ml_collections.ConfigDict()
config.lr_configs.learning_rate_schedule = 'compound'
config.lr_configs.factors = 'constant*piecewise_constant'
config.lr_configs.decay_events = [
decay_events.get(NUM_EPOCHS, NUM_EPOCHS * 2 // 3) * steps_per_epoch,
]
# Note: this is absolute (not relative):
config.lr_configs.decay_factors = [.1]
config.lr_configs.base_learning_rate = 1e-4

# Backbone training configs: optimizer and learning rate.
config.backbone_training = ml_collections.ConfigDict()
config.backbone_training.optimizer = copy.deepcopy(config.optimizer)
config.backbone_training.optimizer_configs = copy.deepcopy(
config.optimizer_configs)
config.backbone_training.lr_configs = copy.deepcopy(config.lr_configs)
config.backbone_training.lr_configs.base_learning_rate = 1e-5
config.optimizer_configs = ml_collections.ConfigDict()
config.optimizer_configs.max_grad_norm = 0.1
config.optimizer_configs.base_learning_rate = 1e-4
config.optimizer_configs.learning_rate_decay_rate = 0.1
config.optimizer_configs.beta1 = 0.9
config.optimizer_configs.beta2 = 0.999
config.optimizer_configs.weight_decay = 1e-4
config.optimizer_configs.learning_rate_reduction = 0.1 # base_lr * reduction
config.optimizer_configs.learning_rate_decay_event = (NUM_EPOCHS * 2 // 3 *
steps_per_epoch)

# Pretrained_backbone.
config.load_pretrained_backbone = True
Expand All @@ -104,6 +87,9 @@ def get_config():
# https://github.com/google-research/scenic/tree/main/scenic/projects/baselines pylint: disable=line-too-long
config.pretrained_backbone_configs.checkpoint_path = 'path_to_checkpoint_of_resnet_50'

# Eval.
config.annotations_loc = 'scenic/dataset_lib/coco_dataset/data/instances_val2017.json'

# Logging.
config.write_summary = True
config.xprof = True # Profile using xprof.
Expand All @@ -115,5 +101,3 @@ def get_config():
config.debug_eval = False # Debug mode during eval.

return config


46 changes: 15 additions & 31 deletions scenic/projects/baselines/detr/configs/detr_sinkhorn_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"""
# pylint: enable=line-too-long

import copy
import ml_collections
_COCO_TRAIN_SIZE = 118287
NUM_EPOCHS = 300
Expand Down Expand Up @@ -80,46 +79,33 @@ def get_config():
config.eos_coef = 0.1

# Training.
config.trainer_name = 'detr_trainer'
config.optimizer = 'adam'
config.optimizer_configs = ml_collections.ConfigDict()
config.optimizer_configs.weight_decay = 1e-4
config.optimizer_configs.beta1 = 0.9
config.optimizer_configs.beta2 = 0.999
config.max_grad_norm = 0.1
config.num_training_epochs = NUM_EPOCHS
config.batch_size = 64
config.rng_seed = 0

decay_events = {500: 400}

# Learning rate.
# Optimizer.
steps_per_epoch = _COCO_TRAIN_SIZE // config.batch_size
config.lr_configs = ml_collections.ConfigDict()
config.lr_configs.learning_rate_schedule = 'compound'
config.lr_configs.factors = 'constant*piecewise_constant'
config.lr_configs.decay_events = [
decay_events.get(NUM_EPOCHS, NUM_EPOCHS * 2 // 3) * steps_per_epoch,
]
# Note: this is absolute (not relative):
config.lr_configs.decay_factors = [.1]
config.lr_configs.base_learning_rate = 1e-4

# Backbone training configs: optimizer and learning rate.
config.backbone_training = ml_collections.ConfigDict()
config.backbone_training.optimizer = copy.deepcopy(config.optimizer)
config.backbone_training.optimizer_configs = copy.deepcopy(
config.optimizer_configs)
config.backbone_training.lr_configs = copy.deepcopy(config.lr_configs)
config.backbone_training.lr_configs.base_learning_rate = 1e-5
config.optimizer_configs = ml_collections.ConfigDict()
config.optimizer_configs.max_grad_norm = 0.1
config.optimizer_configs.base_learning_rate = 1e-4
config.optimizer_configs.learning_rate_decay_rate = 0.1
config.optimizer_configs.beta1 = 0.9
config.optimizer_configs.beta2 = 0.999
config.optimizer_configs.weight_decay = 1e-4
config.optimizer_configs.learning_rate_reduction = 0.1 # base_lr * reduction
config.optimizer_configs.learning_rate_decay_event = (NUM_EPOCHS * 2 // 3 *
steps_per_epoch)

# Pretrained_backbone.
config.load_pretrained_backbone = True
config.freeze_backbone_batch_stats = True
config.pretrained_backbone_configs = ml_collections.ConfigDict()
# Download pretrained ResNet50 checkpoints from here:
# https://github.com/google-research/scenic/tree/main/scenic/projects/baselines pylint: disable=line-too-long
config.init_from.checkpoint_path = 'path_to_checkpoint_of_resnet_50'
config.pretrained_backbone_configs.checkpoint_path = 'path_to_checkpoint_of_resnet_50'

# Eval.
config.annotations_loc = 'scenic/dataset_lib/coco_dataset/data/instances_val2017.json'

# Logging.
config.write_summary = True
Expand All @@ -132,5 +118,3 @@ def get_config():
config.debug_eval = False # Debug mode during eval.

return config


2 changes: 1 addition & 1 deletion scenic/projects/baselines/detr/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from scenic import app
from scenic.projects.baselines.detr import model as detr_model
from scenic.projects.baselines.detr import trainer
from scenic.train_lib_deprecated import train_utils
from scenic.train_lib import train_utils

FLAGS = flags.FLAGS

Expand Down
94 changes: 41 additions & 53 deletions scenic/projects/baselines/detr/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
from typing import Any, Dict, Optional, Set

from absl import logging
from flax import core as flax_core
from flax import optim as optimizers
from flax import traverse_util
import flax
import optax
import jax
from jax.example_libraries import optimizers as experimental_optimizers
import jax.numpy as jnp
Expand All @@ -33,8 +32,7 @@
from scenic.common_lib import image_utils
from scenic.dataset_lib.coco_dataset import coco_eval
from scenic.model_lib.base_models import box_utils
from scenic.train_lib_deprecated import optimizers as scenic_optimizers
from scenic.train_lib_deprecated import train_utils
from scenic.train_lib import train_utils
import scipy.special
import tensorflow as tf

Expand Down Expand Up @@ -319,62 +317,52 @@ def draw_boxes_side_by_side(pred: Dict[str, Any], batch: Dict[str, Any],
return np.stack(viz, axis=0)


def get_detr_optimizer(config):
"""Makes a Flax MultiOptimizer for DETR."""
other_optim = scenic_optimizers.get_optimizer(config)
def get_detr_optimizer(config, params):
"""Makes an Optax optimizer for DETR."""
oc = config.optimizer_configs

if config.get('backbone_training'):
backbone_optim = scenic_optimizers.get_optimizer(config.backbone_training)
else:
backbone_optim = other_optim

def is_bn(path):
def bn_and_freeze_batch_stats(path):
# For DETR we need to skip the BN affine transforms as well.
if not config.freeze_backbone_batch_stats:
return False
names = ['/bn1/', '/bn2/', '/bn3/', '/init_bn/', '/proj_bn/']
for s in names:
if s in path:
return True
return False
backbone_traversal = optimizers.ModelParamTraversal(
lambda path, param: ('backbone' in path) and (not is_bn(path)))
other_traversal = optimizers.ModelParamTraversal(
lambda path, param: 'backbone' not in path)

return MultiOptimizerWithLogging((backbone_traversal, backbone_optim),
(other_traversal, other_optim))


class MultiOptimizerWithLogging(optimizers.MultiOptimizer):
"""Adds logging to MultiOptimizer to show which params are trained."""

def init_state(self, params):
self.log(params)
return super().init_state(params)

def log(self, inputs):
for i, traversal in enumerate(self.traversals):
params = _get_params_dict(inputs)
flat_dict = traverse_util.flatten_dict(params)
for key, value in _sorted_items(flat_dict):
path = '/' + '/'.join(key)
if traversal._filter_fn(path, value): # pylint: disable=protected-access
logging.info(
'ParamTraversalLogger (opt %d): %s, %s', i, value.shape, path)


def _sorted_items(x):
"""Returns items of a dict ordered by keys."""
return sorted(x.items(), key=lambda x: x[0])


def _get_params_dict(inputs):
if isinstance(inputs, (dict, flax_core.FrozenDict)):
return flax_core.unfreeze(inputs)
else:
raise ValueError(
'Can only traverse a flax Model instance or a nested dict, not '
f'{type(inputs)}')

backbone_traversal = flax.traverse_util.ModelParamTraversal(
lambda path, _: 'backbone' in path)
bn_traversal = flax.traverse_util.ModelParamTraversal(
lambda path, _: bn_and_freeze_batch_stats(path))

all_false = jax.tree_util.tree_map(lambda _: False, params)

def get_mask(traversal: flax.traverse_util.ModelParamTraversal):
return traversal.update(lambda _: True, all_false)

backbone_mask = get_mask(backbone_traversal)
bn_mask = get_mask(bn_traversal)
weight_decay_mask = jax.tree_map(lambda p: p.ndim != 1, params)


tx = optax.chain(
optax.clip_by_global_norm(oc.max_grad_norm),
optax.adamw(
learning_rate=optax.piecewise_constant_schedule(
oc.base_learning_rate,
{oc.learning_rate_decay_event: oc.learning_rate_decay_rate}
),
b1=oc.beta1,
b2=oc.beta2,
weight_decay=oc.weight_decay,
mask=weight_decay_mask,
mu_dtype=oc.get('mu_dtype', jnp.float32)),
optax.masked(optax.scale(oc.learning_rate_reduction), backbone_mask),
optax.masked(optax.scale(0), bn_mask),
)

return tx

def clip_grads(grad_tree, max_norm):
"""Clip gradients stored as a pytree of arrays to maximum norm `max_norm`."""
Expand Down
Loading