Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(torch) Use pure-python implementation of tree when in dynamo context #18614

Open
kiukchung opened this issue Oct 15, 2023 · 6 comments
Open
Assignees
Labels
backend:torch type:feature The user is asking for a new feature.

Comments

@kiukchung
Copy link
Contributor

Motivation

any_symbolic_tensors is called on pretty much all ops and it internally uses tree.flatten() to check if any positional or keyword arguments to the op is a KerasTensor (e.g. symbolic tensor). torchdynamo skips tracing tree (presumably since it has C-bindings) therefore causing graph-breaks at each op. This results in poor jitted performance for the pytorch backend since the graph-breaks occur between each op and we lose opportunities for any significant compiler optimizations (e.g. operator fusion).

See #18569 for more details.

Proposal

NOTE: making any_symbolic_tensors() does not guarantee everything in keras will be dynamo compatible. Once we fix this other issues may arise.

  1. Povide a dynamo traceable pure-python version of tree.flatten() and use that instead of tree.flatten() to prevent graph-breaks at any_symbolic_tensors().
  2. If 1) is not enough, that is, we now observe graph breaks (albeit not as frequent) due to other usages of tree.* then (as suggested by @fchollet in (re)enable torch.compile in the pytorch trainer for train, predict, and eval #18569) we need to create a keras.utils.tree that uses pure-python implementations when in dynamo context and replace usages of tree.* with keras.utils.tree.*.

My suggestion is to first to 1), then see if 2) is needed as 2) is a bigger change that we may not actually need.

@fchollet
Copy link
Contributor

Do we have a way to tell when we're in a Dynamo context?

@AakashKumarNain
Copy link
Contributor

AakashKumarNain commented Oct 16, 2023

we need to create a keras.utils.tree that uses pure-python implementations when in dynamo context and replace usages of tree.* with keras.utils.tree.*.

Should we start using Optree instead?

@mattdangerw
Copy link
Member

If we add a keras.utils.tree let's make sure to export it. We are using dm-tree in KerasNLP downstream of Keras, and if we have a torch friendly nested solution, it would be great to be able to leverage.

@sachinprasadhs sachinprasadhs added type:feature The user is asking for a new feature. backend:torch labels Oct 18, 2023
@ASEM000
Copy link

ASEM000 commented Dec 11, 2023

+1 for optree.

@haifeng-jin
Copy link
Contributor

Hi @james77777778 ,

Seems optree is not implemented with Python only. It uses some C, too.
We found it is still not compatible with the torch dynamo.

Is my understanding correct?
What was the reason that we want to swap dm-tree to optree, please?

We may need this info for future work of supporting torch dynamo.
Thanks!

@james77777778
Copy link
Contributor

Is my understanding correct? What was the reason that we want to swap dm-tree to optree, please?

Actually, I just picked the item from: #18442 (a few months ago)

In the PR, I didn't achieve a significant speed-up by replacing dm-tree with optree for torch backend:
#19306 (comment)

It's a bit strange for me as well, considering that optree has been integrated into torch.

Refs:

Perhaps we need a completely pure python implementation for dynamo?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:torch type:feature The user is asking for a new feature.
Projects
None yet
Development

No branches or pull requests

8 participants