Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace the UNET custom attention processors #1608

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

yafshar
Copy link
Contributor

@yafshar yafshar commented Dec 13, 2024

  • Replace the the UNET custom attention processors with the default implementation on HPU.

What does this PR do?

Fixes # (issue) on G3

>>> python -m pytest tests/test_diffusers.py -v -s -k test_deterministic_image_generation --junitxml=report4.xml

tests/test_diffusers.py:3752:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:116: in decorate_context
    return func(*args, **kwargs)
optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py:594: in __call__
    noise_pred = self.unet_hpu(
/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:116: in decorate_context
    return func(*args, **kwargs)
optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py:732: in unet_hpu
    return self.capture_replay(latent_model_input, timestep, encoder_hidden_states)
/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:116: in decorate_context
    return func(*args, **kwargs)
optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py:756: in capture_replay
    graph.capture_end()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <habana_frameworks.torch.hpu.graphs.HPUGraph object at 0x7f8eb9186d40>

    def capture_end(self):
        r"""
        Ends HPU graph capture on the current stream.
        After ``capture_end``, ``replay`` may be called on this instance.
        """
>       _hpu_C.capture_end(self.hpu_graph)
E       RuntimeError: Graph compile failed. synStatus=synStatus 26 [Generic failure].

/usr/local/lib/python3.10/dist-packages/habana_frameworks/torch/hpu/graphs.py:64: RuntimeError

It sounds like that the UNET custom attention processors is the issue, if we replace it with the default implementation. The error goes away

This PR fixes the issue

>>> python -m pytest tests/test_diffusers.py -v -s -k test_deterministic_image_generation --junitxml=report4.xml

-------------------------------------- generated xml file: /root/optimum-habana/report4.xml ---------------------------------------
======================================= 2 passed, 163 deselected, 4 warnings in 154.12s (0:02:34) ========================================

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

- Replace the the custom attention processors with the default
  implementation.
@yafshar yafshar requested a review from regisss as a code owner December 13, 2024 18:36
@yafshar yafshar changed the title Replace the custom attention processors Replace the UNET custom attention processors Dec 13, 2024
@yafshar
Copy link
Contributor Author

yafshar commented Dec 13, 2024

Using this PR, the test passes on both G2 & G3

@libinta libinta added the run-test Run CI for PRs from external contributors label Dec 13, 2024
@yeonsily yeonsily requested a review from libinta December 13, 2024 21:06
@yafshar
Copy link
Contributor Author

yafshar commented Dec 13, 2024

@imangohari1 I added your patch. Please add the test results.

@imangohari1
Copy link
Contributor

@regisss @astachowiczhabana @libinta

PR #1545 added the pipeline.unet.set_default_attn_processor(pipeline.unet) to the text_generation.
Here we added those to CI tests too.
Below are the tests on g1/g2/g3. They are now showing no GC errors.

p.s. upstream discussion: huggingface/diffusers#7300

G3

GAUDI2_CI=1 RUN_SLOW=true python -m pytest tests/test_diffusers.py -s -v -k test_no_throughput_regression_bf16
.
.

tests/test_diffusers.py::GaudiStableDiffusionPipelineTester::test_no_throughput_regression_bf16
  /usr/local/lib/python3.10/dist-packages/diffusers/models/unets/unet_2d_blocks.py:2628: FutureWarning: `scale` is deprecated and will be removed in version 1.0.0. The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`.
    deprecate("scale", "1.0.0", deprecation_message)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
====================================================================================== 2 passed, 163 deselected, 4 warnings in 141.07s (0:02:21) =========
GAUDI2_CI=1 RUN_SLOW=true python -m pytest tests/test_diffusers.py -s -v -k test_no_generation_regression
.
.

========================================================================================================== warnings summary ===========================================================================================================
../../usr/lib/python3.10/inspect.py:288
  /usr/lib/python3.10/inspect.py:288: FutureWarning: `torch.distributed.reduce_op` is deprecated, please use `torch.distributed.ReduceOp` instead
    return isinstance(object, types.FunctionType)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
====================================================================================== 3 passed, 162 deselected, 1 warning in 160.62s (0:02:40) =======================

G2

GAUDI2_CI=1 RUN_SLOW=true python -m pytest tests/test_diffusers.py -s -v -k test_no_throughput_regression_bf16


============================================================== short test summary info ===============================================================
FAILED tests/test_diffusers.py::GaudiDDPMPipelineTester::test_no_throughput_regression_bf16 - AssertionError: 0.142 not greater than or equal to 0.1442744619890604
======================================== 1 failed, 1 passed, 163 deselected, 4 warnings in 295.15s (0:04:55) =========================================

GAUDI2_CI=1 RUN_SLOW=true python -m pytest tests/test_diffusers.py -s -v -k test_no_generation_regression


==================================================================== 3 passed, 162 deselected, 1 warning in 248.95s (0:04:08) ===========================

G1

GAUDI2_CI=0 RUN_SLOW=true python -m pytest tests/test_diffusers.py -s -v -k test_no_throughput_regression_bf16 



============================================================== short test summary info ===============================================================
FAILED tests/test_diffusers.py::GaudiStableDiffusionPipelineTester::test_no_throughput_regression_bf16 - AssertionError: 0.276 not greater than or equal to 0.29355
FAILED tests/test_diffusers.py::GaudiDDPMPipelineTester::test_no_throughput_regression_bf16 - AssertionError: 0.044 not greater than or equal to 0.047698229228712884
============================================= 2 failed, 162 deselected, 4 warnings in 663.18s (0:11:03) ===================

GAUDI2_CI=0 RUN_SLOW=true python -m pytest tests/test_diffusers.py -s -v -k test_no_generation_regression

================================================================== warnings summary ==================================================================
../../usr/lib/python3.10/inspect.py:288
  /usr/lib/python3.10/inspect.py:288: FutureWarning: `torch.distributed.reduce_op` is deprecated, please use `torch.distributed.ReduceOp` instead
    return isinstance(object, types.FunctionType)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================================== 3 passed, 161 deselected, 1 warning in 383.21s (0:06:23) ============

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
run-test Run CI for PRs from external contributors
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants