From 79ffbf19214738d94e44e45d042f31339b673315 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Thu, 16 Jan 2025 09:16:01 -0800 Subject: [PATCH] #sdy support StableHLO from refining Shardy ops with polymorphic shapes PiperOrigin-RevId: 716261489 --- third_party/stablehlo/temporary.patch | 229 ++++++++++ xla/mlir_hlo/BUILD | 5 + .../stablehlo_ext/transforms/CMakeLists.txt | 1 + .../transforms/sdy_refine_shapes.cpp | 412 ++++++++++++++++++ .../transforms/sdy_refine_shapes.h | 32 ++ .../transforms/stablehlo_refine_shapes.cpp | 18 +- .../transforms/stablehlo_refine_shapes.h | 32 ++ .../stablehlo_ext/sdy_refine_shapes.mlir | 238 ++++++++++ .../tools/mlir-hlo-opt/mlir-hlo-opt.cc | 2 + xla/service/spmd/shardy/BUILD | 1 + 10 files changed, 969 insertions(+), 1 deletion(-) create mode 100644 xla/mlir_hlo/stablehlo_ext/transforms/sdy_refine_shapes.cpp create mode 100644 xla/mlir_hlo/stablehlo_ext/transforms/sdy_refine_shapes.h create mode 100644 xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_refine_shapes.h create mode 100644 xla/mlir_hlo/tests/stablehlo_ext/sdy_refine_shapes.mlir diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 0a91a1ea5e668..2e151cb0db7b5 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -322,6 +322,18 @@ diff --ruN a/stablehlo/stablehlo/tests/interpret/chlo/ragged_dot.mlir b/stablehl + ]> : tensor<2x11x7xf32> + func.return +} +diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir +--- stablehlo/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir ++++ stablehlo/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir +@@ -36,7 +36,7 @@ + + // ----- + +-// expected-error @+1 {{number of refinements must match number of function operands 6 vs 1}} ++// expected-error @+1 {{number of refinements must match number of op operands 6 vs 1}} + func.func @refine_arguments_invalid_arg_num_mismatch(%arg0: tensor) { + return + } diff --ruN a/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp b/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp --- stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp +++ stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp @@ -593,4 +605,221 @@ diff --ruN a/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp b/stable } } // namespace +diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineArguments.cpp b/stablehlo/stablehlo/transforms/StablehloRefineArguments.cpp +--- stablehlo/stablehlo/transforms/StablehloRefineArguments.cpp ++++ stablehlo/stablehlo/transforms/StablehloRefineArguments.cpp +@@ -72,78 +72,6 @@ + Type type = mlir::parseType(shape, context); + if (!type) return module->emitOpError("Invalid type string: ") << shape; + refinedTypes.push_back(type); +- } +- return success(); +-} +- +-LogicalResult refinementError(func::FuncOp func, int64_t idx, Type argType, +- Type refinedType, StringRef msg) { +- return func.emitOpError() +- << "invalid refinement for argument " << idx << ", refinement " << msg +- << " in " << mlir::debugString(argType) << " -> " +- << mlir::debugString(refinedType); +-} +- +-// Validates refinement types: +-// - A type refinement must be specified for each operand +-// - Refinement types that match operand types are skipped +-// - Refinement types that do not match operands must be refining tensors +-// - Refined tensor types must be ranked, operand type can be unranked +-// - Refined tensor types must match operand type for all static dimensions +-// +-LogicalResult validateRefinedTypes(func::FuncOp func, TypeRange refinedTypes) { +- // Validate refined shapes +- if (func.getNumArguments() != refinedTypes.size()) { +- return func.emitOpError( +- "number of refinements must match number of function operands ") +- << refinedTypes.size() << " vs " << func.getNumArguments(); +- } +- +- // Validate that refinements are valid +- auto argTypes = func.getArgumentTypes(); +- for (int64_t i = 0; i < func.getNumArguments(); ++i) { +- Type type = argTypes[i]; +- Type refinedType = refinedTypes[i]; +- +- // Always allow skipping refinement +- if (type == refinedType) continue; +- +- // If mismatched, must be tensor types +- auto tensorType = dyn_cast(type); +- auto refinedTensorType = dyn_cast(refinedType); +- if (!tensorType || !refinedTensorType) { +- return refinementError(func, i, type, refinedType, "must be a tensor"); +- } +- +- // Check that element types match +- if (tensorType.getElementType() != refinedTensorType.getElementType()) { +- return refinementError(func, i, type, refinedType, +- "element types must match"); +- } +- +- // Refined rank cannot be unranked if mismatch +- if (isa(refinedType)) { +- return refinementError(func, i, type, refinedType, "must be ranked"); +- } +- +- // Unranked operands can be refined to anything +- if (!tensorType.hasRank()) continue; +- +- // Validate ranks match if ranked (must allow unranked tensorType) +- if (tensorType.getRank() != refinedTensorType.getRank()) { +- return refinementError(func, i, type, refinedType, +- "rank must match operand rank"); +- } +- +- // Validate static dimension sizes match +- for (auto [dimSize, refinedDimSize] : +- llvm::zip(tensorType.getShape(), refinedTensorType.getShape())) { +- if (!ShapedType::isDynamic(dimSize) && dimSize != refinedDimSize) { +- return refinementError( +- func, i, type, refinedType, +- "dimension sizes must match for static dimensions"); +- } +- } + } + return success(); + } +@@ -219,9 +147,74 @@ + + } // namespace + ++LogicalResult refinementError(Operation* op, int64_t idx, Type argType, ++ Type refinedType, StringRef msg) { ++ return op->emitOpError() ++ << "invalid refinement for argument " << idx << ", refinement " << msg ++ << " in " << mlir::debugString(argType) << " -> " ++ << mlir::debugString(refinedType); ++} ++ ++LogicalResult validateRefinedTypes(Operation* op, TypeRange argTypes, TypeRange refinedTypes) { ++ // Validate refined shapes ++ if (argTypes.size() != refinedTypes.size()) { ++ return op->emitOpError( ++ "number of refinements must match number of op operands ") ++ << refinedTypes.size() << " vs " << argTypes.size(); ++ } ++ ++ // Validate that refinements are valid ++ for (int64_t i = 0; i < argTypes.size(); ++i) { ++ Type type = argTypes[i]; ++ Type refinedType = refinedTypes[i]; ++ ++ // Always allow skipping refinement ++ if (type == refinedType) continue; ++ ++ // If mismatched, must be tensor types ++ auto tensorType = dyn_cast(type); ++ auto refinedTensorType = dyn_cast(refinedType); ++ if (!tensorType || !refinedTensorType) { ++ return refinementError(op, i, type, refinedType, "must be a tensor"); ++ } ++ ++ // Check that element types match ++ if (tensorType.getElementType() != refinedTensorType.getElementType()) { ++ return refinementError(op, i, type, refinedType, ++ "element types must match"); ++ } ++ ++ // Refined rank cannot be unranked if mismatch ++ if (isa(refinedType)) { ++ return refinementError(op, i, type, refinedType, "must be ranked"); ++ } ++ ++ // Unranked operands can be refined to anything ++ if (!tensorType.hasRank()) continue; ++ ++ // Validate ranks match if ranked (must allow unranked tensorType) ++ if (tensorType.getRank() != refinedTensorType.getRank()) { ++ return refinementError(op, i, type, refinedType, ++ "rank must match operand rank"); ++ } ++ ++ // Validate static dimension sizes match ++ for (auto [dimSize, refinedDimSize] : ++ llvm::zip(tensorType.getShape(), refinedTensorType.getShape())) { ++ if (!ShapedType::isDynamic(dimSize) && dimSize != refinedDimSize) { ++ return refinementError( ++ op, i, type, refinedType, ++ "dimension sizes must match for static dimensions"); ++ } ++ } ++ } ++ return success(); ++} ++ + LogicalResult refineArguments(func::FuncOp func, TypeRange refinedTypes) { + // Verify that refinements are valid +- if (failed(validateRefinedTypes(func, refinedTypes))) return failure(); ++ if (failed(validateRefinedTypes(func, func.getArgumentTypes(), refinedTypes))) ++ return failure(); + + // Wrap refined operands in operand wrapper to keep IR valid for refinement + wrapRefinedOperands(func, refinedTypes); +diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp +--- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp ++++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp +@@ -109,7 +109,13 @@ + // their operands and results. Any operand type in these ops can change + // within what's supported by `inferMostSpecificType` without breaking + // verification of the op. +- if (isa(user->getDialect())) ++ if (isa( ++ user->getDialect())) ++ continue; ++ // TODO(bartchr): Consider if the dialect allow-listing approach is too ++ // strict. In the meantime, allow some shape interop with the shardy ++ // dialect. ++ if (user->getDialect()->getNamespace() == "sdy") + continue; + + // Simply changing operand type of `func.return` won't work because +diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.h b/stablehlo/stablehlo/transforms/StablehloRefineShapes.h +--- stablehlo/stablehlo/transforms/StablehloRefineShapes.h ++++ stablehlo/stablehlo/transforms/StablehloRefineShapes.h +@@ -16,18 +16,37 @@ + #ifndef STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H + #define STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H + ++#include ++ ++#include "llvm/ADT/SmallVector.h" + #include "mlir/Dialect/Func/IR/FuncOps.h" + #include "mlir/IR/BuiltinOps.h" + #include "mlir/IR/Operation.h" + #include "mlir/IR/PatternMatch.h" ++#include "mlir/IR/TypeRange.h" + #include "mlir/IR/Types.h" + #include "mlir/IR/Value.h" ++#include "mlir/IR/ValueRange.h" + #include "mlir/Interfaces/InferTypeOpInterface.h" ++#include "mlir/Support/LLVM.h" + #include "mlir/Support/LogicalResult.h" + #include "stablehlo/dialect/Base.h" + + namespace mlir { + namespace stablehlo { ++ ++// Emits an error message for invalid refinement. ++LogicalResult refinementError(Operation* op, int64_t idx, Type argType, ++ Type refinedType, StringRef msg); ++ ++// Validates refinement types: ++// - A type refinement must be specified for each operand ++// - Refinement types that match operand types are skipped ++// - Refinement types that do not match operands must be refining tensors ++// - Refined tensor types must be ranked, operand type can be unranked ++// - Refined tensor types must match operand type for all static dimensions ++LogicalResult validateRefinedTypes(Operation* op, TypeRange argTypes, ++ TypeRange refinedTypes); + + // Gets a FuncOp that --stablehlo-refine-shapes will run on. + // Returns a nullptr and emits appropriate errors if such a function cannot diff --git a/xla/mlir_hlo/BUILD b/xla/mlir_hlo/BUILD index 537a645c9e0b4..58bfb6a3dfd12 100644 --- a/xla/mlir_hlo/BUILD +++ b/xla/mlir_hlo/BUILD @@ -979,6 +979,7 @@ cc_binary( "@llvm-project//mlir:MlirOptLib", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@shardy//shardy/dialect/sdy/ir:dialect", "@stablehlo//:register", ], ) @@ -1080,11 +1081,14 @@ cc_library( name = "stablehlo_extension_passes", srcs = [ "stablehlo_ext/transforms/chlo_recompose_ops.cpp", + "stablehlo_ext/transforms/sdy_refine_shapes.cpp", "stablehlo_ext/transforms/stablehlo_canonicalize_dynamism.cpp", "stablehlo_ext/transforms/stablehlo_refine_shapes.cpp", ], hdrs = [ "stablehlo_ext/transforms/passes.h", + "stablehlo_ext/transforms/sdy_refine_shapes.h", + "stablehlo_ext/transforms/stablehlo_refine_shapes.h", ], compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", @@ -1100,6 +1104,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", + "@shardy//shardy/dialect/sdy/ir:dialect", "@stablehlo//:base", "@stablehlo//:chlo_ops", "@stablehlo//:stablehlo_ops", diff --git a/xla/mlir_hlo/stablehlo_ext/transforms/CMakeLists.txt b/xla/mlir_hlo/stablehlo_ext/transforms/CMakeLists.txt index e7cfe812e7a6d..5d8e211740bee 100644 --- a/xla/mlir_hlo/stablehlo_ext/transforms/CMakeLists.txt +++ b/xla/mlir_hlo/stablehlo_ext/transforms/CMakeLists.txt @@ -21,6 +21,7 @@ add_mlir_dialect_library(StablehloExtensionPasses chlo_recompose_ops.cpp stablehlo_canonicalize_dynamism.cpp stablehlo_refine_shapes.cpp + sdy_refine_shapes.cpp DEPENDS StablehloExtensionPassesIncGen diff --git a/xla/mlir_hlo/stablehlo_ext/transforms/sdy_refine_shapes.cpp b/xla/mlir_hlo/stablehlo_ext/transforms/sdy_refine_shapes.cpp new file mode 100644 index 0000000000000..8bedd719d79dd --- /dev/null +++ b/xla/mlir_hlo/stablehlo_ext/transforms/sdy_refine_shapes.cpp @@ -0,0 +1,412 @@ +/* Copyright 2023 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "stablehlo_ext/transforms/sdy_refine_shapes.h" + +#include +#include + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/dialect/sdy/ir/utils.h" +#include "stablehlo/dialect/Base.h" +#include "stablehlo/dialect/ChloOps.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/Passes.h" +#include "stablehlo/transforms/StablehloRefineShapes.h" +#include "stablehlo_ext/transforms/stablehlo_refine_shapes.h" + +namespace mlir { +namespace stablehlo_ext { + +namespace { + +template +void refineBlockArguments(OpTy regionOp, TypeRange refinedTypes) { + Region& body = regionOp.getBody(); + OpBuilder builder(body); + for (int64_t i = 0; i < body.getNumArguments(); ++i) { + auto arg = body.getArgument(i); + arg.setType(refinedTypes[i]); + } +} + +// Refines the values using the given types. +// +// This is similar to `stablehlo::refineValues`, but the problem is that +// `hlo::inferMostSpecificType` doesn't account for the block argument types +// differing from the operand types due to the body having the local types. +// So to figure out the more specific type, we transform the refinement of +// the operand to the local refinement. +// +// For example: +// +// ``` +// %0 = sdy.manual_computation(%0) +// in_shardings=[<@mesh, [{"x"}, {}]>] +// out_shardings=[<@mesh, [{"x"}, {}]>] +// manual_axes={"x"} (%arg1: tensor<2x?xf32>) { +// ... +// } : (tensor<4x?xf32>) -> tensor<4x?xf32> +// ``` +// +// The global and local types differ for the known static dimension of the +// operand, so we need to convert the global refinement to the local refinement +// to figure out the more specific type. +LogicalResult refineValues( + PatternRewriter& rewriter, sdy::ManualComputationOp manualComputation, + ArrayRef blockArguments, TypeRange types, + sdy::MeshAttr mesh) { + if (blockArguments.size() != types.size()) { + return rewriter.notifyMatchFailure( + manualComputation, [&](Diagnostic& diag) { + diag << "refineValues failed for " << types << ": expected " + << blockArguments.size() << " types, got " << types.size(); + }); + } + + // Check whether `types` contain any new information with respect to + // existing return types. Even if just a single dimension size out of an + // entire tensor type got updated, using `inferMostSpecificType` ensures + // that we don't miss that. + bool needsRefinement = false; + SmallVector refinedTypes; + for (auto it : llvm::zip(blockArguments, types)) { + // Cannot use structured bindings to simplify this because capturing + // structured bindings in a lambda is a C++ 20 extension. + BlockArgument blockArg = std::get<0>(it); + Type blockArgType = blockArg.getType(); + auto refinement = cast(std::get<1>(it)); + // inferMostSpecificType cannot account for the fact that the operand and + // block arg types differ for their known static dimensions due to the body + // having the local types. So to figure out the more specific type, + // transform the refinement of the operand to the local refinement. + sdy::TensorShardingAttr inSharding = eraseFreeAxes( + manualComputation.getInSharding(blockArg.getArgNumber()), + manualComputation.getManualAxes()); + auto refinedType = hlo::inferMostSpecificType( + /*location=*/{}, {blockArgType, inSharding.getLocalTensorType( + refinement, mesh)}); + if (failed(refinedType)) { + return rewriter.notifyMatchFailure(manualComputation, + [&](Diagnostic& diag) { + diag << "inferMostSpecificType failed for " << blockArgType << " and " + << refinement; + }); + } + refinedTypes.push_back(*refinedType); + needsRefinement |= (blockArgType != *refinedType); + } + if (!needsRefinement) + return rewriter.notifyMatchFailure( + manualComputation, "doesn't need refinement"); + + for (auto it : llvm::zip(blockArguments, refinedTypes)) { + // Cannot use structured bindings to simplify this because capturing + // structured bindings in a lambda is a C++ 20 extension. + auto value = std::get<0>(it); + auto refinedType = std::get<1>(it); + if (value.getType() == refinedType) continue; + + // Check whether the users of this value are ready for the type of the + // value to be refined. + for (Operation* user : value.getUsers()) { + // CHLO and StableHLO ops are designed to support type refinements of + // their operands and results. Any operand type in these ops can change + // within what's supported by `inferMostSpecificType` without breaking + // verification of the op. + if (isa( + user->getDialect())) + continue; + + // Simply changing operand type of `func.return` won't work because + // that won't update the FunctionType of the enclosing `func.func`. + if (isa(user)) continue; + if (isa(user)) continue; + + // Unlike in TensorFlow's type inference pass, here we work only with + // allowlisted ops to focus our support on well-defined semantics of + // StableHLO programs. + return rewriter.notifyMatchFailure( + manualComputation, [&](Diagnostic& diag) { + diag << "unsupported refinement: tried to refine " << value.getType() + << " to " << refinedType << " for user " << user; + }); + } + + // Happy path: simply call setType here because most of our users are + // fine with that. + auto unrefinedType = value.getType(); + value.setType(refinedType); + + // Special case: for `func.return`, guard the refinement with a cast + // and leave propagation of the refined return type to a dedicated pattern. + auto isFuncReturn = [](OpOperand& use) -> bool { + return isa(use.getOwner()); + }; + if (llvm::none_of(value.getUses(), isFuncReturn)) continue; + rewriter.setInsertionPointAfter(manualComputation); + auto castToUnrefinedType = rewriter.create( + manualComputation->getLoc(), unrefinedType, value); + value.replaceUsesWithIf(castToUnrefinedType.getOutputs()[0], isFuncReturn); + } + + return success(); +} + +LogicalResult refineArguments(sdy::ManualComputationOp manualComputation, + TypeRange refinedTypes, + sdy::MeshAttr mesh, + PatternRewriter& rewriter) { + // Verify that refinements are valid + if (failed(stablehlo::validateRefinedTypes( + manualComputation, manualComputation.getBody().getArgumentTypes(), + refinedTypes))) + return failure(); + + if (failed(refineValues(rewriter, manualComputation, + manualComputation.getBody().getArguments(), + manualComputation.getOperandTypes(), mesh))) { + return failure(); + } + + // Actually update block argument types. + refineBlockArguments(manualComputation, refinedTypes); + + return success(); +} + +LogicalResult refineArguments(sdy::NamedComputationOp namedComputation, + TypeRange refinedTypes, + PatternRewriter& rewriter) { + // Verify that refinements are valid + if (failed(stablehlo::validateRefinedTypes( + namedComputation, namedComputation.getBody().getArgumentTypes(), + refinedTypes))) + return failure(); + + if (failed(stablehlo::refineValues(rewriter, namedComputation, + namedComputation.getBody().getArguments(), + namedComputation.getOperandTypes()))) { + return failure(); + } + + // Actually update block argument types. + refineBlockArguments(namedComputation, refinedTypes); + + return success(); +} + + +LogicalResult refineManualComputationBody( + sdy::ManualComputationOp manualComputation, PatternRewriter& rewriter); + +struct RefineManualComputationOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(sdy::ManualComputationOp op, + PatternRewriter& rewriter) const override { + return refineManualComputationBody(op, rewriter); + } +}; + +// Applies shape refinement patterns to `ManualComputationOp`. +// NOTE: This will only work when the program is inlined. Refining `CallOp`s +// nested in `ManualComputationOp`s would require keeping track of scope, which +// requires more refactoring of the base pass. Likely requiring exposing +// `RefinementState` and a method for invoking all patterns available on +// the `ManualComputationOp` body. +template +LogicalResult applyShapeRefinementPatterns(OpTy regionOp) { + MLIRContext* context = regionOp.getContext(); + RewritePatternSet patterns(context); + GreedyRewriteConfig config; + + // The algorithm behind this pass consists of a single traversal of the + // function. This is sufficient because we only support one function per + // program at the moment. + // TODO(#1048): Find out why .maxIterations = 1 no longer works. + // There have been recent refactors to applyPatternsGreedily + // upstream, and that might be the reason. + config.useTopDownTraversal = true; + config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive; + config.maxIterations = 2; + config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; + config.strictMode = GreedyRewriteStrictness::AnyOp; + + populateStablehloExtRefineShapesPatterns(&patterns, context); + patterns.add(context); + + // The folding patterns implement partial evaluation of shape computations + // which is a critical part of implementing type refinement for ops like + // dynamic_broadcast_in_dim, dynamic_iota and dynamic_reshape whose shape + // depends on the value of their shape operands. + stablehlo::populateStablehloShapeFolderPatterns(&patterns, context); + + if (failed(applyPatternsGreedily(regionOp, std::move(patterns), config))) + regionOp.emitError("Failed to converge StablehloRefineShapes in ") + << config.maxIterations << " iterations"; + + return success(); +} + +// Manual computations need to be refined in-order like function calls +// since the output shape depends on the shape of the return in the +// ops body region, with some transformation based on the mesh, out shardings, +// and manual axes. +// +// For example, if the result is `tensor`, the op is manual on axis +// `x=2`, and the resut has sharding `<@mesh, [{"x"}]>`, then if the local type +// in the body region is `tensor<4xf32>`, the global type is `tensor<8xf32>`. +// And so we need to update the global result type of the manual computation +// to be `tensor<8xf32>` to reflect the actual shape of the result. +LogicalResult refineManualComputationBody( + sdy::ManualComputationOp manualComputation, PatternRewriter& rewriter) { + rewriter.setInsertionPointToStart(&manualComputation.getRegion().front()); + + SymbolTable symbolTable(manualComputation->getParentOfType()); + ArrayRef manualAxes = manualComputation.getManualAxes(); + sdy::MeshAttr mesh = sdy::getCommonMesh( + manualComputation.getInShardings().getShardings(), + manualComputation.getOutShardings().getShardings(), symbolTable); + + // Convert the global types to local types using the sharding consisting only + // of manual axes. + SmallVector localBlockArgTypes; + localBlockArgTypes.reserve(manualComputation.getNumOperands()); + for (auto [arg, globalType, inSharding] : + llvm::zip_equal(manualComputation.getBody().getArguments(), + manualComputation->getOperandTypes(), + manualComputation.getInShardings().getShardings())) { + localBlockArgTypes.push_back( + sdy::eraseFreeAxes(inSharding, manualAxes) + .getLocalTensorType(cast(globalType), mesh)); + } + + if (failed(refineArguments(manualComputation, localBlockArgTypes, + mesh, rewriter))) + return failure(); + + // Now iterate into the function body and apply refinement patterns. + if (failed(applyShapeRefinementPatterns(manualComputation))) return failure(); + + // Convert the local types to global types using the sharding consisting only + // of manual axes. + SmallVector globalResultTypes; + globalResultTypes.reserve(manualComputation.getNumResults()); + for (auto [localType, sharding] : + llvm::zip_equal(sdy::getBodyTerminatorOpOperandTypes(manualComputation), + manualComputation.getOutShardings().getShardings())) { + globalResultTypes.push_back( + sdy::eraseFreeAxes(sharding, manualAxes) + .getGlobalTensorType(cast(localType), mesh)); + } + + return stablehlo::refineReturnTypes(rewriter, manualComputation, + globalResultTypes); +} + +LogicalResult refineNamedComputationOpPattern( + sdy::NamedComputationOp namedComputation, PatternRewriter& rewriter); + +struct RefineNamedComputationOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(sdy::NamedComputationOp op, + PatternRewriter& rewriter) const override { + return refineNamedComputationOpPattern(op, rewriter); + } +}; + +LogicalResult refineNamedComputationOpPattern( + sdy::NamedComputationOp namedComputation, PatternRewriter& rewriter) { + rewriter.setInsertionPointToStart(&namedComputation.getRegion().front()); + + SymbolTable symbolTable(namedComputation->getParentOfType()); + + if (failed(refineArguments(namedComputation, + namedComputation.getOperandTypes(), rewriter))) + return failure(); + + // Now iterate into the function body and apply refinement patterns. + if (failed(applyShapeRefinementPatterns(namedComputation))) return failure(); + + // TODO(bartchr): Should be able to call `getBodyTerminatorOpOperandTypes` + // but getting a refined template compilation error. + return stablehlo::refineReturnTypes( + rewriter, namedComputation, + llvm::to_vector(namedComputation.getBody().front().getTerminator() + ->getOperandTypes())); +} + +struct RefineInferTypeOpInterfacePattern + : public OpInterfaceRewritePattern { + explicit RefineInferTypeOpInterfacePattern(MLIRContext* context) + : OpInterfaceRewritePattern(context, /*benefit=*/0) {} + LogicalResult matchAndRewrite(InferTypeOpInterface op, + PatternRewriter& rewriter) const override { + // Unlike in TensorFlow's type inference pass, here we work only with + // allowlisted ops to focus our support on well-defined semantics of + // StableHLO programs. + if (!isa( + op->getDialect())) + return rewriter.notifyMatchFailure(op, "unsupported dialect"); + + // For the ops that implement InferTypeOpInterface, we reinfer their return + // types and see what happens. + // Operands of these ops might have been refined elsewhere (e.g. someone + // might have updated argument types of a function) or earlier during this + // pass, and this might enable refinement opportunities downstream. + SmallVector inferredReturnTypes; + if (failed(op.inferReturnTypes(getContext(), /*location=*/{}, + op->getOperands(), op->getAttrDictionary(), + op->getPropertiesStorage(), op->getRegions(), + inferredReturnTypes))) + return rewriter.notifyMatchFailure(op, "inferReturnTypes failed"); + return stablehlo::refineReturnTypes(rewriter, op, inferredReturnTypes); + } +}; + +} // namespace + +/// Patterns for refining shapes of Shardy ops. +void populateSdyShapeRefinementPatterns(RewritePatternSet* patterns, + MLIRContext* context) { + patterns->add(context); + patterns->add(context); + patterns->add(context); +} + +} // namespace stablehlo_ext +} // namespace mlir diff --git a/xla/mlir_hlo/stablehlo_ext/transforms/sdy_refine_shapes.h b/xla/mlir_hlo/stablehlo_ext/transforms/sdy_refine_shapes.h new file mode 100644 index 0000000000000..12a4b35fa362c --- /dev/null +++ b/xla/mlir_hlo/stablehlo_ext/transforms/sdy_refine_shapes.h @@ -0,0 +1,32 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_MLIR_HLO_STABLEHLO_EXT_TRANSFORMS_SDY_REFINE_SHAPES_H_ +#define XLA_MLIR_HLO_STABLEHLO_EXT_TRANSFORMS_SDY_REFINE_SHAPES_H_ + +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace stablehlo_ext { + +/// Populates extension patterns for refining shapes of Shardy ops. +void populateSdyShapeRefinementPatterns(RewritePatternSet* patterns, + MLIRContext* context); + +} // namespace stablehlo_ext +} // namespace mlir + +#endif // XLA_MLIR_HLO_STABLEHLO_EXT_TRANSFORMS_SDY_REFINE_SHAPES_H_ diff --git a/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp b/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp index 37effdeadd65a..7d75ad975a23a 100644 --- a/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp +++ b/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp @@ -12,16 +12,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "stablehlo_ext/transforms/stablehlo_refine_shapes.h" + #include #include #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "stablehlo/dialect/Base.h" #include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/dialect/TypeInference.h" @@ -30,6 +33,7 @@ limitations under the License. #include "stablehlo_ext/IR/base.h" #include "stablehlo_ext/IR/stablehlo_ops.h" #include "stablehlo_ext/transforms/passes.h" // NOLINT: Used in passes.h.inc +#include "stablehlo_ext/transforms/sdy_refine_shapes.h" namespace mlir { namespace stablehlo_ext { @@ -149,6 +153,7 @@ struct StablehloRefineShapesPass patterns->add(context); patterns->add(context); patterns->add(context); + populateSdyShapeRefinementPatterns(patterns, context); }; if (failed(stablehlo::refineEntryFunction(*context, func, @@ -158,5 +163,16 @@ struct StablehloRefineShapesPass }; } // namespace + +void populateStablehloExtRefineShapesPatterns(RewritePatternSet *patterns, + MLIRContext *context) { + stablehlo::populateStablehloRefineShapesPatterns(patterns, context); + stablehlo::populateStablehloShapeFolderPatterns(patterns, context); + patterns->add(context); + patterns->add(context); + patterns->add(context); + populateSdyShapeRefinementPatterns(patterns, context); +} + } // namespace stablehlo_ext } // namespace mlir diff --git a/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_refine_shapes.h b/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_refine_shapes.h new file mode 100644 index 0000000000000..8d2196bafff68 --- /dev/null +++ b/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_refine_shapes.h @@ -0,0 +1,32 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_MLIR_HLO_STABLEHLO_EXT_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H_ +#define XLA_MLIR_HLO_STABLEHLO_EXT_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H_ + +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace stablehlo_ext { + +// Populates extension patterns for refining shapes of StableHLO and Shardy ops. +void populateStablehloExtRefineShapesPatterns(RewritePatternSet *patterns, + MLIRContext *context); + +} // namespace stablehlo_ext +} // namespace mlir + +#endif // XLA_MLIR_HLO_STABLEHLO_EXT_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H_ diff --git a/xla/mlir_hlo/tests/stablehlo_ext/sdy_refine_shapes.mlir b/xla/mlir_hlo/tests/stablehlo_ext/sdy_refine_shapes.mlir new file mode 100644 index 0000000000000..f9fc6014c47b2 --- /dev/null +++ b/xla/mlir_hlo/tests/stablehlo_ext/sdy_refine_shapes.mlir @@ -0,0 +1,238 @@ +// RUN: mlir-hlo-opt %s -stablehlo-ext-refine-shapes --split-input-file 2>&1 | FileCheck %s + +// Only the operand is manual. + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +// CHECK-LABEL: func @main +// CHECK-SAME: (%arg0: tensor<16x32xf32>) -> tensor<8x32xf32> +func.func @main(%arg0: tensor<16x32xf32>) -> tensor { + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg0 : tensor<16x32xf32> + // CHECK-NEXT: %[[MC:.*]] = sdy.manual_computation(%0) + // CHECK-SAME: in_shardings=[<@mesh, [{"a", ?}, {?}]>] + // CHECK-SAME: out_shardings=[<@mesh, [{?}, {?}], replicated={"a"}>] + // CHECK-SAME: manual_axes={"a"} (%arg1: tensor<8x32xf32>) { + // CHECK-NEXT: %[[ADD_2:.*]] = stablehlo.add %arg1, %arg1 : tensor<8x32xf32> + // CHECK-NEXT: sdy.return %[[ADD_2]] : tensor<8x32xf32> + // CHECK-NEXT: } : (tensor<16x32xf32>) -> tensor<8x32xf32> + // CHECK-NEXT: return %[[MC]] : tensor<8x32xf32> + %0 = stablehlo.add %arg0, %arg0 : (tensor<16x32xf32>, tensor<16x32xf32>) -> tensor + %1 = sdy.manual_computation(%0) in_shardings=[<@mesh, [{"a", ?}, {?}]>] out_shardings=[<@mesh, [{?}, {?}], replicated={"a"}>] manual_axes={"a"} (%arg1: tensor) { + %2 = stablehlo.add %arg1, %arg1 : tensor + sdy.return %2 : tensor + } : (tensor) -> tensor + return %1: tensor +} + +// ----- + +// Only the result is manual. + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +// CHECK-LABEL: func @main +// CHECK-SAME: (%arg0: tensor<16x32xf32>) -> tensor<32x32xf32> +func.func @main(%arg0: tensor<16x32xf32>) -> tensor { + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg0 : tensor<16x32xf32> + // CHECK-NEXT: %[[MC:.*]] = sdy.manual_computation(%0) + // CHECK-SAME: in_shardings=[<@mesh, [{?}, {?}], replicated={"a"}>] + // CHECK-SAME: out_shardings=[<@mesh, [{"a", ?}, {?}]>] + // CHECK-SAME: manual_axes={"a"} (%arg1: tensor<16x32xf32>) { + // CHECK-NEXT: %[[ADD_2:.*]] = stablehlo.add %arg1, %arg1 : tensor<16x32xf32> + // CHECK-NEXT: sdy.return %[[ADD_2]] : tensor<16x32xf32> + // CHECK-NEXT: } : (tensor<16x32xf32>) -> tensor<32x32xf32> + // CHECK-NEXT: return %[[MC]] : tensor<32x32xf32> + %0 = stablehlo.add %arg0, %arg0 : (tensor<16x32xf32>, tensor<16x32xf32>) -> tensor + %1 = sdy.manual_computation(%0) in_shardings=[<@mesh, [{?}, {?}], replicated={"a"}>] out_shardings=[<@mesh, [{"a", ?}, {?}]>] manual_axes={"a"} (%arg1: tensor) { + %2 = stablehlo.add %arg1, %arg1 : tensor + sdy.return %2 : tensor + } : (tensor) -> tensor + return %1: tensor +} + +// ----- + +// Both operand and result are manual. + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +// CHECK-LABEL: func @main +// CHECK-SAME: (%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> +func.func @main(%arg0: tensor<16x32xf32>) -> tensor { + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg0 : tensor<16x32xf32> + // CHECK-NEXT: %[[MC:.*]] = sdy.manual_computation(%0) + // CHECK-SAME: in_shardings=[<@mesh, [{"a", ?}, {?}]>] + // CHECK-SAME: out_shardings=[<@mesh, [{"a", ?}, {?}]>] + // CHECK-SAME: manual_axes={"a"} (%arg1: tensor<8x32xf32>) { + // CHECK-NEXT: %[[ADD_2:.*]] = stablehlo.add %arg1, %arg1 : tensor<8x32xf32> + // CHECK-NEXT: sdy.return %[[ADD_2]] : tensor<8x32xf32> + // CHECK-NEXT: } : (tensor<16x32xf32>) -> tensor<16x32xf32> + // CHECK-NEXT: return %[[MC]] : tensor<16x32xf32> + %0 = stablehlo.add %arg0, %arg0 : (tensor<16x32xf32>, tensor<16x32xf32>) -> tensor + %1 = sdy.manual_computation(%0) in_shardings=[<@mesh, [{"a", ?}, {?}]>] out_shardings=[<@mesh, [{"a", ?}, {?}]>] manual_axes={"a"} (%arg1: tensor) { + %2 = stablehlo.add %arg1, %arg1 : tensor + sdy.return %2 : tensor + } : (tensor) -> tensor + return %1: tensor +} + +// ----- + +// The dimension being refined is not the one which is manually sharded. + +sdy.mesh @mesh = <["x"=2]> + +// CHECK-LABEL: func @main +// CHECK-SAME: (%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> +func.func @main(%arg0: tensor<4x4xf32>) -> tensor<4x?xf32> { + // CHECK-NEXT: %[[ABS:.*]] = stablehlo.abs %arg0 : tensor<4x4xf32> + // CHECK-NEXT: %[[MC:.*]] = sdy.manual_computation(%[[ABS]]) + // CHECK-SAME: in_shardings=[<@mesh, [{"x"}, {}]>] + // CHECK-SAME: out_shardings=[<@mesh, [{"x"}, {}]>] + // CHECK-SAME: manual_axes={"x"} (%arg1: tensor<2x4xf32>) { + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg1, %arg1 : tensor<2x4xf32> + // CHECK-NEXT: sdy.return %[[ADD]] : tensor<2x4xf32> + // CHECK-NEXT: } : (tensor<4x4xf32>) -> tensor<4x4xf32> + // CHECK-NEXT: return %[[MC]] : tensor<4x4xf32> + %9 = stablehlo.abs %arg0 : (tensor<4x4xf32>) -> tensor<4x?xf32> + %0 = sdy.manual_computation(%9) in_shardings=[<@mesh, [{"x"}, {}]>] out_shardings=[<@mesh, [{"x"}, {}]>] manual_axes={"x"} (%arg1: tensor<2x?xf32>) { + %1 = stablehlo.add %arg1, %arg1 : tensor<2x?xf32> + sdy.return %1 : tensor<2x?xf32> + } : (tensor<4x?xf32>) -> tensor<4x?xf32> + return %0 : tensor<4x?xf32> +} + +// ----- + +// Body of named computation has all SDY operations. + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +// CHECK-LABEL: func @main +// CHECK-SAME: (%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> +func.func @main(%arg0: tensor<16x32xf32>) -> tensor { + // CHECK-NEXT: %[[ABS:.*]] = stablehlo.abs %arg0 : tensor<16x32xf32> + // CHECK-NEXT: %[[NC:.*]] = sdy.named_computation<"foo">(%[[ABS]]) (%arg1: tensor<16x32xf32>) { + // CHECK-NEXT: %[[SC:.*]] = sdy.sharding_constraint %arg1 <@mesh, [{"b"}, {?}]> : tensor<16x32xf32> + // CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %[[SC]] <@mesh, [{"b", "a"}, {?}]> : tensor<16x32xf32> + // CHECK-NEXT: sdy.sharding_group %[[RESHARD]] group_id=0 : tensor<16x32xf32> + // CHECK-NEXT: sdy.return %[[RESHARD]] : tensor<16x32xf32> + // CHECK-NEXT: } : (tensor<16x32xf32>) -> tensor<16x32xf32> + // CHECK-NEXT: return %[[NC]] : tensor<16x32xf32> + %0 = stablehlo.abs %arg0 : (tensor<16x32xf32>) -> tensor + %1 = sdy.named_computation<"foo">(%0) (%arg1: tensor) { + %2 = sdy.sharding_constraint %arg1 <@mesh, [{"b"}, {?}]> : tensor + %3 = sdy.reshard %2 <@mesh, [{"b", "a"}, {?}]> : tensor + sdy.sharding_group %3 group_id=0 : tensor + sdy.return %3 : tensor + } : (tensor) -> tensor + return %1: tensor +} + +// ----- + +// Body of manual computation has all SDY operations. + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +// CHECK-LABEL: func @main +// CHECK-SAME: (%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> +func.func @main(%arg0: tensor<16x32xf32>) -> tensor { + // CHECK-NEXT: %[[ABS:.*]] = stablehlo.abs %arg0 : tensor<16x32xf32> + // CHECK-NEXT: %[[MC:.*]] = sdy.manual_computation(%0) + // CHECK-SAME: in_shardings=[<@mesh, [{?}, {?}]>] + // CHECK-SAME: out_shardings=[<@mesh, [{?}, {?}]>] + // CHECK-SAME: manual_axes={} (%arg1: tensor<16x32xf32>) { + // CHECK-NEXT: %[[SC:.*]] = sdy.sharding_constraint %arg1 <@mesh, [{"b"}, {?}]> : tensor<16x32xf32> + // CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %[[SC]] <@mesh, [{"b", "a"}, {?}]> : tensor<16x32xf32> + // CHECK-NEXT: sdy.sharding_group %[[RESHARD]] group_id=0 : tensor<16x32xf32> + // CHECK-NEXT: sdy.return %[[RESHARD]] : tensor<16x32xf32> + // CHECK-NEXT: } : (tensor<16x32xf32>) -> tensor<16x32xf32> + // CHECK-NEXT: return %[[MC]] : tensor<16x32xf32> + %0 = stablehlo.abs %arg0 : (tensor<16x32xf32>) -> tensor + %1 = sdy.manual_computation(%0) in_shardings=[<@mesh, [{?}, {?}]>] out_shardings=[<@mesh, [{?}, {?}]>] manual_axes={} (%arg1: tensor) { + %2 = sdy.sharding_constraint %arg1 <@mesh, [{"b"}, {?}]> : tensor + %3 = sdy.reshard %2 <@mesh, [{"b", "a"}, {?}]> : tensor + sdy.sharding_group %3 group_id=0 : tensor + sdy.return %3 : tensor + } : (tensor) -> tensor + return %1: tensor +} + +// ----- + +// Body of the manual computation has a call. +// TODO(b/385323320): the function is not being fully refined due to the call. + +sdy.mesh @mesh = <["a"=2]> + +// CHECK-LABEL: func @main +// CHECK-SAME: (%arg0: tensor<4xf32>) -> tensor +func.func @main(%arg0: tensor<4xf32>) -> (tensor) { + // CHECK-NEXT: %[[C:.*]] = stablehlo.constant dense<4> : tensor + // CHECK-NEXT: %[[CONVERT:.*]] = stablehlo.bitcast_convert %arg0 : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %[[MC:.*]] = sdy.manual_computation(%[[CONVERT]], %[[C]]) + // CHECK-SAME: in_shardings=[<@mesh, [{"a"}]>, <@mesh, []>] + // CHECK-SAME: out_shardings=[<@mesh, [{?}], replicated={"a"}>] + // CHECK-SAME: manual_axes={"a"} (%arg1: tensor<2xf32>, %arg2: tensor) { + // CHECK-NEXT: %[[CALL:.*]] = func.call @refine_call_callee(%arg2, %arg1) : (tensor, tensor<2xf32>) -> tensor + // CHECK-NEXT: sdy.return %[[CALL:.*]] : tensor + // CHECK-NEXT: } : (tensor<4xf32>, tensor) -> tensor + // CHECK-NEXT: return %[[MC]] : tensor + %0 = stablehlo.bitcast_convert %arg0 : (tensor<4xf32>) -> tensor + %1 = stablehlo.constant dense<4> : tensor + %2 = sdy.manual_computation(%0, %1) in_shardings=[<@mesh, [{"a"}]>, <@mesh, []>] out_shardings=[<@mesh, [{?}], replicated={"a"}>] manual_axes={"a"} (%arg1: tensor, %arg2: tensor) { + %3 = func.call @refine_call_callee(%arg2, %arg1) : (tensor, tensor) -> tensor + sdy.return %3 : tensor + } : (tensor, tensor) -> tensor + return %2: tensor +} + +// CHECK-LABEL: func @refine_call_callee +// CHECK-SAME: (%arg0: tensor, %arg1: tensor<2xf32>) -> tensor +func.func @refine_call_callee(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-NEXT: %[[RESHAPE:.*]] = stablehlo.reshape %arg0 : (tensor) -> tensor<1xi32> + // CHECK-NEXT: %[[IOTA:.*]] = stablehlo.dynamic_iota %[[RESHAPE]], dim = 0 : (tensor<1xi32>) -> tensor + // CHECK-NEXT: return %[[IOTA]] : tensor + %0 = stablehlo.reshape %arg0 : (tensor) -> tensor<1xi32> + %1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi32>) -> tensor + return %1 : tensor +} + + +// ----- + +// Body of the named computation has a call. +// TODO(b/385323320): the function is not being fully refined due to the call. + +sdy.mesh @mesh = <["a"=2]> + +// CHECK-LABEL: func @main +// CHECK-SAME: (%arg0: tensor<4xf32>) -> tensor +func.func @main(%arg0: tensor<4xf32>) -> (tensor) { + // CHECK-NEXT: %[[C:.*]] = stablehlo.constant dense<4> : tensor + // CHECK-NEXT: %[[CONVERT:.*]] = stablehlo.bitcast_convert %arg0 : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %[[NC:.*]] = sdy.named_computation<"foo">(%[[CONVERT]], %[[C]]) (%arg1: tensor<4xf32>, %arg2: tensor) { + // CHECK-NEXT: %[[CALL:.*]] = func.call @refine_call_callee(%arg2, %arg1) : (tensor, tensor<4xf32>) -> tensor + // CHECK-NEXT: sdy.return %[[CALL:.*]] : tensor + // CHECK-NEXT: } : (tensor<4xf32>, tensor) -> tensor + // CHECK-NEXT: return %[[NC]] : tensor + %0 = stablehlo.bitcast_convert %arg0 : (tensor<4xf32>) -> tensor + %1 = stablehlo.constant dense<4> : tensor + %2 = sdy.named_computation<"foo">(%0, %1) (%arg1: tensor, %arg2: tensor) { + %3 = func.call @refine_call_callee(%arg2, %arg1) : (tensor, tensor) -> tensor + sdy.return %3 : tensor + } : (tensor, tensor) -> tensor + return %2: tensor +} + +// CHECK-LABEL: func @refine_call_callee +// CHECK-SAME: (%arg0: tensor, %arg1: tensor<4xf32>) -> tensor +func.func @refine_call_callee(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-NEXT: %[[RESHAPE:.*]] = stablehlo.reshape %arg0 : (tensor) -> tensor<1xi32> + // CHECK-NEXT: %[[IOTA:.*]] = stablehlo.dynamic_iota %[[RESHAPE]], dim = 0 : (tensor<1xi32>) -> tensor + // CHECK-NEXT: return %[[IOTA]] : tensor + %0 = stablehlo.reshape %arg0 : (tensor) -> tensor<1xi32> + %1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi32>) -> tensor + return %1 : tensor +} diff --git a/xla/mlir_hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cc b/xla/mlir_hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cc index b447485075dc3..f018cbdcaf772 100644 --- a/xla/mlir_hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cc +++ b/xla/mlir_hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/InitAllExtensions.h" #include "mlir/InitAllPasses.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "shardy/dialect/sdy/ir/dialect.h" #include "stablehlo/dialect/Register.h" #include "stablehlo_ext/transforms/passes.h" #include "transforms/gpu_passes.h" @@ -40,5 +41,6 @@ int main(int argc, char** argv) { registerAllExtensions(registry); mhlo::registerAllMhloDialects(registry); stablehlo::registerAllDialects(registry); + registry.insert(); return failed(MlirOptMain(argc, argv, "MLIR HLO pass driver\n", registry)); } diff --git a/xla/service/spmd/shardy/BUILD b/xla/service/spmd/shardy/BUILD index 5ed5456e3b504..658bd9dd6472d 100644 --- a/xla/service/spmd/shardy/BUILD +++ b/xla/service/spmd/shardy/BUILD @@ -121,6 +121,7 @@ xla_cc_binary( deps = [ "//xla/mlir_hlo", "//xla/mlir_hlo:mhlo_passes", + "//xla/mlir_hlo:stablehlo_extension_passes", "//xla/service/spmd/shardy/mhlo_round_trip:export_callback_custom_calls", "//xla/service/spmd/shardy/mhlo_round_trip:export_ops", "//xla/service/spmd/shardy/mhlo_round_trip:export_shardings",