Skip to content

Commit

Permalink
Share StableHLO/MHLO pretty printers for ConstantOp
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 625273473
  • Loading branch information
tensorflower-gardener authored and copybara-github committed Apr 16, 2024
1 parent 38f5ed1 commit 5c3f217
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 39 deletions.
123 changes: 123 additions & 0 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,80 @@ diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists
add_subdirectory(integrations)
add_subdirectory(reference)
add_subdirectory(tests)
diff --ruN a/stablehlo/stablehlo/dialect/AssemblyFormat.cpp b/stablehlo/stablehlo/dialect/AssemblyFormat.cpp
--- stablehlo/stablehlo/dialect/AssemblyFormat.cpp
+++ stablehlo/stablehlo/dialect/AssemblyFormat.cpp
@@ -15,6 +15,7 @@

#include "stablehlo/dialect/AssemblyFormat.h"

+#include <cassert>
#include <cstdint>
#include <optional>
#include <string>
@@ -130,6 +131,42 @@
for (Type& t : opTypes) typePtrs.push_back(&t);

return detail::parseSameOperandsAndResultTypeImpl(parser, typePtrs, result);
+}
+
+void printConstantOp(OpAsmPrinter& p, Operation* op, ElementsAttr value) {
+ assert(op->getNumResults() == 1);
+ // If not all types are the same, use generic form.
+ if (value.getType() != op->getResultTypes().front()) {
+ p.printGenericOp(op, /*printOpName=*/false);
+ return;
+ }
+
+ p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"});
+ p << ' ';
+ p.printStrippedAttrOrType(value);
+}
+
+ParseResult parseConstantOp(OpAsmParser& parser, OperationState& result) {
+ // Parse the generic form.
+ if (succeeded(parser.parseOptionalLParen())) {
+ if (parser.parseRParen()) return failure();
+ if (parser.parseOptionalAttrDict(result.attributes)) return failure();
+ if (parser.parseColon() || parser.parseLParen() || parser.parseRParen() ||
+ parser.parseArrow())
+ return failure();
+ Type resultTy;
+ if (parser.parseType(resultTy)) return failure();
+ result.addTypes(resultTy);
+ return success();
+ }
+
+ ElementsAttr valueAttr;
+ if (parser.parseOptionalAttrDict(result.attributes)) return failure();
+ if (parser.parseCustomAttributeWithFallback(valueAttr, Type{}, "value",
+ result.attributes))
+ return failure();
+ result.addTypes(valueAttr.getType());
+ return success();
}

void printTupleOpType(OpAsmPrinter& p, Operation*, TypeRange operands,
diff --ruN a/stablehlo/stablehlo/dialect/AssemblyFormat.h b/stablehlo/stablehlo/dialect/AssemblyFormat.h
--- stablehlo/stablehlo/dialect/AssemblyFormat.h
+++ stablehlo/stablehlo/dialect/AssemblyFormat.h
@@ -101,6 +101,16 @@
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands,
SmallVectorImpl<Type>& opTypes, Type& result);

+// Print a `constant` op.
+//
+// op ::= attr-dict $value
+//
+// When the `value` and `output` have different type, it just uses the default
+// operator assembly format as a fallback.
+void printConstantOp(OpAsmPrinter& p, Operation* op, ElementsAttr value);
+
+ParseResult parseConstantOp(OpAsmParser& parser, OperationState& result);
+
// TuplesOp - only print result type. Operand type is trivially inferrable.
//
// Inferring operand types from tuple type:
diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.cpp b/stablehlo/stablehlo/dialect/StablehloOps.cpp
--- stablehlo/stablehlo/dialect/StablehloOps.cpp
+++ stablehlo/stablehlo/dialect/StablehloOps.cpp
Expand All @@ -194,6 +268,55 @@ diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.cpp b/stablehlo/stablehlo/
OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
assert(adaptor.getOperands().empty() && "constant has no operands");

@@ -311,44 +321,11 @@
}

ParseResult ConstantOp::parse(OpAsmParser& parser, OperationState& result) {
- // Parse the generic form.
- if (succeeded(parser.parseOptionalLParen())) {
- if (parser.parseRParen()) return failure();
- if (parser.parseOptionalAttrDict(result.attributes)) return failure();
- if (parser.parseColon() || parser.parseLParen() || parser.parseRParen() ||
- parser.parseArrow())
- return failure();
- Type resultTy;
- if (parser.parseType(resultTy)) return failure();
- result.addTypes(resultTy);
- return success();
- }
-
- ElementsAttr valueAttr;
- if (parser.parseOptionalAttrDict(result.attributes)) return failure();
- if (parser.parseCustomAttributeWithFallback(valueAttr, Type{}, "value",
- result.attributes))
- return failure();
- result.addTypes(valueAttr.getType());
- return success();
-}
-
-/// Print a `constant` op.
-///
-/// op ::= attr-dict $value
-///
-/// When the `value` and `output` have different type, it just uses the default
-/// operator assembly format as a fallback.
+ return hlo::parseConstantOp(parser, result);
+}
+
void ConstantOp::print(::mlir::OpAsmPrinter& p) {
- // If not all types are the same, use generic form.
- if (getValue().getType() != getType()) {
- p.printGenericOp(getOperation(), /*printOpName=*/false);
- return;
- }
-
- p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
- p << ' ';
- p.printStrippedAttrOrType(getValueAttr());
+ hlo::printConstantOp(p, getOperation(), getValue());
}

//===----------------------------------------------------------------------===//
diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.td b/stablehlo/stablehlo/dialect/StablehloOps.td
--- stablehlo/stablehlo/dialect/StablehloOps.td
+++ stablehlo/stablehlo/dialect/StablehloOps.td
Expand Down
41 changes: 2 additions & 39 deletions xla/mlir_hlo/mhlo/IR/hlo_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -650,48 +650,11 @@ bool ConstantOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
}

ParseResult ConstantOp::parse(OpAsmParser& parser, OperationState& result) {
// Parse the generic form.
if (succeeded(parser.parseOptionalLParen())) {
if (parser.parseRParen()) return failure();
if (parser.parseOptionalAttrDict(result.attributes)) return failure();
if (parser.parseColon() || parser.parseLParen() || parser.parseRParen() ||
parser.parseArrow())
return failure();
Type resultTy;
if (parser.parseType(resultTy)) {
return failure();
}
result.addTypes(resultTy);
return success();
}

ElementsAttr valueAttr;
if (parser.parseOptionalAttrDict(result.attributes)) return failure();

if (parser.parseCustomAttributeWithFallback(valueAttr, Type{}, "value",
result.attributes)) {
return failure();
}
result.addTypes(valueAttr.getType());
return success();
return hlo::parseConstantOp(parser, result);
}

/// Print a `constant` op.
///
/// op ::= attr-dict $value
///
/// When the `value` and `output` have different type, it just uses the default
/// operator assembly format as a fallback.
void ConstantOp::print(::mlir::OpAsmPrinter& p) {
// If not all types are the same, use generic form.
if (getValue().getType() != getType()) {
p.printGenericOp(getOperation(), /*printOpName=*/false);
return;
}

p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
p << ' ';
p.printStrippedAttrOrType(getValueAttr());
hlo::printConstantOp(p, getOperation(), getValue());
}

//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit 5c3f217

Please sign in to comment.