Skip to content

Commit

Permalink
[AIR] Added Ray Logging to MosaicTrainer (ray-project#29620)
Browse files Browse the repository at this point in the history
Added RayLogger to MosaicTrainer to relay all reported information.

RayLogger is a subclass of LoggerDestination, just like all other native composer loggers. The information to be logged is given via log_metrics call, which is saved in the RayLogger object. The logger reports the logged information every batch checkpoint and epoch checkpoint. All other composer loggers besides RayLogger loggers are removed from the trainer.

Note that because at the moment, the result metrics_dataframe will only include the keys that are reported in the very first report call, to have metrics that are not reported every batch in the final metrics dataframe, the keys should be passed in via 'log_keys' in the trainer_init_config.

Co-authored-by: Amog Kamsetty <[email protected]>
Signed-off-by: ilee300a <[email protected]>
  • Loading branch information
ilee300a and amogkam authored Oct 27, 2022
1 parent 57ea8bd commit 28e84b8
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 9 deletions.
2 changes: 2 additions & 0 deletions doc/source/custom_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ def update_context(app, pagename, templatename, context, doctree):
"composer.trainer",
"composer.loggers",
"composer.loggers.logger_destination",
"composer.core",
"composer.core.state",
]


Expand Down
8 changes: 4 additions & 4 deletions python/ray/train/examples/mosaic_cifar10_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def trainer_init_per_worker(config):
from composer.models.tasks import ComposerClassifier
import composer.optim

BATCH_SIZE = 32
BATCH_SIZE = 64
# prepare the model for distributed training and wrap with ComposerClassifier for
# Composer Trainer compatibility
model = torchvision.models.resnet18(num_classes=10)
Expand All @@ -37,13 +37,13 @@ def trainer_init_per_worker(config):
datasets.CIFAR10(
data_directory, train=True, download=True, transform=cifar10_transforms
),
list(range(64)),
list(range(BATCH_SIZE * 10)),
)
test_dataset = torch.utils.data.Subset(
datasets.CIFAR10(
data_directory, train=False, download=True, transform=cifar10_transforms
),
list(range(64)),
list(range(BATCH_SIZE * 10)),
)

batch_size_per_worker = BATCH_SIZE // session.get_world_size()
Expand Down Expand Up @@ -82,7 +82,7 @@ def train_mosaic_cifar10(num_workers=2, use_gpu=False):
from ray.train.mosaic import MosaicTrainer

trainer_init_config = {
"max_duration": "1ep",
"max_duration": "2ep",
"algorithms": [LabelSmoothing()],
"should_eval": False,
}
Expand Down
68 changes: 68 additions & 0 deletions python/ray/train/mosaic/_mosaic_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Any, Dict, Optional, List
import torch

from composer.loggers import Logger
from composer.loggers.logger_destination import LoggerDestination
from composer.core.state import State

from ray.air import session


class RayLogger(LoggerDestination):
"""A logger to relay information logged by composer models to ray.
This logger allows utilizing all necessary logging and logged data handling provided
by the Composer library. All the logged information is saved in the data dictionary
every time a new information is logged, but to reduce unnecessary reporting, the
most up-to-date logged information is reported as metrics every batch checkpoint and
epoch checkpoint (see Composer's Event module for more details).
Because ray's metric dataframe will not include new keys that is reported after the
very first report call, any logged information with the keys not included in the
first batch checkpoint would not be retrievable after training. In other words, if
the log level is greater than `LogLevel.BATCH` for some data, they would not be
present in `Result.metrics_dataframe`. To allow preserving those information, the
user can provide keys to be always included in the reported data by using `keys`
argument in the constructor. For `MosaicTrainer`, use
`trainer_init_config['log_keys']` to populate these keys.
Note that in the Event callback functions, we remove unused variables, as this is
practiced in Mosaic's composer library.
Args:
keys: the key values that will be included in the reported metrics.
"""

def __init__(self, keys: Optional[List[str]] = None) -> None:
self.data = {}
# report at fit end only if there are additional training batches run after the
# last epoch checkpoint report
self.should_report_fit_end = False
if keys:
for key in keys:
self.data[key] = None

def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
self.data.update(metrics.items())
for key, val in self.data.items():
if isinstance(val, torch.Tensor):
self.data[key] = val.item()

def batch_checkpoint(self, state: State, logger: Logger) -> None:
del logger # unused
self.should_report_fit_end = True

def epoch_checkpoint(self, state: State, logger: Logger) -> None:
del logger # unused
self.should_report_fit_end = False
session.report(self.data)

# flush the data
self.data = {}

def fit_end(self, state: State, logger: Logger) -> None:
# report at close in case the trainer stops in the middle of an epoch.
# this may be double counted with epoch checkpoint.
del logger # unused
if self.should_report_fit_end:
session.report(self.data)
12 changes: 10 additions & 2 deletions python/ray/train/mosaic/mosaic_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.air.config import DatasetConfig, RunConfig, ScalingConfig
from ray.train.mosaic._mosaic_utils import RayLogger
from ray.train.torch import TorchConfig, TorchTrainer
from ray.train.trainer import GenDataset
from ray.util import PublicAPI
Expand Down Expand Up @@ -207,16 +208,23 @@ def _mosaic_train_loop_per_worker(config):
os.environ["WORLD_SIZE"] = str(session.get_world_size())
os.environ["LOCAL_RANK"] = str(session.get_local_rank())

# Replace Composer's Loggers with RayLogger
ray_logger = RayLogger(keys=config.pop("log_keys", []))

# initialize Composer trainer
config["progress_bar"] = False
trainer: Trainer = trainer_init_per_worker(config)

# Remove Composer's Loggers
# Remove Composer's Loggers if there are any added in the trainer_init_per_worker
# this removes the logging part of the loggers
filtered_callbacks = list()
for callback in trainer.state.callbacks:
if not isinstance(callback, LoggerDestination):
filtered_callbacks.append(callback)
filtered_callbacks.append(ray_logger)
trainer.state.callbacks = filtered_callbacks

# this prevents data to be routed to all the Composer Loggers
trainer.logger.destinations = (ray_logger,)

# call the trainer
trainer.fit()
129 changes: 126 additions & 3 deletions python/ray/train/tests/test_mosaic_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def trainer_init_per_worker(config):
weight_decay=2.0e-3,
)

if config.pop("eval", False):
if config.pop("should_eval", False):
config["eval_dataloader"] = evaluator

return composer.trainer.Trainer(
Expand All @@ -85,9 +85,17 @@ def trainer_init_per_worker(config):
def test_mosaic_cifar10(ray_start_4_cpus):
from ray.train.examples.mosaic_cifar10_example import train_mosaic_cifar10

_ = train_mosaic_cifar10()
result = train_mosaic_cifar10().metrics_dataframe

# TODO : add asserts once reporting has been integrated
# check the max epoch value
assert result["epoch"][result.index[-1]] == 1

# check train_iterations
assert result["_training_iteration"][result.index[-1]] == 2

# check metrics/train/Accuracy has increased
acc = list(result["metrics/train/Accuracy"])
assert acc[-1] > acc[0]


def test_init_errors(ray_start_4_cpus):
Expand Down Expand Up @@ -149,6 +157,10 @@ class DummyCallback(Callback):
def fit_start(self, state: State, logger: Logger) -> None:
raise ValueError("Composer Callback object exists.")

class DummyMonitorCallback(Callback):
def fit_start(self, state: State, logger: Logger) -> None:
logger.log_metrics({"dummy_callback": "test"})

# DummyLogger should not throw an error since it should be removed before `fit` call
trainer_init_config = {
"max_duration": "1ep",
Expand All @@ -175,6 +187,117 @@ def fit_start(self, state: State, logger: Logger) -> None:
trainer.fit()
assert e == "Composer Callback object exists."

trainer_init_config["callbacks"] = DummyMonitorCallback()
trainer = MosaicTrainer(
trainer_init_per_worker=trainer_init_per_worker,
trainer_init_config=trainer_init_config,
scaling_config=scaling_config,
)

result = trainer.fit()

assert "dummy_callback" in result.metrics
assert result.metrics["dummy_callback"] == "test"


def test_log_count(ray_start_4_cpus):
from ray.train.mosaic import MosaicTrainer

trainer_init_config = {
"max_duration": "1ep",
"should_eval": False,
}

trainer = MosaicTrainer(
trainer_init_per_worker=trainer_init_per_worker,
trainer_init_config=trainer_init_config,
scaling_config=scaling_config,
)

result = trainer.fit()

assert len(result.metrics_dataframe) == 1

trainer_init_config["max_duration"] = "1ba"

trainer = MosaicTrainer(
trainer_init_per_worker=trainer_init_per_worker,
trainer_init_config=trainer_init_config,
scaling_config=scaling_config,
)

result = trainer.fit()

assert len(result.metrics_dataframe) == 1


def test_metrics_key(ray_start_4_cpus):
from ray.train.mosaic import MosaicTrainer

"""Tests if `log_keys` defined in `trianer_init_config` appears in result
metrics_dataframe.
"""
trainer_init_config = {
"max_duration": "1ep",
"should_eval": True,
"log_keys": ["metrics/my_evaluator/Accuracy"],
}

trainer = MosaicTrainer(
trainer_init_per_worker=trainer_init_per_worker,
trainer_init_config=trainer_init_config,
scaling_config=scaling_config,
)

result = trainer.fit()

# check if the passed in log key exists
assert "metrics/my_evaluator/Accuracy" in result.metrics_dataframe.columns


def test_monitor_callbacks(ray_start_4_cpus):
from ray.train.mosaic import MosaicTrainer

# Test Callbacks involving logging (SpeedMonitor, LRMonitor)
from composer.callbacks import SpeedMonitor, LRMonitor, GradMonitor

trainer_init_config = {
"max_duration": "1ep",
"should_eval": True,
}
trainer_init_config["log_keys"] = [
"grad_l2_norm/step",
]
trainer_init_config["callbacks"] = [
SpeedMonitor(window_size=3),
LRMonitor(),
GradMonitor(),
]

trainer = MosaicTrainer(
trainer_init_per_worker=trainer_init_per_worker,
trainer_init_config=trainer_init_config,
scaling_config=scaling_config,
)

result = trainer.fit()

assert len(result.metrics_dataframe) == 1

metrics_columns = result.metrics_dataframe.columns
columns_to_check = [
"wall_clock/train",
"wall_clock/val",
"wall_clock/total",
"lr-DecoupledSGDW/group0",
"grad_l2_norm/step",
]
for column in columns_to_check:
assert column in metrics_columns, column + " is not found"
assert result.metrics_dataframe[column].isnull().sum() == 0, (
column + " column has a null value"
)


if __name__ == "__main__":
import sys
Expand Down

0 comments on commit 28e84b8

Please sign in to comment.