Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(re)enable torch.compile in the pytorch trainer for train, predict, and eval #18569

Merged
merged 1 commit into from
Oct 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions keras/backend/common/global_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,9 @@ def clear_session():
from tensorflow.python.eager import context

context.context().clear_kernel_cache()
elif backend.backend() == "torch":
import torch._dynamo as dynamo

# reset's torchdynamo's cache so that cached guards, compiled fn, etc
# do not persist between clear_session() calls
dynamo.reset()
11 changes: 10 additions & 1 deletion keras/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def convert_to_tensor(x, dtype=None, sparse=False):
return torch.as_tensor(x, dtype=torch.int32, device=get_device())
if isinstance(x, float):
return torch.as_tensor(x, dtype=torch.float32, device=get_device())

# Convert to np in case of any array-like that is not list or tuple.
if not isinstance(x, (list, tuple)):
x = np.array(x)
Expand Down Expand Up @@ -180,7 +181,15 @@ def transform(x):


def is_tensor(x):
return torch.is_tensor(x)
# Using the built-in `isinstance` is recommended by pytorch
# over using torch.is_tensor
# see: https://pytorch.org/docs/stable/generated/torch.is_tensor.html
#
# Also, `torch.is_tensor()` causes issues with dynamo caching when
# a torch.Tensor and numpy.ndarray of the same size, shape, and dtype
# is passed, if called on a Tensor first the second call with ndarray
# will return `True` and vice-versa.
return isinstance(x, torch.Tensor)


def shape(x):
Expand Down
19 changes: 18 additions & 1 deletion keras/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,32 @@ def mean(x, axis=None, keepdims=False):
if axis == () or axis == []:
# Torch handles the empty axis case differently from numpy.
return x
elif isinstance(axis, int):
axis = (axis,) # see [NB] below

ori_dtype = standardize_dtype(x.dtype)
# torch.mean only supports floating point inputs
compute_dtype = dtypes.result_type(x.dtype, "float32")
if "int" in ori_dtype or ori_dtype == "bool":
result_dtype = compute_dtype
else:
result_dtype = ori_dtype

# [NB] the python torch op torch.mean() is generated into
# `torch._C._VariableFunctions.pyi`, and the method
# signature is overloaded.
# Dynamo won't actually find the correct signature of
# `torch.mean()` if arguments are passed via kwargs
# So we have to pass the arguments via positional args
# EXCEPT for those that are forced as kwargs via the `*`
# delimiter in the overloaded method signatures.
# Additionally, we have to create a singleton-tuple
# when `axis` is an int to match the existing fn signature
result = torch.mean(
x, axis=axis, keepdims=keepdims, dtype=to_torch_dtype(compute_dtype)
x,
axis,
keepdims,
dtype=to_torch_dtype(compute_dtype),
)
return cast(result, result_dtype)

Expand Down
3 changes: 3 additions & 0 deletions keras/backend/torch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from keras.random.seed_generator import make_default_seed


# torch.Generator not supported with dynamo
# see: https://github.com/pytorch/pytorch/issues/88576
@torch.compiler.disable()
def torch_seed_generator(seed):
first_seed, second_seed = draw_seed(seed)
device = get_device()
Expand Down
17 changes: 9 additions & 8 deletions keras/backend/torch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,7 @@ def one_step_on_data(data):
return self.train_step(data)

if self.jit_compile:
raise ValueError(
"`jit_compile` is not yet enabled for the PyTorch backend."
)
# Temporarily disabled torch compile due to failed unit tests.
# TODO: Uncomment the following line when unit tests passes.
# self.train_function = torch.compile(one_step_on_data)
self.train_function = torch.compile(one_step_on_data)
else:
self.train_function = one_step_on_data

Expand All @@ -127,7 +122,10 @@ def one_step_on_data(data):
with torch.no_grad():
return self.test_step(data)

self.test_function = one_step_on_data
if self.jit_compile:
self.test_function = torch.compile(one_step_on_data)
else:
self.test_function = one_step_on_data

def make_predict_function(self, force=False):
if self.predict_function is not None and not force:
Expand All @@ -145,7 +143,10 @@ def one_step_on_data(data):
with torch.no_grad():
return self.predict_step(data)

self.predict_function = one_step_on_data
if self.jit_compile:
self.predict_function = torch.compile(one_step_on_data)
else:
self.predict_function = one_step_on_data

