diff --git a/keras/backend/common/global_state.py b/keras/backend/common/global_state.py index 9ebcacd51a1..8d72538ebdc 100644 --- a/keras/backend/common/global_state.py +++ b/keras/backend/common/global_state.py @@ -77,3 +77,9 @@ def clear_session(): from tensorflow.python.eager import context context.context().clear_kernel_cache() + elif backend.backend() == "torch": + import torch._dynamo as dynamo + + # reset's torchdynamo's cache so that cached guards, compiled fn, etc + # do not persist between clear_session() calls + dynamo.reset() diff --git a/keras/backend/torch/core.py b/keras/backend/torch/core.py index 70ba7cda7b3..af26646561d 100644 --- a/keras/backend/torch/core.py +++ b/keras/backend/torch/core.py @@ -149,6 +149,7 @@ def convert_to_tensor(x, dtype=None, sparse=False): return torch.as_tensor(x, dtype=torch.int32, device=get_device()) if isinstance(x, float): return torch.as_tensor(x, dtype=torch.float32, device=get_device()) + # Convert to np in case of any array-like that is not list or tuple. if not isinstance(x, (list, tuple)): x = np.array(x) @@ -180,7 +181,15 @@ def transform(x): def is_tensor(x): - return torch.is_tensor(x) + # Using the built-in `isinstance` is recommended by pytorch + # over using torch.is_tensor + # see: https://pytorch.org/docs/stable/generated/torch.is_tensor.html + # + # Also, `torch.is_tensor()` causes issues with dynamo caching when + # a torch.Tensor and numpy.ndarray of the same size, shape, and dtype + # is passed, if called on a Tensor first the second call with ndarray + # will return `True` and vice-versa. + return isinstance(x, torch.Tensor) def shape(x): diff --git a/keras/backend/torch/numpy.py b/keras/backend/torch/numpy.py index 17a2bf35c5a..a2290cf50c5 100644 --- a/keras/backend/torch/numpy.py +++ b/keras/backend/torch/numpy.py @@ -63,6 +63,9 @@ def mean(x, axis=None, keepdims=False): if axis == () or axis == []: # Torch handles the empty axis case differently from numpy. return x + elif isinstance(axis, int): + axis = (axis,) # see [NB] below + ori_dtype = standardize_dtype(x.dtype) # torch.mean only supports floating point inputs compute_dtype = dtypes.result_type(x.dtype, "float32") @@ -70,8 +73,22 @@ def mean(x, axis=None, keepdims=False): result_dtype = compute_dtype else: result_dtype = ori_dtype + + # [NB] the python torch op torch.mean() is generated into + # `torch._C._VariableFunctions.pyi`, and the method + # signature is overloaded. + # Dynamo won't actually find the correct signature of + # `torch.mean()` if arguments are passed via kwargs + # So we have to pass the arguments via positional args + # EXCEPT for those that are forced as kwargs via the `*` + # delimiter in the overloaded method signatures. + # Additionally, we have to create a singleton-tuple + # when `axis` is an int to match the existing fn signature result = torch.mean( - x, axis=axis, keepdims=keepdims, dtype=to_torch_dtype(compute_dtype) + x, + axis, + keepdims, + dtype=to_torch_dtype(compute_dtype), ) return cast(result, result_dtype) diff --git a/keras/backend/torch/random.py b/keras/backend/torch/random.py index b4365c02dac..f85732ca37a 100644 --- a/keras/backend/torch/random.py +++ b/keras/backend/torch/random.py @@ -10,6 +10,9 @@ from keras.random.seed_generator import make_default_seed +# torch.Generator not supported with dynamo +# see: https://github.com/pytorch/pytorch/issues/88576 +@torch.compiler.disable() def torch_seed_generator(seed): first_seed, second_seed = draw_seed(seed) device = get_device() diff --git a/keras/backend/torch/trainer.py b/keras/backend/torch/trainer.py index 34f0bd782d7..56f7eb6cf81 100644 --- a/keras/backend/torch/trainer.py +++ b/keras/backend/torch/trainer.py @@ -102,12 +102,7 @@ def one_step_on_data(data): return self.train_step(data) if self.jit_compile: - raise ValueError( - "`jit_compile` is not yet enabled for the PyTorch backend." - ) - # Temporarily disabled torch compile due to failed unit tests. - # TODO: Uncomment the following line when unit tests passes. - # self.train_function = torch.compile(one_step_on_data) + self.train_function = torch.compile(one_step_on_data) else: self.train_function = one_step_on_data @@ -127,7 +122,10 @@ def one_step_on_data(data): with torch.no_grad(): return self.test_step(data) - self.test_function = one_step_on_data + if self.jit_compile: + self.test_function = torch.compile(one_step_on_data) + else: + self.test_function = one_step_on_data def make_predict_function(self, force=False): if self.predict_function is not None and not force: @@ -145,7 +143,10 @@ def one_step_on_data(data): with torch.no_grad(): return self.predict_step(data) - self.predict_function = one_step_on_data + if self.jit_compile: + self.predict_function = torch.compile(one_step_on_data) + else: + self.predict_function = one_step_on_data def _symbolic_build(self, data_batch): model_unbuilt = not all(layer.built for layer in self._flatten_layers()) diff --git a/keras/layers/reshaping/flatten.py b/keras/layers/reshaping/flatten.py index da72ce961dd..0923f33da4b 100644 --- a/keras/layers/reshaping/flatten.py +++ b/keras/layers/reshaping/flatten.py @@ -57,7 +57,12 @@ def compute_output_shape(self, input_shape): non_batch_dims = input_shape[1:] if len(non_batch_dims) == 0: flattened_dim = 1 - elif None in non_batch_dims: + elif any(d is None for d in non_batch_dims): + # NB: we cannot use the shorter `None in non_batch_dims` here b/c + # torchdynamo errors when calling `__contains__` op with + # a constant (in this case `None`) operand since it assumes + # that the elements in the collection are also `ConstantVariable`s + # but tensor shapes can be `SymNodeVariable`s (e.g. `SymInt`) flattened_dim = None else: flattened_dim = math.prod(non_batch_dims) diff --git a/keras/layers/reshaping/up_sampling2d.py b/keras/layers/reshaping/up_sampling2d.py index 0e593de09e7..969839f9ab7 100644 --- a/keras/layers/reshaping/up_sampling2d.py +++ b/keras/layers/reshaping/up_sampling2d.py @@ -1,5 +1,3 @@ -import numpy as np - from keras import backend from keras import ops from keras.api_export import keras_export @@ -149,10 +147,6 @@ def _resize_images( else: raise ValueError(f"Invalid `data_format` argument: {data_format}") - new_shape = x.shape[rows : cols + 1] - new_shape *= np.array([height_factor, width_factor]) - new_shape = new_shape.tolist() - if data_format == "channels_first": x = ops.transpose(x, [0, 2, 3, 1]) # https://github.com/keras-team/keras/issues/294 @@ -161,6 +155,17 @@ def _resize_images( x = ops.repeat(x, height_factor, axis=1) x = ops.repeat(x, width_factor, axis=2) else: + # multiply the height and width factor on each dim + # by hand (versus using element-wise multiplication + # by np.array([height_factor, width_factor]) then + # list-ifying the tensor by calling `.tolist()`) + # since when running under torchdynamo, `new_shape` + # will be traced as a symbolic variable (specifically + # a `FakeTensor`) which does not have a `tolist()` method. + new_shape = ( + x.shape[rows] * height_factor, + x.shape[cols] * width_factor, + ) x = ops.image.resize(x, new_shape, interpolation=interpolation) if data_format == "channels_first": x = ops.transpose(x, [0, 3, 1, 2]) diff --git a/keras/testing/test_case.py b/keras/testing/test_case.py index 91684f6dab7..b4c738819ba 100644 --- a/keras/testing/test_case.py +++ b/keras/testing/test_case.py @@ -11,6 +11,7 @@ from keras import utils from keras.backend.common import is_float_dtype from keras.backend.common import standardize_dtype +from keras.backend.common.global_state import clear_session from keras.backend.common.keras_tensor import KerasTensor from keras.models import Model from keras.utils import traceback_utils @@ -24,6 +25,12 @@ def __init__(self, *args, **kwargs): if traceback_utils.is_traceback_filtering_enabled(): traceback_utils.disable_traceback_filtering() + def setUp(self): + # clear global state so that test cases are independent + # required for the jit enabled torch tests since dynamo has + # a global cache for guards, compiled fn, etc + clear_session() + def get_temp_dir(self): temp_dir = tempfile.mkdtemp() self.addCleanup(lambda: shutil.rmtree(temp_dir)) @@ -329,7 +336,19 @@ def call(self, x): output_data = tree.map_structure( lambda x: backend.convert_to_numpy(x), output_data ) - model.compile(optimizer="sgd", loss="mse", jit_compile=True) + # test the "default" path for each backend by setting + # jit_compile="auto. + # for tensorflow and jax backends auto is jitted + # for torch backend auto is eager + # + # NB: for torch, jit_compile=True turns on torchdynamo + # which may not always succeed in tracing depending + # on the model. Run your program with these env vars + # to get debug traces of dynamo: + # TORCH_LOGS="+dynamo" + # TORCHDYNAMO_VERBOSE=1 + # TORCHDYNAMO_REPORT_GUARD_FAILURES=1 + model.compile(optimizer="sgd", loss="mse", jit_compile="auto") model.fit(input_data, output_data, verbose=0) # Build test. diff --git a/keras/trainers/epoch_iterator.py b/keras/trainers/epoch_iterator.py index 4e1a776d3dc..b5d7606c077 100644 --- a/keras/trainers/epoch_iterator.py +++ b/keras/trainers/epoch_iterator.py @@ -122,7 +122,7 @@ def enumerate_epoch(self, return_type="auto"): if buffer: yield step - len(buffer) + 1, buffer if not self._num_batches: - # Infer the number of batches returned by the data_adater. + # Infer the number of batches returned by the data_adapter. # Assumed static. self._num_batches = step + 1 self.data_adapter.on_epoch_end() diff --git a/keras/trainers/trainer.py b/keras/trainers/trainer.py index 8606b4cbf5d..b70a3a31bc4 100644 --- a/keras/trainers/trainer.py +++ b/keras/trainers/trainer.py @@ -104,7 +104,7 @@ def compile( and to set it to `True` when debugging. steps_per_execution: Int. The number of batches to run during each a single compiled function call. Running multiple - batches inside a single a single compiled function call can + batches inside a single compiled function call can greatly improve performance on TPUs or small models with a large Python overhead. At most, one full epoch will be run each execution. If a number larger than the size of the epoch is @@ -115,9 +115,12 @@ def compile( each compiled function execution). Not supported with the PyTorch backend. jit_compile: Bool or `"auto"`. Whether to use XLA compilation when - compiling a model. Not supported with the PyTorch backend. - If `"auto"`, XLA compilation will be enabled if the - the model supports it, and disabled otherwise. + compiling a model. For `jax` and `tensorflow` backends, + `jit_compile="auto"` enables XLA compilation if the model + supports it, and disabled otherwise. + For `torch` backend, `"auto"` will default to eager + execution and `jit_compile=True` will run with `torch.compile` + with the `"inductor"` backend. auto_scale_loss: Bool. If `True` and the model dtype policy is `"mixed_float16"`, the passed optimizer will be automatically wrapped in a `LossScaleOptimizer`, which will dynamically @@ -162,12 +165,7 @@ def compile( "cannot also be True. Disabling `jit_compile`.", stacklevel=2, ) - if jit_compile and backend.backend() == "torch": - warnings.warn( - "`jit_compile` is not yet enabled for the PyTorch backend. " - "Proceeding with `jit_compile=False`." - ) - jit_compile = False + self.jit_compile = jit_compile self.run_eagerly = run_eagerly self.stop_training = False @@ -194,7 +192,10 @@ def compile( def jit_compile(self): if self._jit_compile is None: # Value was never set. Resolve it now. - jit_compile = model_supports_jit(self) + # torch always runs in eager unless jit_compile is explicitly set + jit_compile = ( + model_supports_jit(self) and backend.backend() != "torch" + ) self._jit_compile = jit_compile return self._jit_compile @@ -866,11 +867,11 @@ def _assert_compile_called(self, method_name=None): def resolve_auto_jit_compile(model): + if backend.backend() == "torch": + # jit_compile = "auto" with the pytorch backend defaults to eager + return False + if model_supports_jit(model): - if backend.backend() == "torch": - # Torch defaults to eager mode - # until torch compile is reliable - return False return True return False diff --git a/keras/trainers/trainer_test.py b/keras/trainers/trainer_test.py index 80f108b3ae9..ee96a0e5f8c 100644 --- a/keras/trainers/trainer_test.py +++ b/keras/trainers/trainer_test.py @@ -148,6 +148,18 @@ def __init__(self, units): ) self.assertEqual(len(model_weighted.metrics), 3) + @pytest.mark.skipif( + backend.backend() != "torch", + reason="torch backend runs in eager mode for jit_compile='auto'", + ) + def test_compile_eager_vs_jit_torch(self): + model = ExampleModel(units=3) + model.compile(jit_compile="auto") + # torch trainer en/disables torch.compile only based on the value of + # model.jit_compile (not model.run_eagerly) + self.assertFalse(model.run_eagerly) + self.assertFalse(model.jit_compile) + @parameterized.named_parameters( [ ("eager", True, False, False), @@ -292,6 +304,14 @@ def test_predict_flow(self, run_eagerly, jit_compile): outputs = model.predict(x, batch_size=batch_size) self.assertAllClose(outputs, 4 * np.ones((100, 3))) + @parameterized.named_parameters( + [ + ("eager", True, False), + ("graph_fn", False, False), + ("jit", False, True), + ] + ) + def test_predict_flow_struct(self, run_eagerly, jit_compile): # Test with input/output structs model = StructModel(units=3) model.run_eagerly = run_eagerly diff --git a/keras/utils/naming_test.py b/keras/utils/naming_test.py index c5b0752a191..6be61fdbefe 100644 --- a/keras/utils/naming_test.py +++ b/keras/utils/naming_test.py @@ -61,7 +61,13 @@ def test_uniquify_already_uniquified_name(self): name = "unique_name" unique_name = naming.uniquify(name) new_unique_name = naming.uniquify(unique_name) - self.assertEqual(new_unique_name, unique_name) + + # first time `name` is uniquified so returns same name + self.assertEqual(name, unique_name) + + # second time `name` is uniquified should be different + # from the first output + self.assertNotEqual(new_unique_name, unique_name) def test_to_snake_case_capital_after_any_character(self): name = "myVariableNameHere" diff --git a/requirements.txt b/requirements.txt index 39b351e3481..33eb38e34d9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,8 +2,8 @@ tf-nightly==2.15.0.dev20231009 # Pin a working nightly until rc0. # Torch. -torch>=2.0.1 -torchvision>=0.15.1 +torch>=2.1.0 +torchvision>=0.16.0 # Jax. jax[cpu]