v1.5
Release Notes – Release 1.5
Key Features and Enhancements
- [pyTorch] Added support for non-reentrant mode for activation recompute in the
checkpoint
API. - [pyTorch] Added support for rectangular matrices in the unfused softmax backend in order to support speculative decoding.
- [pyTorch] Added the
inference_params
argument to theDotProductAttention
API to support kv-caching. - [JAX] Added the
DotProductAttention
API. - [JAX] Expanded RoPE support using the
rotary_pos_emb_group_method
argument. - [paddle] Added support for RMSNorm.
- [paddle] Added support for RoPE.
- [paddle] Added support for SwiGLU.
Fixed Issues
- [pyTorch] Fixed a numerical issue with storing weights in FP8 via the
fp8_model_init
API.
Known Issues in This Release
- FlashAttention v2, which is a dependency of this release of Transformer Engine, has a known issue with excessive memory usage during installation (Dao-AILab/flash-attention#358). You can work around this issue either by setting the environment variable MAX_JOBS=1 during Transformer Engine installation.
- [pyTorch] FlashAttention v2.1 changed the behavior of the causal mask when performing cross-attention (see https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag for reference). In order for Transformer Engine to keep consistent behavior between versions and backends, FlashAttention is disabled for this use case (cross attention with casual masking) when 2.1+ version of FlashAttention is installed.
Breaking Changes in This Release
There are no breaking changes in this release.
Deprecated Features
- [JAX] The arguments
num_heads
,dropout_rate
,output_layernorm
,apply_residual_connection_post_layernorm
, andfuse_qkv
are deprecated in theMultiHeadAttention
API. They are replaced respectively withnum_attention_heads
,attention_dropout
,input_layernorm
,return_layernorm_output
, andfused_qkv_params
.
Miscellaneous Changes
There are no miscellaneous changes in this release.