def _symbolic_build(self, data_batch):
model_unbuilt = not all(layer.built for layer in self._flatten_layers())
Expand Down
7 changes: 6 additions & 1 deletion keras/layers/reshaping/flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,12 @@ def compute_output_shape(self, input_shape):
non_batch_dims = input_shape[1:]
if len(non_batch_dims) == 0:
flattened_dim = 1
elif None in non_batch_dims:
elif any(d is None for d in non_batch_dims):
# NB: we cannot use the shorter `None in non_batch_dims` here b/c
# torchdynamo errors when calling `__contains__` op with
# a constant (in this case `None`) operand since it assumes
# that the elements in the collection are also `ConstantVariable`s
# but tensor shapes can be `SymNodeVariable`s (e.g. `SymInt`)
flattened_dim = None
else:
flattened_dim = math.prod(non_batch_dims)
Expand Down
17 changes: 11 additions & 6 deletions keras/layers/reshaping/up_sampling2d.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import numpy as np

from keras import backend
from keras import ops
from keras.api_export import keras_export
Expand Down Expand Up @@ -149,10 +147,6 @@ def _resize_images(
else:
raise ValueError(f"Invalid `data_format` argument: {data_format}")

new_shape = x.shape[rows : cols + 1]
new_shape *= np.array([height_factor, width_factor])
new_shape = new_shape.tolist()

if data_format == "channels_first":
x = ops.transpose(x, [0, 2, 3, 1])
# https://github.com/keras-team/keras/issues/294
Expand All @@ -161,6 +155,17 @@ def _resize_images(
x = ops.repeat(x, height_factor, axis=1)
x = ops.repeat(x, width_factor, axis=2)
else:
# multiply the height and width factor on each dim
# by hand (versus using element-wise multiplication
# by np.array([height_factor, width_factor]) then
# list-ifying the tensor by calling `.tolist()`)
# since when running under torchdynamo, `new_shape`
# will be traced as a symbolic variable (specifically
# a `FakeTensor`) which does not have a `tolist()` method.
new_shape = (
x.shape[rows] * height_factor,
x.shape[cols] * width_factor,
)
x = ops.image.resize(x, new_shape, interpolation=interpolation)
if data_format == "channels_first":
x = ops.transpose(x, [0, 3, 1, 2])
Expand Down
21 changes: 20 additions & 1 deletion keras/testing/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from keras import utils
from keras.backend.common import is_float_dtype
from keras.backend.common import standardize_dtype
from keras.backend.common.global_state import clear_session
from keras.backend.common.keras_tensor import KerasTensor
from keras.models import Model
from keras.utils import traceback_utils
Expand All @@ -24,6 +25,12 @@ def __init__(self, *args, **kwargs):
if traceback_utils.is_traceback_filtering_enabled():
traceback_utils.disable_traceback_filtering()

def setUp(self):
# clear global state so that test cases are independent
# required for the jit enabled torch tests since dynamo has
# a global cache for guards, compiled fn, etc
clear_session()

def get_temp_dir(self):
temp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(temp_dir))
Expand Down Expand Up @@ -329,7 +336,19 @@ def call(self, x):
output_data = tree.map_structure(
lambda x: backend.convert_to_numpy(x), output_data
)
model.compile(optimizer="sgd", loss="mse", jit_compile=True)
# test the "default" path for each backend by setting
# jit_compile="auto.
# for tensorflow and jax backends auto is jitted
# for torch backend auto is eager
#
# NB: for torch, jit_compile=True turns on torchdynamo
# which may not always succeed in tracing depending
# on the model. Run your program with these env vars
# to get debug traces of dynamo:
# TORCH_LOGS="+dynamo"
# TORCHDYNAMO_VERBOSE=1
# TORCHDYNAMO_REPORT_GUARD_FAILURES=1
model.compile(optimizer="sgd", loss="mse", jit_compile="auto")
model.fit(input_data, output_data, verbose=0)

# Build test.
Expand Down
2 changes: 1 addition & 1 deletion keras/trainers/epoch_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def enumerate_epoch(self, return_type="auto"):
if buffer:
yield step - len(buffer) + 1, buffer
if not self._num_batches:
# Infer the number of batches returned by the data_adater.
# Infer the number of batches returned by the data_adapter.
# Assumed static.
self._num_batches = step + 1
self.data_adapter.on_epoch_end()
Expand Down
31 changes: 16 additions & 15 deletions keras/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def compile(
and to set it to `True` when debugging.
steps_per_execution: Int. The number of batches to run
during each a single compiled function call. Running multiple
batches inside a single a single compiled function call can
batches inside a single compiled function call can
greatly improve performance on TPUs or small models with a large
Python overhead. At most, one full epoch will be run each
execution. If a number larger than the size of the epoch is
Expand All @@ -115,9 +115,12 @@ def compile(
each compiled function execution).
Not supported with the PyTorch backend.
jit_compile: Bool or `"auto"`. Whether to use XLA compilation when
compiling a model. Not supported with the PyTorch backend.
If `"auto"`, XLA compilation will be enabled if the
the model supports it, and disabled otherwise.
compiling a model. For `jax` and `tensorflow` backends,
`jit_compile="auto"` enables XLA compilation if the model
supports it, and disabled otherwise.
For `torch` backend, `"auto"` will default to eager
execution and `jit_compile=True` will run with `torch.compile`
with the `"inductor"` backend.
auto_scale_loss: Bool. If `True` and the model dtype policy is
`"mixed_float16"`, the passed optimizer will be automatically
wrapped in a `LossScaleOptimizer`, which will dynamically
Expand Down Expand Up @@ -162,12 +165,7 @@ def compile(
"cannot also be True. Disabling `jit_compile`.",
stacklevel=2,
)
if jit_compile and backend.backend() == "torch":
warnings.warn(
"`jit_compile` is not yet enabled for the PyTorch backend. "
"Proceeding with `jit_compile=False`."
)
jit_compile = False

self.jit_compile = jit_compile
self.run_eagerly = run_eagerly
self.stop_training = False
Expand All @@ -194,7 +192,10 @@ def compile(
def jit_compile(self):
if self._jit_compile is None:
# Value was never set. Resolve it now.
jit_compile = model_supports_jit(self)
# torch always runs in eager unless jit_compile is explicitly set
jit_compile = (
model_supports_jit(self) and backend.backend() != "torch"
)
self._jit_compile = jit_compile
return self._jit_compile

Expand Down Expand Up @@ -866,11 +867,11 @@ def _assert_compile_called(self, method_name=None):


def resolve_auto_jit_compile(model):
if backend.backend() == "torch":
# jit_compile = "auto" with the pytorch backend defaults to eager
return False

if model_supports_jit(model):
if backend.backend() == "torch":
# Torch defaults to eager mode
# until torch compile is reliable
return False
return True
return False

Expand Down
20 changes: 20 additions & 0 deletions keras/trainers/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,18 @@ def __init__(self, units):
)
self.assertEqual(len(model_weighted.metrics), 3)

@pytest.mark.skipif(
backend.backend() != "torch",
reason="torch backend runs in eager mode for jit_compile='auto'",
)
def test_compile_eager_vs_jit_torch(self):
model = ExampleModel(units=3)
model.compile(jit_compile="auto")
# torch trainer en/disables torch.compile only based on the value of
# model.jit_compile (not model.run_eagerly)
self.assertFalse(model.run_eagerly)
self.assertFalse(model.jit_compile)

@parameterized.named_parameters(
[
("eager", True, False, False),
Expand Down Expand Up @@ -292,6 +304,14 @@ def test_predict_flow(self, run_eagerly, jit_compile):
outputs = model.predict(x, batch_size=batch_size)
self.assertAllClose(outputs, 4 * np.ones((100, 3)))

@parameterized.named_parameters(
[
("eager", True, False),
("graph_fn", False, False),
("jit", False, True),
]
)
def test_predict_flow_struct(self, run_eagerly, jit_compile):
# Test with input/output structs
model = StructModel(units=3)
model.run_eagerly = run_eagerly
Expand Down
8 changes: 7 additions & 1 deletion keras/utils/naming_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,13 @@ def test_uniquify_already_uniquified_name(self):
name = "unique_name"
unique_name = naming.uniquify(name)
new_unique_name = naming.uniquify(unique_name)
self.assertEqual(new_unique_name, unique_name)

# first time `name` is uniquified so returns same name
self.assertEqual(name, unique_name)

# second time `name` is uniquified should be different
# from the first output
self.assertNotEqual(new_unique_name, unique_name)

def test_to_snake_case_capital_after_any_character(self):
name = "myVariableNameHere"
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
tf-nightly==2.15.0.dev20231009 # Pin a working nightly until rc0.

# Torch.
torch>=2.0.1
torchvision>=0.15.1
torch>=2.1.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@grasskin - FYI. I remember Gabriel wanting to keep the requirements as torch 2.0.1. So wanted him to take a look or be in the loop.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @sampathweb, @grasskin, let me know if we have a good reason to stay at 2.0.1. I'd like to update to 2.1 if possible since it has a bunch of fixes (especially to torch.compile)

torchvision>=0.16.0

# Jax.
jax[cpu]
Expand Down