-
Notifications
You must be signed in to change notification settings - Fork 491
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lowering Aten op to composite op instead of small ops #8502
base: master
Are you sure you want to change the base?
Conversation
torch_xla/csrc/ops/ops.cpp
Outdated
|
||
// Building composite computation. | ||
const std::string name = "composite.gelu"; | ||
const std::string attr = "{approximate = \"none\"}"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a dummy str for testing purpose?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a real op attribution for GELU
: https://pytorch.org/docs/stable/generated/torch.nn.GELU.html#torch.nn.GELU
The available value of approximate
is none
or tanh
. The lowering process checks this attribution and decides the sub lower function here. As my changes are in the sub lower function, I manually set this attribution.
It's a common process for composite op which has attributions (defined as non-tensor inputs for composite op, e.g. dim
for Softmax
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I can get the attribution from XlaOp instead of manually setting strings, I will try.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What information is important for optimizations on this op? One option could be to have the composites implied value be none
, and not generate a composite if tanh
is used, of if the value is needed, then this looks good as-is, not a great API for making these in XLA, since they're an MLIR-first feature currently. If an MLIR dep is allowed in this file, you could build an MLIR dict and then dump to string before calling the XLA builder method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I believe that the approximate
attribute here serves little purpose other than enhancing IR readability. Nevertheless, I've retained all the relevant information according to the public guidance stablehlo.composite. If no more further comments or feedback, I can remove this information.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
approximate
for Gelu
and GeluBackward
is removed.
torch_xla/csrc/ops/ops.cpp
Outdated
// Building call to computation. | ||
std::vector<xla::XlaOp> inputs{xla_input}; | ||
xla::XlaOp output = xla::CompositeCall(loctx->builder(), computation, inputs, name, | ||
attr, /*version=*/1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto for version
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this is for testing, I learned this setting from this XLA UT. I can remove it if it makes no scense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
version
is removed.
Hi @Zantares, thanks for the PR! As long as the normal compilation/execution flow is not affected, I think this actually makes the HLO/StableHLO cleaner. Not sure if you have access to TPU to see if the code example in the PR description would run on TPU as well. I applied TPU CI tag, could you please push again to see if TPU CI pass? |
cc @GleasonK in case you know if the composite HLO op affects complication flow |
Thanks! I will fix the format error and push it again to trigger TPU CI. |
torch_xla/csrc/ops/ops.cpp
Outdated
return node.ReturnOp(BuildGelu(xla_input), loctx); | ||
|
||
// Building composite computation. | ||
const std::string name = "composite.gelu"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The purpose of namespacing is to be able to tell who the maintainer or origin of a given composite is. If something changes about GELU (new attr, etc) who is on the hook to maintain it (for composites this is usually intended to be a vendor who has a library nvidia.some_op
, aws.some_op
, litert.some_op
etc). The name composite
doesn't answer this question.
For this I'd recommend ptxla.gelu
or aten.gelu
as names.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As an aside, I'd be careful representing converting every (or even many) aten operations using composites. They have a maintenance overhead in terms of needing to consider forward/backward compatibilty (i.e. the above thing about "what if gelu changes, who fixes?"), for some ops like gelu/softmax its probably ok, they don't tend to change much and usually look somewhat uniform.
For other aten ops that have very specific HW support, I'd recommend an approach that decentralizes composite ownership/maintenance, i.e. FX graph rewrite-as-composite API or make composite builder work for these use cases. This is how Google AI Edge uses composites today, they own the library and the compatibility for the (very small) subset of ATen ops that they have HW support for.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(cc @lsy323)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestions, here are some of my thoughts:
- The namespace
ptxla
sounds quite good, and I will address it. - Regarding the compatibility/ownership of these ops in this PR, I believe it still belongs to
torch-xla
. The reason is that I didn't re-implement these ops from scratch, I simply wrapped the original implementation (such asBuildGelu
andLowerSoftmax
) with a composite call. W/O this PR, those implementations would need to be fixed if there were any changes in the semantics. At the current stage, I don't have any plans to introduce new composite ops that don't have an original implementation intorch-xla
- Since this PR is aimed at resolving training issues, I'm uncertain whether the composite builder will work or not. Judging from the discussions in the attached issues, it appears that it might not work for training purposes. Could you please share some examples or guidance on:
FX graph rewrite-as-composite API or make composite builder work for these use cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Namespace is addressed.
Hello, @Zantares, sorry for the delay, we are deprecating the current workflow of generating StableHLO from PyTorch programs. In #8544 we have the updated document to export PyTorch program using torch_xla2. I would encourage you to explore along the new workflow, however I have no problem landing this change since there is no side effect. |
Thanks. I'll go ahead and check out the new path. Before that, I have a quick question here: I have followed the official docs <OP Lowering Guide> and <Codegen migration Guide> to understand the lowering process, will the new |
Hi @Zantares, here it the guide https://github.com/pytorch/xla/tree/master/experimental/torch_xla2 |
This PR is to solve the 2nd question in this issue: supports composite op in training.
Motivation
Composite op is beneficial for performance optimization and we aim to apply it to training too. According to the response in the issue, the community has no plan to extend this to training currently... Thus, I created this draft PR to demonstrate our intention.
Detail
This PR alters the Aten op lowering process when there isn't a 1:1 mapping to XLA op. It uses composite call instead of small XLA ops. Later, in the optimization process, the composite call can be easily replaced with a custom kernel or decomposed.
If this PR gets accepted, here are some further suggestions:
XLA_COMPOSITE_OP
) to enable this feature. Also, add an op list setting to define which ops can be composed.Example
With this PR, the generated StableHLO is: