Skip to content

Commit

Permalink
Revert "[AIR] Avoid checkpoint conversion, move encoding logic to che…
Browse files Browse the repository at this point in the history
…ckpoints (ray-project#28794)" (ray-project#29784)

This added dependencies from the HorovodConfig on TensorFlow and Torch. If either of these is not installed, (e.g. if the user is using Horovod with Torch and does not have TensorFlow installed), then they will run into a `ModuleNotFoundError`.

https://github.com/ray-project/ray/blob/6b9a56d28e1029741feaa864257d75824fe36622/python/ray/train/horovod/config.py#L16-L17

Reverting this for now.
  • Loading branch information
matthewdeng authored Oct 27, 2022
1 parent 6b9a56d commit 57ea8bd
Show file tree
Hide file tree
Showing 24 changed files with 174 additions and 404 deletions.
20 changes: 0 additions & 20 deletions python/ray/air/_internal/tensorflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,26 +64,6 @@ def convert_ndarray_batch_to_tf_tensor_batch(
return batch


# This is not foolproof, but it's better than nothing
# The place it is used in will be deprecated soon
def contains_tensorflow_object(obj):
if hasattr(obj, "__module__") and (
"keras" in obj.__module__ or "tensorflow" in obj.__module__
):
return True
elif isinstance(obj, dict):
for k, v in obj.items():
if contains_tensorflow_object(k):
return True
if contains_tensorflow_object(v):
return True
elif isinstance(obj, (list, tuple)):
for v in obj:
if contains_tensorflow_object(v):
return True
return False


def get_type_spec(
schema: Union["pyarrow.lib.Schema", "PandasBlockSchema"],
columns: Union[str, List[str]],
Expand Down
9 changes: 3 additions & 6 deletions python/ray/air/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,11 +501,6 @@ def _get_temporary_checkpoint_dir(self) -> str:
)
return os.path.join(tmp_dir_path, checkpoint_dir_name)

def _save_checkpoint_metadata_in_directory(self, path: str) -> None:
checkpoint_metadata_path = os.path.join(path, _CHECKPOINT_METADATA_FILE_NAME)
with open(checkpoint_metadata_path, "wb") as file:
pickle.dump(self._metadata, file)

def _to_directory(self, path: str) -> None:
if self._data_dict or self._obj_ref:
# This is a object ref or dict
Expand Down Expand Up @@ -552,7 +547,9 @@ def _to_directory(self, path: str) -> None:
f"No valid location found for checkpoint {self}: {self._uri}"
)

self._save_checkpoint_metadata_in_directory(path)
checkpoint_metadata_path = os.path.join(path, _CHECKPOINT_METADATA_FILE_NAME)
with open(checkpoint_metadata_path, "wb") as file:
pickle.dump(self._metadata, file)

def to_directory(self, path: Optional[str] = None) -> str:
"""Write checkpoint data to directory.
Expand Down
2 changes: 1 addition & 1 deletion python/ray/train/_internal/backend_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def initialize_session(
train_func=train_func,
dataset_shard=self.dataset_shards[index],
checkpoint=checkpoint,
encode_data_fn=self._backend._encode_data,
encode_data_fn=self._backend.encode_data,
)
)

Expand Down
12 changes: 3 additions & 9 deletions python/ray/train/_internal/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,13 @@ def _process_checkpoint(
"""Ray Train entrypoint. Perform all processing for a checkpoint."""
# Get checkpoint from first worker.
checkpoint_data = checkpoint_results[0].data
checkpoint_metadata = checkpoint_results[0].metadata or {}

# TODO(ml-team): Remove once we remove Backend.decode_data
checkpoint_data = decode_checkpoint_fn(checkpoint_data).to_dict()
# Decode checkpoint.
checkpoint_data = decode_checkpoint_fn(checkpoint_data)

score_attr = self._checkpoint_strategy.checkpoint_score_attribute
if (
self._checkpoint_strategy.num_to_keep != 0
and score_attr not in checkpoint_metadata
and score_attr not in checkpoint_data
):
raise ValueError(
Expand All @@ -124,11 +122,7 @@ def _process_checkpoint(
dir_or_data=checkpoint_data,
checkpoint_id=self._latest_checkpoint_id,
storage_mode=CheckpointStorage.MEMORY,
metrics={
score_attr: checkpoint_metadata.get(
score_attr, checkpoint_data.get(score_attr, 0.0)
)
},
metrics={score_attr: checkpoint_data.get(score_attr, 0.0)},
)
self.register_checkpoint(checkpoint=tracked_checkpoint)

Expand Down
51 changes: 15 additions & 36 deletions python/ray/train/_internal/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging
import platform
import queue
import sys
import threading
import time
from dataclasses import dataclass
Expand Down Expand Up @@ -52,8 +51,7 @@ class TrialInfo:
@dataclass
class TrainingResult:
type: TrainingResultType
data: Union[Dict, Checkpoint]
metadata: Optional[Dict] = None
data: Dict


# TODO(xwjiang): This needs a better name.
Expand All @@ -70,9 +68,8 @@ def __init__(
trial_info: Optional[TrialInfo] = None,
dataset_shard: Optional[Union[Dataset, DatasetPipeline]] = None,
# TODO(xwjiang): Legacy Ray Train trainer clean up!
checkpoint: Optional[Checkpoint] = None,
# Deprecated
encode_data_fn: Optional[Callable] = None,
checkpoint: Optional[Union[Dict, Checkpoint]] = None,
encode_data_fn: Callable = None,
detailed_autofilled_metrics: bool = False,
):

Expand All @@ -83,7 +80,7 @@ def __init__(
self.world_size = world_size
self.trial_info = trial_info
# TODO(xwjiang): Legacy Ray Train trainer clean up!
self.loaded_checkpoint = checkpoint
self.loaded_checkpoint: Optional[Union[Dict, Checkpoint]] = checkpoint

# Function to encode checkpoint dict before sending to the driver.
if not encode_data_fn:
Expand Down Expand Up @@ -243,9 +240,9 @@ def _report_legacy(self, **kwargs):
if self.ignore_report:
return

kwargs = self._auto_fill_metrics(kwargs)
kwargs = self._encode_data_fn(self._auto_fill_metrics(kwargs))

result = TrainingResult(type=TrainingResultType.REPORT, data=kwargs)
result = TrainingResult(TrainingResultType.REPORT, kwargs)

# Add result to a thread-safe queue.
self.result_queue.put(result, block=True)
Expand All @@ -272,26 +269,22 @@ def _report_thread_runner_error(self, block=False):
except queue.Empty:
pass

def checkpoint(self, checkpoint: Checkpoint):
def checkpoint(self, **kwargs):
"""Adds kwargs to the queue to be consumed by main thread.
Also stores the checkpoint in ``self.loaded_checkpoint``.
"""

# Update session checkpoint to latest checkpoint.
self.loaded_checkpoint = checkpoint
self.loaded_checkpoint = kwargs

# Only store checkpoints on worker with rank 0.
if self.world_rank != 0:
checkpoint = None
elif checkpoint:
checkpoint = self._encode_data_fn(checkpoint)

result = TrainingResult(
type=TrainingResultType.CHECKPOINT,
data=checkpoint,
metadata=self._auto_fill_checkpoint_metrics({}),
)
kwargs = {}
else:
kwargs = self._encode_data_fn(self._auto_fill_checkpoint_metrics(kwargs))

result = TrainingResult(TrainingResultType.CHECKPOINT, kwargs)
# Add result to a thread-safe queue.
self.result_queue.put(result, block=True)

Expand All @@ -301,23 +294,9 @@ def checkpoint(self, checkpoint: Checkpoint):

def report(self, metrics: Dict, checkpoint: Optional[Checkpoint] = None) -> None:
# TODO(xwjiang): tons of optimizations.

# Special case: early fail for Torch tensors
if "torch" in sys.modules:
from ray.air._internal.torch_utils import contains_tensor

if contains_tensor(metrics):
raise ValueError(
"Passing objects containg Torch tensors as metrics "
"is not supported as it will throw an exception on "
"deserialization. You can either convert the tensors "
"to Python objects or use a `TorchCheckpoint` as the "
"`checkpoint` argument of `ray.air.session.report` to "
"store your Torch objects."
)

if checkpoint:
self.checkpoint(checkpoint)
checkpoint_dict = checkpoint.to_dict()
self.checkpoint(**checkpoint_dict)
self._report_legacy(**metrics)


Expand Down
65 changes: 3 additions & 62 deletions python/ray/train/backend.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,16 @@
import logging
import warnings
from typing import Type, TypeVar, Dict
from typing import TypeVar, Dict

from ray.air.checkpoint import Checkpoint
from ray.train._internal.utils import Singleton
from ray.train._internal.worker_group import WorkerGroup
from ray.util.annotations import Deprecated, DeveloperAPI
from ray.util.annotations import DeveloperAPI

from ray.widgets import make_table_html_repr

EncodedData = TypeVar("EncodedData")

logger = logging.getLogger(__name__)

# This is used in several places to print a warning.
_encode_decode_deprecation_message = (
"``encode_data`` and ``decode_data`` are deprecated in favor of "
"framework-specific ``ray.air.Checkpoint`` subclasses (reported "
"using ``ray.air.session.report()``) which can implement "
"encoding and decoding logic. In the future, ``encode_data`` and "
"``decode_data`` will throw an exception if overriden."
)


def _warn_about_bad_checkpoint_type(expected_checkpoint_cls: Type[Checkpoint]):
return
# Do not print warnings in 2.1 yet.
# TODO(ml-team): Change this once we have full API parity with framework
# checkpoints. Also turn on test_torch_trainer::test_torch_bad_checkpoint_warning
# warnings.warn(
# f"You have reported a checkpoint with the `{Checkpoint}` "
# "type, but the intended checkpoint type for the Trainer "
# f"you are using is `{expected_checkpoint_cls}`. "
# "Not using the intended checkpoint type may cause "
# "exceptions or other issues, especially during "
# "serialization and deserialization. The checkpoint "
# "type will be changed automatically. "
# "This behavior may change in the future."
# )


@DeveloperAPI
class BackendConfig:
Expand Down Expand Up @@ -73,37 +46,6 @@ def on_shutdown(self, worker_group: WorkerGroup, backend_config: BackendConfig):
"""Logic for shutting down the backend."""
pass

@classmethod
def _encode_data(cls, checkpoint: Checkpoint) -> Checkpoint:
"""Temporary method until ``encode_data`` is deprecated."""
if cls.encode_data != Backend.encode_data:
warnings.warn(
_encode_decode_deprecation_message, DeprecationWarning, stacklevel=2
)
# We wrap the return of encode_data in dict in case it is
# not a dict itself.
checkpoint = checkpoint.from_dict(
{"encoded_data": cls.encode_data(checkpoint.to_dict())}
)
return checkpoint

@classmethod
def _decode_data(cls, checkpoint: Checkpoint) -> Checkpoint:
"""Temporary method until ``decode_data`` is deprecated."""
if cls.decode_data != Backend.decode_data:
warnings.warn(
_encode_decode_deprecation_message, DeprecationWarning, stacklevel=2
)
checkpoint_dict = checkpoint.to_dict()
# If "encoded_data" is not in the dict, then the data was
# not encoded, but the user may want to just do decoding
# anyway.
checkpoint = checkpoint.from_dict(
cls.decode_data(checkpoint_dict.get("encoded_data", checkpoint_dict))
)
return checkpoint

@Deprecated(message=_encode_decode_deprecation_message)
@staticmethod
def encode_data(data_dict: Dict) -> EncodedData:
"""Logic to encode a data dict before sending to the driver.
Expand All @@ -114,7 +56,6 @@ def encode_data(data_dict: Dict) -> EncodedData:

return data_dict

@Deprecated(message=_encode_decode_deprecation_message)
@staticmethod
def decode_data(encoded_data: EncodedData) -> Dict:
"""Logic to decode an encoded data dict.
Expand Down
6 changes: 1 addition & 5 deletions python/ray/train/data_parallel_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from ray import tune
from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.air._internal.checkpointing import save_preprocessor_to_dir
from ray.air.config import DatasetConfig, RunConfig, ScalingConfig, CheckpointConfig
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY
from ray.air._internal.checkpoint_manager import _TrackedCheckpoint
Expand Down Expand Up @@ -42,10 +41,7 @@ def __init__(
)

def _process_persistent_checkpoint(self, checkpoint: _TrackedCheckpoint):
if isinstance(checkpoint.dir_or_data, dict):
checkpoint.dir_or_data[PREPROCESSOR_KEY] = self.preprocessor
else:
save_preprocessor_to_dir(self.preprocessor, checkpoint.dir_or_data)
checkpoint.dir_or_data[PREPROCESSOR_KEY] = self.preprocessor
super(_DataParallelCheckpointManager, self)._process_persistent_checkpoint(
checkpoint=checkpoint
)
Expand Down
9 changes: 5 additions & 4 deletions python/ray/train/examples/horovod/horovod_pytorch_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from torchvision import datasets, transforms

from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.air.config import ScalingConfig
from ray.train.horovod import HorovodTrainer
from ray.train.torch.torch_checkpoint import TorchCheckpoint
import ray.train.torch


Expand Down Expand Up @@ -152,11 +152,12 @@ def train_func(config):
model, optimizer, train_sampler, train_loader, epoch, log_interval, use_cuda
)
if save_model_as_dict:
checkpoint = TorchCheckpoint.from_state_dict(model.state_dict())
checkpoint_dict = dict(model=model.state_dict())
else:
checkpoint = TorchCheckpoint.from_model(model)
checkpoint_dict = dict(model=model)
checkpoint_dict = Checkpoint.from_dict(checkpoint_dict)
results.append(loss)
session.report(dict(loss=loss), checkpoint=checkpoint)
session.report(dict(loss=loss), checkpoint=checkpoint_dict)

# Only used for testing.
return results
Expand Down
11 changes: 6 additions & 5 deletions python/ray/train/examples/pytorch/torch_linear_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torch.nn as nn
import ray.train as train
from ray.train.torch import TorchTrainer, TorchCheckpoint
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig


Expand Down Expand Up @@ -48,7 +48,8 @@ def validate_epoch(dataloader, model, loss_fn):
import copy

model_copy = copy.deepcopy(model)
return model_copy.cpu().state_dict(), loss
result = {"model": model_copy.cpu().state_dict(), "loss": loss}
return result


def train_func(config):
Expand All @@ -75,12 +76,12 @@ def train_func(config):
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

results = []

for _ in range(epochs):
train_epoch(train_loader, model, loss_fn, optimizer)
state_dict, loss = validate_epoch(validation_loader, model, loss_fn)
result = dict(loss=loss)
result = validate_epoch(validation_loader, model, loss_fn)
results.append(result)
session.report(result, checkpoint=TorchCheckpoint.from_state_dict(state_dict))
session.report(result)

return results

Expand Down
Loading

0 comments on commit 57ea8bd

Please sign in to comment.