Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental DirectML support via torch-directml and onnxruntime-directml #1702

Open
wants to merge 16 commits into
base: main
Choose a base branch
from

Conversation

kazssym
Copy link

@kazssym kazssym commented Feb 19, 2024

What does this PR do?

This PR adds experimental DirectML support via torch-directml, which is still in preview and lacks several PyTorch functions such as microsoft/DirectML#449.

If you are interested in this PR, please leave comments below.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

This commit introduces two improvements:

1. DirectML acceleration:

    - Added support for running optimum commands on DirectML hardware (Windows only) using the --device dml flag.
    - Automatically sets the device to torch_directml.device() when the flag is specified.

2. Improved device handling:

    - Ensures the model is directly initialized in the device only when applicable.
This commit refines the device handling in optimum/exporters/tasks.py for the following improvements:

  - More precise device check: Instead of checking for not device.type, the condition is updated to device.type != "privateuseone". This ensures the initialization happens on the requested device only if it's not a private use device (e.g., DirectML).
  - Improved clarity: The code comments are updated to better explain the purpose of the device initialization and its benefits for large models.
  - Extends device compatibility to "privateuseone" in export_pytorch for exporting models usable on specific hardware.

This commit allows exporting PyTorch models compatible with the "privateuseone" device, potentially enabling inference on specialized hardware platforms.
This commit adds support for running PyTorch models on the DML device within the Optimum framework.

  - Dynamic DML device handling: Introduces dynamic import of torch_directml for improved maintainability.
  - Consistent device selection: Ensures consistent device selection across optimum/exporters/onnx/convert.py, optimum/exporters/tasks.py, and optimum/onnxruntime/io_binding/io_binding_helper.py.

This change allows users to leverage DML capabilities for efficient PyTorch model inference with Optimum.
This commit removes unnecessary code for handling the DML device in optimum/commands/optimum_cli.py.

  - Redundant import: The code previously imported torch_directml conditionally, which is no longer needed as DML device support is handled in other parts of the codebase.

This change simplifies the code and avoids potential conflicts.
This commit updates `setup.py` to include the following changes:

  - Introduces a new conditional section "exporters-directml" with dependencies required for exporting models for DML inference.
  - This section mirrors the existing "exporters" and "exporters-gpu" sections, adding `onnxruntime-directml` as a dependency.

This update ensures users have the necessary libraries for working with DML devices when installing Optimum with DML support.
@kazssym kazssym changed the title Experimental DirectML support via torch-directml Experimental DirectML support via torch-directml and onnxruntime-directml Jan 13, 2025
@kazssym kazssym marked this pull request as ready for review January 14, 2025 02:57
@kazssym kazssym marked this pull request as draft January 19, 2025 02:58
@kazssym
Copy link
Author

kazssym commented Jan 19, 2025

I found a problem in import_utils.py that onnxruntime-directml is not properly detected as an alternative for onnxruntime.

https://github.com/huggingface/optimum/blob/ded71c2a66bf7a41c4350677db8d32b424dc8d09/optimum/utils/import_utils.py#L68C1-L70C1

@kazssym kazssym marked this pull request as ready for review February 2, 2025 08:44
@kazssym
Copy link
Author

kazssym commented Feb 2, 2025

Blocked by microsoft/DirectML#686

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant