Skip to content

PyTorch/XLA 2.4 Release

Compare
Choose a tag to compare
@bhavya01 bhavya01 released this 25 Jul 00:10
· 600 commits to master since this release
f4d0333

Cloud TPUs now support the Pytorch 2.4 release, via PyTorch/XLA integration. On top of the underlying improvements and bug fixes in the PyTorch 2.4 release, this release introduces several features, and PyTorch/XLA specific bug fixes.

🚀 PyTorch/XLA 2.4 release delivers a 4% speedup boost (Geometric Mean) on torchbench evaluation benchmarks using openxla_eval dynamo backend on TPUs, compared to the 2.3 release.

Highlights

We are excited to announce the release of PyTorch XLA 2.4! PyTorch 2.4 offers improved support for custom kernels using Pallas, including kernels like FlashAttention and Group Matrix Multiplication that can be used like any other torch operators and inference support for the PagedAttention kernel. We also add experimental support for eager mode that compiles and executes each operator for a better debugging and development experience.

Stable Features

PJRT

  • Enable dynamic plugins by default #7270

GSPMD

  • Support manual sharding and introduce high level manual sharding APIs #6915, #6931
  • Support SPMDFullToShardShape, SPMDShardToFullShape #6922, #6925

Torch Compile

  • Add a DynamoSyncInputExecuteTime counter #6813
  • Fix runtime error when run dynamo with a profiler scope #6913

Export

  • Add fx passes to support unbounded dynamism #6653
  • Add dynamism support to conv1d, view, softmax #6653
  • Add dynamism support to aten.embedding and aten.split_with_sizes #6781
  • Inline all scalars by default in export path #6803
  • Run shape propagation for inserted fx nodes #6805
  • Add an option to not generate weights #6909
  • Support export custom op to stablehlo custom call #7017
  • Support array attribute in stablehlo composite #6840
  • Add option to export FX Node metadata to StableHLO #7046

Beta Features

Pallas

  • Support FlashAttention backward kernels #6870
  • Make FlashAttention as torch.autograd.Function #6886
  • Remove torch.empty in tracing to avoid allocating extra memory #6897
  • Integrate FlashAttention with SPMD #6935
  • Support scaling factor for attention weights in FlashAttention #7035
  • Support segment ids in FlashAttention #6943
  • Enable PagedAttention through Pallas #6912
  • Properly support PagedAttention dynamo code path #7022
  • Support megacore_mode in PagedAttention #7060
  • Add Megablocks’ Group Matrix Multiplication kernel #6940, #7117, #7120, #7119, #7133, #7151
  • Support histogram #7115, #7202
  • Support tgmm #7137
  • Make repeat_with_fixed_output_size not OOM on VMEM #7145
  • Introduce GMM torch.autograd.function #7152

CoreAtenOpSet

  • Lower embedding_bag_forward_only #6951
  • Implement Repeat with fixed output shape #7114
  • Add int8 per channel weight-only quantized matmul #7201

FSDP via SPMD

  • Support multislice #7044
  • Allow sharding on the maximal dimension of the weights #7134
  • Apply optimization-barrier to all params and buffers during grad checkpointing #7206

Distributed Checkpoint

  • Add optimizer priming for distributed checkpointing #6572

Usability

  • Add xla.sync as a better name for mark_step. See #6399. #6914
  • Add xla.step context manager to handle exceptions better. See #6751. #7068
  • Implement ComputationClient::GetMemoryInfo for getting TPU memory allocation #7086
  • Dump HLO HBM usage info #7085
  • Add function for retrieving fallback operations #7116
  • Deprecate XLA_USE_BF16 and XLA_USE_FP16 #7150
  • Add PT_XLA_DEBUG_LEVEL to make it easier to distinguish between execution cause and compilation cause #7149
  • Warn when using persistent cache with debug env vars #7175
  • Add experimental MLIR debuginfo writer API #6799

GPU CUDA Fallback

  • Add dlpack support #7025
  • Make from_dlpack handle cuda synchronization implicitly for input tensors that have __dlpack__ and __dlpack_device__ attributes. #7125

Distributed

  • Switch all_reduce to use the new functional collective op #6887
  • Allow user to configure distributed runtime service. #7204
  • Use dest_offsets directly in LoadPlanner #7243

Experimental Features

Eager Mode

  • Enable Eager mode for PyTorch/XLA #7611
  • Support eager mode with torch.compile #7649
  • Eagerly execute inplace ops in eager mode #7666
  • Support eager mode for multi-process training #7668
  • Handle random seed for eager mode #7669
  • Enable SPMD with eager mode #7673

Triton

  • Add support for Triton GPU kernels #6798
  • Make Triton kernels work with CUDA plugin #7303

While Loop

  • Prepare for torch while_loop signature change. #6872
  • Implement fori_loop as a wrapper around while_loop #6850
  • Complete fori_loop/while_loop and additional test case #7306

Bug Fixes and Improvements

  • Fix type promotion for pow. (#6745)
  • Fix vector norm lowering #6883
  • Manually init absl log to avoid log spam #6890
  • Fix pixel_shuffle return empty #6907
  • Make nms fallback to CPU implementation by default #6933
  • Fix torch.full scalar type #7010
  • Handle multiple inplace update input output aliasing #7023
  • Fix overflow for div arguments. #7081
  • Add data_type promotion to gelu_backward, stack #7090, #7091
  • Fix index of 0-element tensor by 0-element tensor #7113
  • Fix output data-type for upsample_bilinear #7168
  • Fix a data-type related problem for mul operation by converting inputs to result type #7130
  • Make clip_grad_norm_ follow input’s dtype #7205