Skip to content

Commit

Permalink
[IFRT] Modify IFRT <-> VIFRT legalization to support escaped SymbolRe…
Browse files Browse the repository at this point in the history
…fAttr.

PiperOrigin-RevId: 700534783
  • Loading branch information
ICGog authored and Google-ML-Automation committed Nov 27, 2024
1 parent 0f6a243 commit 6d52a86
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 11 deletions.
26 changes: 24 additions & 2 deletions xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,19 @@ func.func @op_call(
%1, %ctrl_1 = ifrt.Call @add_one::@main(%0) after %ctrl_0 on devices [0,1]
: (!array_op_call) -> !array_op_call

// Verifies that escaped symbol attr is correctly handled.
// CHECK: %[[OUT2:.+]]:2 = "vifrt.CallV1"(%[[ARG0]])
// CHECK-SAME: <{
// CHECK-DAG: callee = "@escaped-module::@main"
// CHECK-DAG: devices = #vifrt<devices_v1[0, 1]>
// CHECK-DAG: donated_input_indices = array<i32>
// CHECK-DAG: io_aliases = []
// CHECK-DAG: operandSegmentSizes = array<i32: 1, 0>
// CHECK-SAME: }>
// CHECK-SAME: (!vifrt.array_v1<tensor<2x2xi32>, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default">) -> (!vifrt.array_v1<tensor<2x2xi32>, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default">, !vifrt.control_v1)
%2, %ctrl_2 = ifrt.Call @"escaped-module"::@main(%arg0) on devices [0,1]
: (!array_op_call) -> !array_op_call

// Verifies that the donated input indices attribute is converted.

// CHECK: "vifrt.CallV1"(%[[ARG0]])
Expand All @@ -366,7 +379,7 @@ func.func @op_call(
// CHECK-DAG: operandSegmentSizes = array<i32: 1, 0>
// CHECK-SAME: }>
// CHECK-SAME: (!vifrt.array_v1<tensor<2x2xi32>, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default">) -> (!vifrt.array_v1<tensor<2x2xi32>, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default">, !vifrt.control_v1)
%2, %ctrl_2 = ifrt.Call @add_one::@main(%arg0) on devices [0,1]
%3, %ctrl_3 = ifrt.Call @add_one::@main(%arg0) on devices [0,1]
{donated_input_indices=array<i32: 0>} : (!array_op_call) -> !array_op_call

// Verifies that the io_aliases attribute is converted.
Expand All @@ -380,7 +393,7 @@ func.func @op_call(
// CHECK-DAG: operandSegmentSizes = array<i32: 1, 0>
// CHECK-SAME: }>
// CHECK-SAME: (!vifrt.array_v1<tensor<2x2xi32>, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default">) -> (!vifrt.array_v1<tensor<2x2xi32>, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default">, !vifrt.control_v1)
%3, %ctrl_3 = ifrt.Call @add_two::@main(%arg1) on devices [0,1]
%4, %ctrl_4 = ifrt.Call @add_two::@main(%arg1) on devices [0,1]
{io_aliases=[array<i32: 0, 0>]} : (!array_op_call) -> !array_op_call

return %1 : !array_op_call
Expand All @@ -395,6 +408,15 @@ module @add_one attributes {sym_visibility = "private"} {
}
}

// CHECK-NOT @"escaped-module"
module @"escaped-module" attributes {sym_visibility = "private"} {
func.func private @main(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> {
%0 = stablehlo.constant dense<2> : tensor<2x2xi32>
%1 = stablehlo.add %arg0, %0 : tensor<2x2xi32>
return %1 : tensor<2x2xi32>
}
}

// CHECK-NOT @add_two
module @add_two attributes {sym_visibility = "private"} {
func.func private @main(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> {
Expand Down
17 changes: 10 additions & 7 deletions xla/python/ifrt/ir/transforms/ifrt_legalize_to_vifrt_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -343,14 +343,17 @@ class IfrtToVifrtOpConverter : public mlir::OpConversionPattern<IfrtOpTy> {
llvm::DenseSet<mlir::StringAttr> already_converted_attrs;
if constexpr (std::is_same<IfrtOpTy, CallOp>::value) {
auto call_op = static_cast<CallOp>(ifrt_op);
// Convert the callee from SymbolRefAttr to SymbolNameAttr so that DCE
// Convert the callee from SymbolRefAttr to StringAttr so that DCE
// can remove the atom programs, which have independently legalized to
// VHLO.
std::string symbol_ref_str;
{
llvm::raw_string_ostream os(symbol_ref_str);
call_op.getCalleeAttr().print(os);
}
// VHLO. Manually to the conversion by merging RootReference and
// NestedReferences to avoid string escaping.
std::string symbol_ref_str = absl::StrCat(
"@", call_op.getCalleeAttr().getRootReference().getValue().str(),
absl::StrJoin(
call_op.getCalleeAttr().getNestedReferences(), "",
[](std::string* out, const mlir::FlatSymbolRefAttr& symbol_ref) {
absl::StrAppend(out, "::@", symbol_ref.getValue().str());
}));
vifrt_attrs.push_back(
{call_op.getCalleeAttrName(),
mlir::StringAttr::get(call_op.getContext(), symbol_ref_str)});
Expand Down
6 changes: 4 additions & 2 deletions xla/python/ifrt/ir/transforms/vifrt_legalize_to_ifrt_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,10 @@ mlir::FailureOr<mlir::SymbolRefAttr> getCalleeSymbolRef(CallOpV1 call_op) {
if (!callee_symbol_ref_str_attr) {
return mlir::failure();
}
std::vector<std::string> symbol_strs =
absl::StrSplit(callee_symbol_ref_str_attr.str(), absl::ByString("::@"));
// It is important to call `getValue()` on the `StringAttr` to get the
// unescaped string instead of the escaped string.
std::vector<std::string> symbol_strs = absl::StrSplit(
callee_symbol_ref_str_attr.getValue().str(), absl::ByString("::@"));
if (symbol_strs.empty()) {
return mlir::failure();
}
Expand Down

0 comments on commit 6d52a86

Please sign in to comment.