-
Notifications
You must be signed in to change notification settings - Fork 527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pd: add CINN compiler for dpa2, dpa1 training #4514
base: devel
Are you sure you want to change the base?
pd: add CINN compiler for dpa2, dpa1 training #4514
Conversation
📝 WalkthroughWalkthroughThe changes modify two modules. In Changes
Sequence Diagram(s)sequenceDiagram
participant Runner as training.run()
participant Env as CINN Flag
participant Paddle as Paddle API
participant NV as nvprof_context
participant Model as Model.forward()
Runner->>Env: Check if CINN is enabled
alt CINN is enabled
Runner->>Paddle: Import jit, static, BuildStrategy
Runner->>Paddle: Wrap Model.forward() using jit.to_static
end
Runner->>NV: Enter profiling context
Runner->>Model: Execute data retrieval and prediction
Model-->>Runner: Return predictions
Runner->>NV: Exit profiling context
sequenceDiagram
participant OS as OS Environment
participant Env as Environment Setup
participant TB as to_bool()
participant Paddle as PaddlePaddle
OS->>TB: Provide "JIT" value
TB-->>Env: Return boolean for JIT
OS->>TB: Provide "CINN" value
TB-->>Env: Return boolean for CINN
alt CINN is true
Env->>Paddle: Check for CINN support
alt Not Supported
Env->>Env: Raise assertion error
end
end
Env->>Env: In enable_prim(), if JIT true, enable primitives in JIT mode
Env->>Env: Otherwise, enable primitives in eager mode with composite op blacklist
✨ Finishing Touches
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
deepmd/pd/train/training.py (2)
402-405
: Remove or clarify commented-out code.These lines comment out a previously raised
NotImplementedError
and a potentialpaddle.jit.to_static
call. If you no longer need this logic, removing it altogether might reduce confusion and keep the file tidy. Otherwise, add a comment explaining why these lines are kept for future reference.-# if JIT: -# raise NotImplementedError("JIT is not supported yet when training with Paddle") -# self.model = paddle.jit.to_static(self.model)Also applies to: 406-406
925-926
: Consider removing the extra commented-out code.This snippet appears to comment out a JIT debugging break. If it's no longer needed, removing it can avoid potential confusion.
-# if JIT: -# break🧰 Tools
🪛 GitHub Check: CodeQL
[notice] 925-926: Commented-out code
This comment appears to contain commented-out code.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pd/train/training.py
(8 hunks)deepmd/pd/utils/env.py
(2 hunks)
🧰 Additional context used
🪛 GitHub Check: CodeQL
deepmd/pd/utils/env.py
[notice] 61-61: Unused global variable
The global variable 'CINN' is not used.
deepmd/pd/train/training.py
[notice] 925-926: Commented-out code
This comment appears to contain commented-out code.
🔇 Additional comments (7)
deepmd/pd/train/training.py (4)
56-57
: Check the new imports for consistency.
The addition of CINN
and DEFAULT_PRECISION
is consistent with improvements to JIT compilation and precision handling. Nothing problematic is observed here; just ensure that CINN
is successfully imported where used and that DEFAULT_PRECISION
is consistently applied.
636-648
: JIT and CINN integration logic looks good.
You're conditionally enabling JIT using jit.to_static
and passing the build_strategy.build_cinn_pass = CINN
. This is a clean approach, ensuring that CINN is only used if enabled. Just verify upstream usage to avoid unexpected behaviors if CINN
is disabled at runtime.
688-692
: Validate data fetching performance within profiling context.
Wrapping the data loading with nvprof_context
helps profile overhead. Ensure that exceptions thrown within this block are properly handled so that the profiler is closed gracefully.
706-706
: Precision usage for learning rate.
Using paddle.full([], pref_lr, DEFAULT_PRECISION)
enforces consistent floating-point precision for the learning rate. This is beneficial for uniformity, especially in multi-task scenarios. Good practice!
Also applies to: 765-765, 815-815
deepmd/pd/utils/env.py (3)
36-50
: Robust input validation in to_bool
.
Excellent job handling integers, booleans, and string values thoroughly, with clear error messages and suitable lowercasing of strings. This ensures minimal confusion for environment variable parsing.
60-61
: CINN usage is valid despite static analysis hint.
While static analysis tools suggest that CINN
is “unused,” it is in fact used within training.py
(via build_strategy.build_cinn_pass
). You can safely disregard that hint.
🧰 Tools
🪛 GitHub Check: CodeQL
[notice] 61-61: Unused global variable
The global variable 'CINN' is not used.
167-182
: Flexible prim enabling for JIT and eager modes.
Switching between JIT mode and eager mode, while setting blacklists only in eager mode, is a neat approach to leverage performance improvements. This design offers more control and clarity to developers using either mode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (3)
deepmd/pd/train/training.py (3)
709-709
: Confirm precision alignment for learning rate.
Replacing a float withpaddle.full([], pref_lr, DEFAULT_PRECISION)
ensures consistent floating-point precision. Verify that using a 0-dimensional tensor is intentional and that the rest of the forward pass expects a scalar.
768-768
: Avoid code duplication in validation logic.
This pattern repeats the same approach for creating a scalar tensor frompref_lr
. Consider encapsulating the logic to reduce repetitive calls and simplify maintenance.
818-818
: Maintain consistent approach for multi-task calls.
Reusingpaddle.full([], pref_lr, DEFAULT_PRECISION)
in multiple task keys is valid. However, centralizing this snippet could prevent future errors in multi-task scenarios.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pd/train/training.py
(6 hunks)deepmd/pd/utils/env.py
(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/pd/utils/env.py
🔇 Additional comments (2)
deepmd/pd/train/training.py (2)
56-57
: Use environment-driven feature flags carefully.
Importing CINN
and DEFAULT_PRECISION
helps toggle experimental optimizations and ensure matching data types throughout training. This is acceptable, but confirm that these environment variables or configurations are properly documented and tested to avoid accidental misconfigurations.
691-695
: Validate data fetching within profiler context.
The usage of nvprof_context
around get_data()
is beneficial for profiling. Make sure exceptions (e.g., StopIteration
) in data fetching are handled gracefully to preserve consistent profiling measurements.
95b201d
to
5cdd421
Compare
5cdd421
to
7ca2a9e
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4514 +/- ##
==========================================
- Coverage 84.60% 84.57% -0.03%
==========================================
Files 680 680
Lines 64473 64500 +27
Branches 3540 3540
==========================================
+ Hits 54546 54554 +8
- Misses 8786 8804 +18
- Partials 1141 1142 +1 ☔ View full report in Codecov by Sentry. |
@njzjz I have a question about the code at deepmd-kit/deepmd/pt/model/network/layernorm.py Lines 97 to 99 in 8d4c27b
numel()==0 ? I haven't noticed this happening during Python training and testing. Could you please help me with this? Thank you!
|
When using LAMMPS with MPI. |
See also #2668. Some users may encounter the situation where a processor has no atom. |
Thank you for your response. Although 0-size tensors are not very common, we have indeed encountered similar issues with some object detection models. We are going to support for training and inference with 0-size tensors in Paddle in the near future. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
deepmd/pd/utils/env.py (2)
40-60
: Consider using pattern matching for a more concise implementation.The function is well-implemented with proper type checking and error handling. However, it could be more concise using Python 3.10+ pattern matching.
Here's a more concise version using pattern matching:
-def to_bool(flag: int | bool | str) -> bool: - if isinstance(flag, int): - if flag not in [0, 1]: - raise ValueError(f"flag must be either 0 or 1, but received {flag}") - return bool(flag) - - elif isinstance(flag, str): - flag = flag.lower() - if flag not in ["1", "0", "true", "false"]: - raise ValueError( - f"flag must be either '0', '1', 'true', 'false', but received '{flag}'" - ) - return flag in ["1", "true"] - - elif isinstance(flag, bool): - return flag - - else: - raise ValueError( - f"flag must be either int, bool, or str, but received {type(flag).__name__}" - ) +def to_bool(flag: int | bool | str) -> bool: + match flag: + case bool(): + return flag + case int() if flag in [0, 1]: + return bool(flag) + case str() as s if s.lower() in ["0", "1", "true", "false"]: + return s.lower() in ["1", "true"] + case int(): + raise ValueError(f"flag must be either 0 or 1, but received {flag}") + case str(): + raise ValueError( + f"flag must be either '0', '1', 'true', 'false', but received '{flag}'" + ) + case _: + raise ValueError( + f"flag must be either int, bool, or str, but received {type(flag).__name__}" + )
177-195
: Enhance logging to indicate the specific mode.The function correctly handles both JIT/CINN and eager modes, but the log message could be more informative.
Consider updating the log message to indicate the specific mode:
- log.info(f"{'Enable' if enable else 'Disable'} prim in eager and static mode.") + mode = "JIT/CINN" if JIT or CINN else "eager" + log.info(f"{'Enable' if enable else 'Disable'} prim in {mode} mode.")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/pd/utils/env.py
(3 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (19)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
🔇 Additional comments (2)
deepmd/pd/utils/env.py (2)
2-4
: LGTM! Good use of type hints.The
annotations
import enables modern type hint syntax, improving code readability and type safety.
63-70
: LGTM! Good defensive programming with CINN support check.The code properly validates CINN support at startup with a clear error message, preventing runtime issues.
Note: The past review comment about unused
CINN
variable is outdated as the variable is now used intraining.py
.
We verified paddle CINN compiler in DPA-2 example(single A100-SXM (40G), cada11.8, Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz x 160).
To enable CINN compiler in training, add one flag:
CINN=1
before training command, e.g.CINN=1 dp --pd train input_torch_medium.json
.Curves:
dpa2
Performance
We tested with torch==2.6.0.dev20241219+cu118
se_atten
Performance
We tested with torch==2.6.0.dev20241219+cu118
Accuracy details:
dpa2
Pytorch:
![image](https://private-user-images.githubusercontent.com/23737287/398843959-d74528e9-11eb-432b-b7cc-bf568b224dcb.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzg5NTg3MTQsIm5iZiI6MTczODk1ODQxNCwicGF0aCI6Ii8yMzczNzI4Ny8zOTg4NDM5NTktZDc0NTI4ZTktMTFlYi00MzJiLWI3Y2MtYmY1NjhiMjI0ZGNiLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMDclMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjA3VDIwMDAxNFomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTk4ODNkNGRmMGZkZTdlOGQxMDQ1ZmNhMDM3MzQyMjlmODJlMzJlZTFhNDMyZDRmNGQwOTM1MzYxZTQ0YzM5ZjImWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.tFlSyMUPycy5_Pu8bnxBEbiOtqn4MWSjevccBURN6M8)
Paddle(eager mode):
![image](https://private-user-images.githubusercontent.com/23737287/398843926-dc17b5d2-7f13-498c-9398-9d4c1a9a1e22.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzg5NTg3MTQsIm5iZiI6MTczODk1ODQxNCwicGF0aCI6Ii8yMzczNzI4Ny8zOTg4NDM5MjYtZGMxN2I1ZDItN2YxMy00OThjLTkzOTgtOWQ0YzFhOWExZTIyLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMDclMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjA3VDIwMDAxNFomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTFiOWVmZWIyNTdhYTBhMTY4ODZhMWZmMzk1NTYzZGIwYTA0ZTUzMzA1YWY3ZDBjMTlmYWIwZDhmMjg5OWEyYjMmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.o-ZFH_ycz4DPxJjIwU4B5oQxO0pWDB_xW3kxjnSW05g)
Paddle(CINN compiler)
![1da419d9d8b27981cfd3a1067e30ce6f](https://private-user-images.githubusercontent.com/23737287/398844128-f20286a1-6088-4f6f-b667-2b4763a438b5.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzg5NTg3MTQsIm5iZiI6MTczODk1ODQxNCwicGF0aCI6Ii8yMzczNzI4Ny8zOTg4NDQxMjgtZjIwMjg2YTEtNjA4OC00ZjZmLWI2NjctMmI0NzYzYTQzOGI1LnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMDclMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjA3VDIwMDAxNFomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTRlYTJjOTQ3YjUxMmY5ODJiZTljZWJhODU5OGMwNDUxZTMwOTZiODIyNDk1NTkxYWI4OWJmYmI5YTIzYmJhNmImWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.jIyQqcoM9nwAKY6AsoC2sny-YuqhYKjCfJ2luUmpr0I)
se_atten
Pytorch
![image](https://private-user-images.githubusercontent.com/23737287/399135519-be46fd72-2db9-451b-9c2c-d5c382b75c6f.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzg5NTg3MTQsIm5iZiI6MTczODk1ODQxNCwicGF0aCI6Ii8yMzczNzI4Ny8zOTkxMzU1MTktYmU0NmZkNzItMmRiOS00NTFiLTljMmMtZDVjMzgyYjc1YzZmLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMDclMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjA3VDIwMDAxNFomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWIwM2MyNzliZmYyMmExOGY1MGE1Y2JhMDk4Nzc2OWFkMzI2ODZjYTllMWFjNzAzYzZmNjdhOWFlNDVmZTk3MWYmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.pY4QAdqaF7JNDlzo1nXOpe51Gtr6WBF-SVjd0ydZBGY)
Paddle(eager_mode)
![image](https://private-user-images.githubusercontent.com/23737287/399135527-cffeedf8-26dd-4142-830b-af5106805e52.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzg5NTg3MTQsIm5iZiI6MTczODk1ODQxNCwicGF0aCI6Ii8yMzczNzI4Ny8zOTkxMzU1MjctY2ZmZWVkZjgtMjZkZC00MTQyLTgzMGItYWY1MTA2ODA1ZTUyLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMDclMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjA3VDIwMDAxNFomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWNlMTJhODQ3MmY5ZDI3OTJjZDM1NmFiNTMxYWNlNzUyNDQwZTk5OTBlMDAyN2ZlMDRkMTY0YTQyNTQ1NTY5ODYmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.GOo75Ih3fz1pxSpkzHyjIEm1PPFEGOGfd57FClXo2ZM)
Paddle(CINN compliler)
![image](https://private-user-images.githubusercontent.com/23737287/399135674-971a9476-703f-4765-81b9-c45311cc803f.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzg5NTg3MTQsIm5iZiI6MTczODk1ODQxNCwicGF0aCI6Ii8yMzczNzI4Ny8zOTkxMzU2NzQtOTcxYTk0NzYtNzAzZi00NzY1LTgxYjktYzQ1MzExY2M4MDNmLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMDclMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjA3VDIwMDAxNFomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWJkNWYyODFhNWE1NjcyOGQ5ZGNmOTgxNTI2MzIzODAzMGQ3OTcxODk1NWQ5OGU3ZmRiZTc1MGY5NmMyNTJlZWMmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.it2LJLEu1T43ri-D7gktWCSuOjbeF7jXl6AaumIckAs)
TODO:
Summary by CodeRabbit
New Features
Refactor