Skip to content

Commit

Permalink
#sdy support StableHLO from refining Shardy ops with polymorphic shapes
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 716261489
  • Loading branch information
bartchr808 authored and Google-ML-Automation committed Jan 16, 2025
1 parent 1e94534 commit 79ffbf1
Show file tree
Hide file tree
Showing 10 changed files with 969 additions and 1 deletion.
229 changes: 229 additions & 0 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,18 @@ diff --ruN a/stablehlo/stablehlo/tests/interpret/chlo/ragged_dot.mlir b/stablehl
+ ]> : tensor<2x11x7xf32>
+ func.return
+}
diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir
--- stablehlo/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir
+++ stablehlo/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir
@@ -36,7 +36,7 @@

// -----

-// expected-error @+1 {{number of refinements must match number of function operands 6 vs 1}}
+// expected-error @+1 {{number of refinements must match number of op operands 6 vs 1}}
func.func @refine_arguments_invalid_arg_num_mismatch(%arg0: tensor<f32>) {
return
}
diff --ruN a/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp b/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp
--- stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp
+++ stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp
Expand Down Expand Up @@ -593,4 +605,221 @@ diff --ruN a/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp b/stable
}
} // namespace

diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineArguments.cpp b/stablehlo/stablehlo/transforms/StablehloRefineArguments.cpp
--- stablehlo/stablehlo/transforms/StablehloRefineArguments.cpp
+++ stablehlo/stablehlo/transforms/StablehloRefineArguments.cpp
@@ -72,78 +72,6 @@
Type type = mlir::parseType(shape, context);
if (!type) return module->emitOpError("Invalid type string: ") << shape;
refinedTypes.push_back(type);
- }
- return success();
-}
-
-LogicalResult refinementError(func::FuncOp func, int64_t idx, Type argType,
- Type refinedType, StringRef msg) {
- return func.emitOpError()
- << "invalid refinement for argument " << idx << ", refinement " << msg
- << " in " << mlir::debugString(argType) << " -> "
- << mlir::debugString(refinedType);
-}
-
-// Validates refinement types:
-// - A type refinement must be specified for each operand
-// - Refinement types that match operand types are skipped
-// - Refinement types that do not match operands must be refining tensors
-// - Refined tensor types must be ranked, operand type can be unranked
-// - Refined tensor types must match operand type for all static dimensions
-//
-LogicalResult validateRefinedTypes(func::FuncOp func, TypeRange refinedTypes) {
- // Validate refined shapes
- if (func.getNumArguments() != refinedTypes.size()) {
- return func.emitOpError(
- "number of refinements must match number of function operands ")
- << refinedTypes.size() << " vs " << func.getNumArguments();
- }
-
- // Validate that refinements are valid
- auto argTypes = func.getArgumentTypes();
- for (int64_t i = 0; i < func.getNumArguments(); ++i) {
- Type type = argTypes[i];
- Type refinedType = refinedTypes[i];
-
- // Always allow skipping refinement
- if (type == refinedType) continue;
-
- // If mismatched, must be tensor types
- auto tensorType = dyn_cast<TensorType>(type);
- auto refinedTensorType = dyn_cast<TensorType>(refinedType);
- if (!tensorType || !refinedTensorType) {
- return refinementError(func, i, type, refinedType, "must be a tensor");
- }
-
- // Check that element types match
- if (tensorType.getElementType() != refinedTensorType.getElementType()) {
- return refinementError(func, i, type, refinedType,
- "element types must match");
- }
-
- // Refined rank cannot be unranked if mismatch
- if (isa<UnrankedTensorType>(refinedType)) {
- return refinementError(func, i, type, refinedType, "must be ranked");
- }
-
- // Unranked operands can be refined to anything
- if (!tensorType.hasRank()) continue;
-
- // Validate ranks match if ranked (must allow unranked tensorType)
- if (tensorType.getRank() != refinedTensorType.getRank()) {
- return refinementError(func, i, type, refinedType,
- "rank must match operand rank");
- }
-
- // Validate static dimension sizes match
- for (auto [dimSize, refinedDimSize] :
- llvm::zip(tensorType.getShape(), refinedTensorType.getShape())) {
- if (!ShapedType::isDynamic(dimSize) && dimSize != refinedDimSize) {
- return refinementError(
- func, i, type, refinedType,
- "dimension sizes must match for static dimensions");
- }
- }
}
return success();
}
@@ -219,9 +147,74 @@

} // namespace

+LogicalResult refinementError(Operation* op, int64_t idx, Type argType,
+ Type refinedType, StringRef msg) {
+ return op->emitOpError()
+ << "invalid refinement for argument " << idx << ", refinement " << msg
+ << " in " << mlir::debugString(argType) << " -> "
+ << mlir::debugString(refinedType);
+}
+
+LogicalResult validateRefinedTypes(Operation* op, TypeRange argTypes, TypeRange refinedTypes) {
+ // Validate refined shapes
+ if (argTypes.size() != refinedTypes.size()) {
+ return op->emitOpError(
+ "number of refinements must match number of op operands ")
+ << refinedTypes.size() << " vs " << argTypes.size();
+ }
+
+ // Validate that refinements are valid
+ for (int64_t i = 0; i < argTypes.size(); ++i) {
+ Type type = argTypes[i];
+ Type refinedType = refinedTypes[i];
+
+ // Always allow skipping refinement
+ if (type == refinedType) continue;
+
+ // If mismatched, must be tensor types
+ auto tensorType = dyn_cast<TensorType>(type);
+ auto refinedTensorType = dyn_cast<TensorType>(refinedType);
+ if (!tensorType || !refinedTensorType) {
+ return refinementError(op, i, type, refinedType, "must be a tensor");
+ }
+
+ // Check that element types match
+ if (tensorType.getElementType() != refinedTensorType.getElementType()) {
+ return refinementError(op, i, type, refinedType,
+ "element types must match");
+ }
+
+ // Refined rank cannot be unranked if mismatch
+ if (isa<UnrankedTensorType>(refinedType)) {
+ return refinementError(op, i, type, refinedType, "must be ranked");
+ }
+
+ // Unranked operands can be refined to anything
+ if (!tensorType.hasRank()) continue;
+
+ // Validate ranks match if ranked (must allow unranked tensorType)
+ if (tensorType.getRank() != refinedTensorType.getRank()) {
+ return refinementError(op, i, type, refinedType,
+ "rank must match operand rank");
+ }
+
+ // Validate static dimension sizes match
+ for (auto [dimSize, refinedDimSize] :
+ llvm::zip(tensorType.getShape(), refinedTensorType.getShape())) {
+ if (!ShapedType::isDynamic(dimSize) && dimSize != refinedDimSize) {
+ return refinementError(
+ op, i, type, refinedType,
+ "dimension sizes must match for static dimensions");
+ }
+ }
+ }
+ return success();
+}
+
LogicalResult refineArguments(func::FuncOp func, TypeRange refinedTypes) {
// Verify that refinements are valid
- if (failed(validateRefinedTypes(func, refinedTypes))) return failure();
+ if (failed(validateRefinedTypes(func, func.getArgumentTypes(), refinedTypes)))
+ return failure();

// Wrap refined operands in operand wrapper to keep IR valid for refinement
wrapRefinedOperands(func, refinedTypes);
diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
--- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
@@ -109,7 +109,13 @@
// their operands and results. Any operand type in these ops can change
// within what's supported by `inferMostSpecificType` without breaking
// verification of the op.
- if (isa<chlo::ChloDialect, StablehloDialect>(user->getDialect()))
+ if (isa<chlo::ChloDialect, StablehloDialect>(
+ user->getDialect()))
+ continue;
+ // TODO(bartchr): Consider if the dialect allow-listing approach is too
+ // strict. In the meantime, allow some shape interop with the shardy
+ // dialect.
+ if (user->getDialect()->getNamespace() == "sdy")
continue;

// Simply changing operand type of `func.return` won't work because
diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.h b/stablehlo/stablehlo/transforms/StablehloRefineShapes.h
--- stablehlo/stablehlo/transforms/StablehloRefineShapes.h
+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.h
@@ -16,18 +16,37 @@
#ifndef STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H
#define STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H

+#include <cstdint>
+
+#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "stablehlo/dialect/Base.h"

namespace mlir {
namespace stablehlo {
+
+// Emits an error message for invalid refinement.
+LogicalResult refinementError(Operation* op, int64_t idx, Type argType,
+ Type refinedType, StringRef msg);
+
+// Validates refinement types:
+// - A type refinement must be specified for each operand
+// - Refinement types that match operand types are skipped
+// - Refinement types that do not match operands must be refining tensors
+// - Refined tensor types must be ranked, operand type can be unranked
+// - Refined tensor types must match operand type for all static dimensions
+LogicalResult validateRefinedTypes(Operation* op, TypeRange argTypes,
+ TypeRange refinedTypes);

// Gets a FuncOp that --stablehlo-refine-shapes will run on.
// Returns a nullptr and emits appropriate errors if such a function cannot

5 changes: 5 additions & 0 deletions xla/mlir_hlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,7 @@ cc_binary(
"@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@stablehlo//:register",
],
)
Expand Down Expand Up @@ -1080,11 +1081,14 @@ cc_library(
name = "stablehlo_extension_passes",
srcs = [
"stablehlo_ext/transforms/chlo_recompose_ops.cpp",
"stablehlo_ext/transforms/sdy_refine_shapes.cpp",
"stablehlo_ext/transforms/stablehlo_canonicalize_dynamism.cpp",
"stablehlo_ext/transforms/stablehlo_refine_shapes.cpp",
],
hdrs = [
"stablehlo_ext/transforms/passes.h",
"stablehlo_ext/transforms/sdy_refine_shapes.h",
"stablehlo_ext/transforms/stablehlo_refine_shapes.h",
],
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
Expand All @@ -1100,6 +1104,7 @@ cc_library(
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@stablehlo//:base",
"@stablehlo//:chlo_ops",
"@stablehlo//:stablehlo_ops",
Expand Down
1 change: 1 addition & 0 deletions xla/mlir_hlo/stablehlo_ext/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ add_mlir_dialect_library(StablehloExtensionPasses
chlo_recompose_ops.cpp
stablehlo_canonicalize_dynamism.cpp
stablehlo_refine_shapes.cpp
sdy_refine_shapes.cpp

DEPENDS
StablehloExtensionPassesIncGen
Expand Down
Loading

0 comments on commit 79ffbf1

Please sign in to comment.