-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
(re)enable torch.compile in the pytorch trainer for train, predict, and eval #18569
Conversation
8212c7d
to
22ec236
Compare
Codecov ReportAll modified lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #18569 +/- ##
==========================================
+ Coverage 78.11% 78.30% +0.18%
==========================================
Files 334 334
Lines 32477 32484 +7
Branches 6339 6342 +3
==========================================
+ Hits 25371 25438 +67
+ Misses 5539 5482 -57
+ Partials 1567 1564 -3
Flags with carried forward coverage won't be shown. Click here to find out more.
☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
91ddb48
to
e166beb
Compare
This is actually very fixable. Instead of using The reason |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM -- I think doing the tree
conversion would likely unlock the performance benefits here.
@@ -2,8 +2,8 @@ | |||
tf-nightly==2.15.0.dev20231009 # Pin a working nightly until rc0. | |||
|
|||
# Torch. | |||
torch>=2.0.1 | |||
torchvision>=0.15.1 | |||
torch>=2.1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@grasskin - FYI. I remember Gabriel wanting to keep the requirements as torch 2.0.1
. So wanted him to take a look or be in the loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @sampathweb, @grasskin, let me know if we have a good reason to stay at 2.0.1. I'd like to update to 2.1 if possible since it has a bunch of fixes (especially to torch.compile)
Yep I created an issue for this (#18614). I can do this in a fast-follow PR since this one is getting big and the torch backend defaults to eager right now. |
Happy to merge this now since CI is passing and we can do the rest in future PRs. If the updated torch version is an issue we can revert that part later. |
Change Summary
torch
totorch>=2.1.0
(which has many improvements to dynamo)torch.compile
whenjit_compile=True
for train, eval, and predictmodel.fit()
explaining thatjit_compile="auto"
defaults to eager for the torch backend (torch.compile
only kicks in if the user explicitly setsjit_compile=True
).setUp()
toclear_session()
intesting.TestCase
(required for dynamo)naming_test.py:test_uniquify_already_uniquified_name()
jit_compile="auto"
(versusjit_compile=True
) inkeras.testing.test_case.TestCase.run_layer_test.run_training_step()
so that the backends are tested in their "default" jitted mode (jit for tf and jax and eager for torch).Note On Dynamo
Currently there are two caveats to running torch backend with
jit_compile=True
(performance) It is slower than eager because of too many graph breaks, which is mainly due to the usage of
tree
in the function (dynamo will not trace throughtree
, see skipfiles)any_symbolic_tensors()
, which in turn is called by pretty much all ops (e.g. numpy, layer, activation, etc). Therefore, no "deep-graph" can be captured and hence no opportunities for optimizations such as op-fusion. This can be fixed by not usingtree.flatten
inany_symbolic_tensors()
(overhead)
torch.core.convert_to_tensor
needs to be simplified to just callingtorch.as_tensor(x, dtype, device)
rather than usingx.to(dtype, device)
. This won't make things compile better but reduces frame eval overhead sinceconvert_to_tensor
is called for each op and tracing through many branches is less than ideal.(compatibility) There are cases where primitive operators can be traced by dynamo, but when a sequence of them are used as a higher order operator such as a layer (e.g. up_sampling_2d), causes guard failures on the primitive ops, which in turn makes dynamo trace with dynamic shapes via symbolic variables rather than concretized values, which can often lead to tracing failures due to "missing methods".
Testing
CI for unittests
Manual testing on
examples/keras_io/vision/mnist_convnet.py
by explicitly enablingjit_compile
.Observations: