Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pd: add CINN compiler for dpa2, dpa1 training #4514

Open
wants to merge 9 commits into
base: devel
Choose a base branch
from

Conversation

HydrogenSulfate
Copy link
Contributor

@HydrogenSulfate HydrogenSulfate commented Dec 27, 2024

We verified paddle CINN compiler in DPA-2 example(single A100-SXM (40G), cada11.8, Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz x 160).

To enable CINN compiler in training, add one flag: CINN=1 before training command, e.g. CINN=1 dp --pd train input_torch_medium.json.

Curves:

dpa2

e9eb3013dc154b03172a0c1881507c8e

Performance

We tested with torch==2.6.0.dev20241219+cu118

Pytorch(eager) Pytorch(compile) Paddle(eager) Paddle(CINN compile)
0.1078s/step compile do not support double-backward 0.1130s/step 0.0638s/step

se_atten

image

Performance

We tested with torch==2.6.0.dev20241219+cu118

Pytorch(eager) Pytorch(compile) Paddle(eager) Paddle(CINN compile)
0.0342s/step compile do not support double-backward 0.0444s/step 0.0244s/step

Accuracy details:

dpa2

  • Pytorch:
    image

  • Paddle(eager mode):
    image

  • Paddle(CINN compiler)
    1da419d9d8b27981cfd3a1067e30ce6f

se_atten

  • Pytorch
    image

  • Paddle(eager_mode)
    image

  • Paddle(CINN compliler)
    image

TODO:

Summary by CodeRabbit

  • New Features

    • Enhanced training performance through optimization techniques that leverage compilation and profiling for better execution monitoring.
    • Introduced a new function for converting various input types to boolean values, improving environment variable handling.
  • Refactor

    • Improved precision management for training parameters and updated environment configuration handling to ensure robust optimization support.

deepmd/pd/train/training.py Fixed Show fixed Hide fixed
deepmd/pd/utils/env.py Fixed Show fixed Hide fixed
Copy link
Contributor

coderabbitai bot commented Dec 27, 2024

📝 Walkthrough

Walkthrough

The changes modify two modules. In deepmd/pd/train/training.py, a conditional check for the CINN flag has been added. When enabled, Paddle’s jit.to_static function is used to convert the model’s forward pass to a static graph, a BuildStrategy is set up, and the learning rate is initialized with a specified precision. Data retrieval and prediction are now wrapped within a profiling context. In deepmd/pd/utils/env.py, a new to_bool function converts various input types to booleans; the JIT and CINN environment variables now use this function, and enable_prim() is updated to branch based on the JIT flag and Paddle’s CINN support.

Changes

File(s) Change Summary
deepmd/pd/train/training.py Added imports for CINN and DEFAULT_PRECISION from deepmd/pd/utils/env. In the run() method, a check for the CINN flag is introduced to import Paddle’s jit and static modules, set up a BuildStrategy, wrap the model’s forward pass with jit.to_static, change learning rate initialization using paddle.full, and encapsulate data fetching and prediction within nvprof_context.
deepmd/pd/utils/env.py Introduced the to_bool() function to cast integers, booleans, or strings to booleans. Updated the assignment of the JIT and CINN variables to use to_bool(). Added an assertion to verify PaddlePaddle’s CINN support and modified enable_prim() to branch its behavior based on the JIT flag.

Sequence Diagram(s)

sequenceDiagram
    participant Runner as training.run()
    participant Env as CINN Flag
    participant Paddle as Paddle API
    participant NV as nvprof_context
    participant Model as Model.forward()

    Runner->>Env: Check if CINN is enabled
    alt CINN is enabled
        Runner->>Paddle: Import jit, static, BuildStrategy
        Runner->>Paddle: Wrap Model.forward() using jit.to_static
    end
    Runner->>NV: Enter profiling context
    Runner->>Model: Execute data retrieval and prediction
    Model-->>Runner: Return predictions
    Runner->>NV: Exit profiling context
Loading
sequenceDiagram
    participant OS as OS Environment
    participant Env as Environment Setup
    participant TB as to_bool()
    participant Paddle as PaddlePaddle

    OS->>TB: Provide "JIT" value
    TB-->>Env: Return boolean for JIT
    OS->>TB: Provide "CINN" value
    TB-->>Env: Return boolean for CINN
    alt CINN is true
        Env->>Paddle: Check for CINN support
        alt Not Supported
            Env->>Env: Raise assertion error
        end
    end
    Env->>Env: In enable_prim(), if JIT true, enable primitives in JIT mode
    Env->>Env: Otherwise, enable primitives in eager mode with composite op blacklist
Loading
✨ Finishing Touches
  • 📝 Generate Docstrings (Beta)

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR. (Beta)
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
deepmd/pd/train/training.py (2)

402-405: Remove or clarify commented-out code.

These lines comment out a previously raised NotImplementedError and a potential paddle.jit.to_static call. If you no longer need this logic, removing it altogether might reduce confusion and keep the file tidy. Otherwise, add a comment explaining why these lines are kept for future reference.

-# if JIT:
-#     raise NotImplementedError("JIT is not supported yet when training with Paddle")
-#     self.model = paddle.jit.to_static(self.model)

Also applies to: 406-406


925-926: Consider removing the extra commented-out code.

This snippet appears to comment out a JIT debugging break. If it's no longer needed, removing it can avoid potential confusion.

-# if JIT:
-#     break
🧰 Tools
🪛 GitHub Check: CodeQL

[notice] 925-926: Commented-out code
This comment appears to contain commented-out code.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bf79cc6 and 55f7ef6.

📒 Files selected for processing (2)
  • deepmd/pd/train/training.py (8 hunks)
  • deepmd/pd/utils/env.py (2 hunks)
🧰 Additional context used
🪛 GitHub Check: CodeQL
deepmd/pd/utils/env.py

[notice] 61-61: Unused global variable
The global variable 'CINN' is not used.

deepmd/pd/train/training.py

[notice] 925-926: Commented-out code
This comment appears to contain commented-out code.

🔇 Additional comments (7)
deepmd/pd/train/training.py (4)

56-57: Check the new imports for consistency.

The addition of CINN and DEFAULT_PRECISION is consistent with improvements to JIT compilation and precision handling. Nothing problematic is observed here; just ensure that CINN is successfully imported where used and that DEFAULT_PRECISION is consistently applied.


636-648: JIT and CINN integration logic looks good.

You're conditionally enabling JIT using jit.to_static and passing the build_strategy.build_cinn_pass = CINN. This is a clean approach, ensuring that CINN is only used if enabled. Just verify upstream usage to avoid unexpected behaviors if CINN is disabled at runtime.


688-692: Validate data fetching performance within profiling context.

Wrapping the data loading with nvprof_context helps profile overhead. Ensure that exceptions thrown within this block are properly handled so that the profiler is closed gracefully.


706-706: Precision usage for learning rate.

Using paddle.full([], pref_lr, DEFAULT_PRECISION) enforces consistent floating-point precision for the learning rate. This is beneficial for uniformity, especially in multi-task scenarios. Good practice!

Also applies to: 765-765, 815-815

deepmd/pd/utils/env.py (3)

36-50: Robust input validation in to_bool.

Excellent job handling integers, booleans, and string values thoroughly, with clear error messages and suitable lowercasing of strings. This ensures minimal confusion for environment variable parsing.


60-61: CINN usage is valid despite static analysis hint.

While static analysis tools suggest that CINN is “unused,” it is in fact used within training.py (via build_strategy.build_cinn_pass). You can safely disregard that hint.

🧰 Tools
🪛 GitHub Check: CodeQL

[notice] 61-61: Unused global variable
The global variable 'CINN' is not used.


167-182: Flexible prim enabling for JIT and eager modes.

Switching between JIT mode and eager mode, while setting blacklists only in eager mode, is a neat approach to leverage performance improvements. This design offers more control and clarity to developers using either mode.

@HydrogenSulfate HydrogenSulfate changed the title pd: add CINN compiler for dpa2 [WIP] pd: add CINN compiler for dpa2 Dec 27, 2024
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
deepmd/pd/train/training.py (3)

709-709: Confirm precision alignment for learning rate.
Replacing a float with paddle.full([], pref_lr, DEFAULT_PRECISION) ensures consistent floating-point precision. Verify that using a 0-dimensional tensor is intentional and that the rest of the forward pass expects a scalar.


768-768: Avoid code duplication in validation logic.
This pattern repeats the same approach for creating a scalar tensor from pref_lr. Consider encapsulating the logic to reduce repetitive calls and simplify maintenance.


818-818: Maintain consistent approach for multi-task calls.
Reusing paddle.full([], pref_lr, DEFAULT_PRECISION) in multiple task keys is valid. However, centralizing this snippet could prevent future errors in multi-task scenarios.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 55f7ef6 and 95b201d.

📒 Files selected for processing (2)
  • deepmd/pd/train/training.py (6 hunks)
  • deepmd/pd/utils/env.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/pd/utils/env.py
🔇 Additional comments (2)
deepmd/pd/train/training.py (2)

56-57: Use environment-driven feature flags carefully.
Importing CINN and DEFAULT_PRECISION helps toggle experimental optimizations and ensure matching data types throughout training. This is acceptable, but confirm that these environment variables or configurations are properly documented and tested to avoid accidental misconfigurations.


691-695: Validate data fetching within profiler context.
The usage of nvprof_context around get_data() is beneficial for profiling. Make sure exceptions (e.g., StopIteration) in data fetching are handled gracefully to preserve consistent profiling measurements.

deepmd/pd/train/training.py Show resolved Hide resolved
@HydrogenSulfate HydrogenSulfate changed the title [WIP] pd: add CINN compiler for dpa2 [WIP] pd: add CINN compiler for dpa2 training Dec 27, 2024
@HydrogenSulfate HydrogenSulfate changed the title [WIP] pd: add CINN compiler for dpa2 training [WIP] pd: add CINN compiler for dpa2, dpa1 training Dec 29, 2024
Copy link

codecov bot commented Dec 29, 2024

Codecov Report

Attention: Patch coverage is 47.05882% with 18 lines in your changes missing coverage. Please review.

Project coverage is 84.57%. Comparing base (f01fa53) to head (2cf1c65).

Files with missing lines Patch % Lines
deepmd/pd/utils/env.py 50.00% 13 Missing ⚠️
deepmd/pd/train/training.py 37.50% 5 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##            devel    #4514      +/-   ##
==========================================
- Coverage   84.60%   84.57%   -0.03%     
==========================================
  Files         680      680              
  Lines       64473    64500      +27     
  Branches     3540     3540              
==========================================
+ Hits        54546    54554       +8     
- Misses       8786     8804      +18     
- Partials     1141     1142       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@HydrogenSulfate
Copy link
Contributor Author

@njzjz I have a question about the code at

# See https://github.com/pytorch/pytorch/issues/85792
if xx.numel() > 0:
variance, mean = torch.var_mean(xx, dim=-1, unbiased=False, keepdim=True)
Under what circumstances would one encounter a 0-size(empty) input that causes numel()==0? I haven't noticed this happening during Python training and testing. Could you please help me with this? Thank you!

@njzjz
Copy link
Member

njzjz commented Jan 6, 2025

Under what circumstances would one encounter a 0-size(empty) input that causes numel()==0? I haven't noticed this happening during Python training and testing. Could you please help me with this? Thank you!

When using LAMMPS with MPI.

@njzjz
Copy link
Member

njzjz commented Jan 6, 2025

See also #2668. Some users may encounter the situation where a processor has no atom.

@HydrogenSulfate
Copy link
Contributor Author

See also #2668. Some users may encounter the situation where a processor has no atom.

Thank you for your response. Although 0-size tensors are not very common, we have indeed encountered similar issues with some object detection models. We are going to support for training and inference with 0-size tensors in Paddle in the near future.

@HydrogenSulfate HydrogenSulfate changed the title [WIP] pd: add CINN compiler for dpa2, dpa1 training pd: add CINN compiler for dpa2, dpa1 training Jan 23, 2025
.pre-commit-config.yaml Outdated Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
deepmd/pd/utils/env.py (2)

40-60: Consider using pattern matching for a more concise implementation.

The function is well-implemented with proper type checking and error handling. However, it could be more concise using Python 3.10+ pattern matching.

Here's a more concise version using pattern matching:

-def to_bool(flag: int | bool | str) -> bool:
-    if isinstance(flag, int):
-        if flag not in [0, 1]:
-            raise ValueError(f"flag must be either 0 or 1, but received {flag}")
-        return bool(flag)
-
-    elif isinstance(flag, str):
-        flag = flag.lower()
-        if flag not in ["1", "0", "true", "false"]:
-            raise ValueError(
-                f"flag must be either '0', '1', 'true', 'false', but received '{flag}'"
-            )
-        return flag in ["1", "true"]
-
-    elif isinstance(flag, bool):
-        return flag
-
-    else:
-        raise ValueError(
-            f"flag must be either int, bool, or str, but received {type(flag).__name__}"
-        )
+def to_bool(flag: int | bool | str) -> bool:
+    match flag:
+        case bool():
+            return flag
+        case int() if flag in [0, 1]:
+            return bool(flag)
+        case str() as s if s.lower() in ["0", "1", "true", "false"]:
+            return s.lower() in ["1", "true"]
+        case int():
+            raise ValueError(f"flag must be either 0 or 1, but received {flag}")
+        case str():
+            raise ValueError(
+                f"flag must be either '0', '1', 'true', 'false', but received '{flag}'"
+            )
+        case _:
+            raise ValueError(
+                f"flag must be either int, bool, or str, but received {type(flag).__name__}"
+            )

177-195: Enhance logging to indicate the specific mode.

The function correctly handles both JIT/CINN and eager modes, but the log message could be more informative.

Consider updating the log message to indicate the specific mode:

-    log.info(f"{'Enable' if enable else 'Disable'} prim in eager and static mode.")
+    mode = "JIT/CINN" if JIT or CINN else "eager"
+    log.info(f"{'Enable' if enable else 'Disable'} prim in {mode} mode.")
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 87d8f35 and 2cf1c65.

📒 Files selected for processing (1)
  • deepmd/pd/utils/env.py (3 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (19)
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (5, 3.9)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Analyze (python)
  • GitHub Check: Analyze (c-cpp)
🔇 Additional comments (2)
deepmd/pd/utils/env.py (2)

2-4: LGTM! Good use of type hints.

The annotations import enables modern type hint syntax, improving code readability and type safety.


63-70: LGTM! Good defensive programming with CINN support check.

The code properly validates CINN support at startup with a clear error message, preventing runtime issues.

Note: The past review comment about unused CINN variable is outdated as the variable is now used in training.py.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants