forked from NVIDIA/apex
-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* FusedRMSNorm/"T5LayerNorm" based on FusedLayerNorm (NVIDIA#1274) * FusedRMSNorm based on FusedLayerNorm * refactor duplicated kernels * delete comments * delete comments * cleanup * cleanup * cleanup, fixed clobbering forward_affine_mixed_dtypes * fix pybind naming and add MixedFused test * undo skipping * check elementwise_affine * Update tests/L0/run_fused_layer_norm/test_fused_layer_norm.py Oof, nice catch, thanks Co-authored-by: Masaki Kozuki <[email protected]> Co-authored-by: Masaki Kozuki <[email protected]> * fix and generate docs for FusedRMSNorm (NVIDIA#1285) * [FusedRMSNorm doc] document where epsilon is added (NVIDIA#1295) * [FusedRMSNorm doc] add epsilon to formula * correct * better wording * Fix some bugs * Optimize HostRMSNormGradient and HostApplyRMSNorm for AMD GPUs * Fix NaN issues in FusedRMSNorm * Update test_fused_layer_norm.py * Skip test_fused_layer_norm.TestAutocastFusedRMSNorm on ROCm * Use at::cuda::warp_size() instead of at::cuda::getCurrentDeviceProperties()->warpSize Co-authored-by: eqy <[email protected]> Co-authored-by: Masaki Kozuki <[email protected]> Co-authored-by: Stas Bekman <[email protected]>
- Loading branch information
1 parent
cf77e9b
commit c97ebfa
Showing
6 changed files
with
1,074 additions
and
136 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from .fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm | ||
from .fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm, FusedRMSNorm, MixedFusedRMSNorm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.