diff --git a/xla/hlo/ir/hlo_instructions.cc b/xla/hlo/ir/hlo_instructions.cc index a9624eafbb6e84..12188e79b30135 100644 --- a/xla/hlo/ir/hlo_instructions.cc +++ b/xla/hlo/ir/hlo_instructions.cc @@ -920,15 +920,6 @@ HloCollectiveInstruction::HloCollectiveInstruction( } } -HloCollectiveInstruction::HloCollectiveInstruction( - HloOpcode opcode, const Shape& shape, - absl::Span operands, - absl::Span replica_groups, bool constrain_layout, - const std::optional& channel_id) - : HloCollectiveInstruction(opcode, shape, operands, - CollectiveDeviceList(replica_groups), - constrain_layout, channel_id) {} - HloInstructionProto HloCollectiveInstruction::ToProto() const { HloInstructionProto proto = HloChannelInstruction::ToProto(); *proto.mutable_collective_device_list() = device_list_.ToProto(); @@ -1040,17 +1031,6 @@ HloAllReduceInstructionBase::HloAllReduceInstructionBase( reduce_computation->SetCollectiveCallInstruction(this); } -HloAllReduceInstructionBase::HloAllReduceInstructionBase( - HloOpcode opcode, const Shape& shape, - absl::Span operands, - HloComputation* reduce_computation, - absl::Span replica_groups, bool constrain_layout, - const std::optional& channel_id, bool use_global_device_ids) - : HloAllReduceInstructionBase(opcode, shape, operands, reduce_computation, - CollectiveDeviceList(replica_groups), - constrain_layout, channel_id, - use_global_device_ids) {} - HloInstructionProto HloAllReduceInstructionBase::ToProto() const { HloInstructionProto proto = HloCollectiveInstruction::ToProto(); proto.set_use_global_device_ids(use_global_device_ids_); diff --git a/xla/hlo/ir/hlo_instructions.h b/xla/hlo/ir/hlo_instructions.h index 2dc884b5d1c390..dbb1beaa529663 100644 --- a/xla/hlo/ir/hlo_instructions.h +++ b/xla/hlo/ir/hlo_instructions.h @@ -645,8 +645,6 @@ class HloRecvDoneInstruction : public HloSendRecvInstruction { class HloCollectiveInstruction : public HloChannelInstruction { public: - // TODO(b/316622399): Remove usages of this method and replace with - // device_list()->replica_groups(). const std::vector& replica_groups() const { return device_list_.replica_groups(); } @@ -677,13 +675,6 @@ class HloCollectiveInstruction : public HloChannelInstruction { const CollectiveDeviceList& collective_device_list, bool constrain_layout, const std::optional& channel_id); - ABSL_DEPRECATED("Use CollectiveDeviceList instead of list of ReplicaGroup.") - explicit HloCollectiveInstruction( - HloOpcode opcode, const Shape& shape, - absl::Span operands, - absl::Span replica_groups, bool constrain_layout, - const std::optional& channel_id); - HloInstructionProto ToProto() const override; void PrintExtraAttributesImpl(AttributePrinter& printer, @@ -760,14 +751,6 @@ class HloAllReduceInstructionBase : public HloCollectiveInstruction { const CollectiveDeviceList& device_list, bool constrain_layout, const std::optional& channel_id, bool use_global_device_ids); - ABSL_DEPRECATED("Use CollectiveDeviceList instead of list of ReplicaGroup.") - explicit HloAllReduceInstructionBase( - HloOpcode opcode, const Shape& shape, - absl::Span operands, - HloComputation* reduce_computation, - absl::Span replica_groups, bool constrain_layout, - const std::optional& channel_id, bool use_global_device_ids); - // Returns true if the ids in the ReplicaGroup config represent a global id of // (replica_id * partition_count + partition_id) instead of a replica id. // This enables more flexible grouping of devices if this all-reduce is both