Skip to content

Commit

Permalink
[ViewOpGraph] Improve GraphViz output (#125509)
Browse files Browse the repository at this point in the history
This patch improves the GraphViz output of ViewOpGraph
(--view-op-graph).

- Switch to rectangular record-based nodes, inspired by a similar
visualization in [Glow](https://github.com/pytorch/glow). Rectangles
make more efficient use of space when printing text.
- Add input and output ports for each operand and result, and remove
edge labels.
- Switch to a muted color palette to reduce eye strain.
  • Loading branch information
ehein6 authored Feb 7, 2025
1 parent 1611059 commit 1f67070
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 109 deletions.
176 changes: 137 additions & 39 deletions mlir/lib/Transforms/ViewOpGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/IndentedOstream.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/GraphWriter.h"
#include <map>
Expand All @@ -29,7 +30,7 @@ using namespace mlir;

static const StringRef kLineStyleControlFlow = "dashed";
static const StringRef kLineStyleDataFlow = "solid";
static const StringRef kShapeNode = "ellipse";
static const StringRef kShapeNode = "Mrecord";
static const StringRef kShapeNone = "plain";

/// Return the size limits for eliding large attributes.
Expand All @@ -49,16 +50,25 @@ static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
return buf;
}

/// Escape special characters such as '\n' and quotation marks.
static std::string escapeString(std::string str) {
return strFromOs([&](raw_ostream &os) { os.write_escaped(str); });
}

/// Put quotation marks around a given string.
static std::string quoteString(const std::string &str) {
return "\"" + str + "\"";
}

/// For Graphviz record nodes:
/// " Braces, vertical bars and angle brackets must be escaped with a backslash
/// character if you wish them to appear as a literal character "
std::string escapeLabelString(const std::string &str) {
std::string buf;
llvm::raw_string_ostream os(buf);
for (char c : str) {
if (llvm::is_contained({'{', '|', '<', '}', '>', '\n', '"'}, c))
os << '\\';
os << c;
}
return buf;
}

using AttributeMap = std::map<std::string, std::string>;

namespace {
Expand All @@ -79,6 +89,12 @@ struct Node {
std::optional<int> clusterId;
};

struct DataFlowEdge {
Value value;
Node node;
std::string port;
};

/// This pass generates a Graphviz dataflow visualization of an MLIR operation.
/// Note: See https://www.graphviz.org/doc/info/lang.html for more information
/// about the Graphviz DOT language.
Expand Down Expand Up @@ -107,7 +123,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
private:
/// Generate a color mapping that will color every operation with the same
/// name the same way. It'll interpolate the hue in the HSV color-space,
/// attempting to keep the contrast suitable for black text.
/// using muted colors that provide good contrast for black text.
template <typename T>
void initColorMapping(T &irEntity) {
backgroundColors.clear();
Expand All @@ -120,17 +136,19 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
});
for (auto indexedOps : llvm::enumerate(ops)) {
double hue = ((double)indexedOps.index()) / ops.size();
// Use lower saturation (0.3) and higher value (0.95) for better
// readability
backgroundColors[indexedOps.value()->getName()].second =
std::to_string(hue) + " 1.0 1.0";
std::to_string(hue) + " 0.3 0.95";
}
}

/// Emit all edges. This function should be called after all nodes have been
/// emitted.
void emitAllEdgeStmts() {
if (printDataFlowEdges) {
for (const auto &[value, node, label] : dataFlowEdges) {
emitEdgeStmt(valueToNode[value], node, label, kLineStyleDataFlow);
for (const auto &e : dataFlowEdges) {
emitEdgeStmt(valueToNode[e.value], e.node, e.port, kLineStyleDataFlow);
}
}

Expand All @@ -147,8 +165,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
os.indent();
// Emit invisible anchor node from/to which arrows can be drawn.
Node anchorNode = emitNodeStmt(" ", kShapeNone);
os << attrStmt("label", quoteString(escapeString(std::move(label))))
<< ";\n";
os << attrStmt("label", quoteString(label)) << ";\n";
builder();
os.unindent();
os << "}\n";
Expand Down Expand Up @@ -176,16 +193,17 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {

// Always emit splat attributes.
if (isa<SplatElementsAttr>(attr)) {
attr.print(os);
os << escapeLabelString(
strFromOs([&](raw_ostream &os) { attr.print(os); }));
return;
}

// Elide "big" elements attributes.
auto elements = dyn_cast<ElementsAttr>(attr);
if (elements && elements.getNumElements() > largeAttrLimit) {
os << std::string(elements.getShapedType().getRank(), '[') << "..."
<< std::string(elements.getShapedType().getRank(), ']') << " : "
<< elements.getType();
<< std::string(elements.getShapedType().getRank(), ']') << " : ";
emitMlirType(os, elements.getType());
return;
}

Expand All @@ -199,27 +217,43 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
std::string buf;
llvm::raw_string_ostream ss(buf);
attr.print(ss);
os << truncateString(buf);
os << escapeLabelString(truncateString(buf));
}

// Print a truncated and escaped MLIR type to `os`.
void emitMlirType(raw_ostream &os, Type type) {
std::string buf;
llvm::raw_string_ostream ss(buf);
type.print(ss);
os << escapeLabelString(truncateString(buf));
}

// Print a truncated and escaped MLIR operand to `os`.
void emitMlirOperand(raw_ostream &os, Value operand) {
operand.printAsOperand(os, OpPrintingFlags());
}

/// Append an edge to the list of edges.
/// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
void emitEdgeStmt(Node n1, Node n2, std::string label, StringRef style) {
void emitEdgeStmt(Node n1, Node n2, std::string port, StringRef style) {
AttributeMap attrs;
attrs["style"] = style.str();
// Do not label edges that start/end at a cluster boundary. Such edges are
// clipped at the boundary, but labels are not. This can lead to labels
// floating around without any edge next to them.
if (!n1.clusterId && !n2.clusterId)
attrs["label"] = quoteString(escapeString(std::move(label)));
// Use `ltail` and `lhead` to draw edges between clusters.
if (n1.clusterId)
attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId);
if (n2.clusterId)
attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId);

edges.push_back(strFromOs([&](raw_ostream &os) {
os << llvm::format("v%i -> v%i ", n1.id, n2.id);
os << "v" << n1.id;
if (!port.empty() && !n1.clusterId)
// Attach edge to south compass point of the result
os << ":res" << port << ":s";
os << " -> ";
os << "v" << n2.id;
if (!port.empty() && !n2.clusterId)
// Attach edge to north compass point of the operand
os << ":arg" << port << ":n";
emitAttrList(os, attrs);
}));
}
Expand All @@ -240,20 +274,30 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
StringRef background = "") {
int nodeId = ++counter;
AttributeMap attrs;
attrs["label"] = quoteString(escapeString(std::move(label)));
attrs["label"] = quoteString(label);
attrs["shape"] = shape.str();
if (!background.empty()) {
attrs["style"] = "filled";
attrs["fillcolor"] = ("\"" + background + "\"").str();
attrs["fillcolor"] = quoteString(background.str());
}
os << llvm::format("v%i ", nodeId);
emitAttrList(os, attrs);
os << ";\n";
return Node(nodeId);
}

/// Generate a label for an operation.
std::string getLabel(Operation *op) {
std::string getValuePortName(Value operand) {
// Print value as an operand and omit the leading '%' character.
auto str = strFromOs([&](raw_ostream &os) {
operand.printAsOperand(os, OpPrintingFlags());
});
// Replace % and # with _
std::replace(str.begin(), str.end(), '%', '_');
std::replace(str.begin(), str.end(), '#', '_');
return str;
}

std::string getClusterLabel(Operation *op) {
return strFromOs([&](raw_ostream &os) {
// Print operation name and type.
os << op->getName();
Expand All @@ -267,18 +311,73 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {

// Print attributes.
if (printAttrs) {
os << "\n";
os << "\\l";
for (const NamedAttribute &attr : op->getAttrs()) {
os << escapeLabelString(attr.getName().getValue().str()) << ": ";
emitMlirAttr(os, attr.getValue());
os << "\\l";
}
}
});
}

/// Generate a label for an operation.
std::string getRecordLabel(Operation *op) {
return strFromOs([&](raw_ostream &os) {
os << "{";

// Print operation inputs.
if (op->getNumOperands() > 0) {
os << "{";
auto operandToPort = [&](Value operand) {
os << "<arg" << getValuePortName(operand) << "> ";
emitMlirOperand(os, operand);
};
interleave(op->getOperands(), os, operandToPort, "|");
os << "}|";
}
// Print operation name and type.
os << op->getName() << "\\l";

// Print attributes.
if (printAttrs && !op->getAttrs().empty()) {
// Extra line break to separate attributes from the operation name.
os << "\\l";
for (const NamedAttribute &attr : op->getAttrs()) {
os << '\n' << attr.getName().getValue() << ": ";
os << attr.getName().getValue() << ": ";
emitMlirAttr(os, attr.getValue());
os << "\\l";
}
}

if (op->getNumResults() > 0) {
os << "|{";
auto resultToPort = [&](Value result) {
os << "<res" << getValuePortName(result) << "> ";
emitMlirOperand(os, result);
if (printResultTypes) {
os << " ";
emitMlirType(os, result.getType());
}
};
interleave(op->getResults(), os, resultToPort, "|");
os << "}";
}

os << "}";
});
}

/// Generate a label for a block argument.
std::string getLabel(BlockArgument arg) {
return "arg" + std::to_string(arg.getArgNumber());
return strFromOs([&](raw_ostream &os) {
os << "<res" << getValuePortName(arg) << "> ";
arg.printAsOperand(os, OpPrintingFlags());
if (printResultTypes) {
os << " ";
emitMlirType(os, arg.getType());
}
});
}

/// Process a block. Emit a cluster and one node per block argument and
Expand All @@ -287,14 +386,12 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
emitClusterStmt([&]() {
for (BlockArgument &blockArg : block.getArguments())
valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));

// Emit a node for each operation.
std::optional<Node> prevNode;
for (Operation &op : block) {
Node nextNode = processOperation(&op);
if (printControlFlowEdges && prevNode)
emitEdgeStmt(*prevNode, nextNode, /*label=*/"",
kLineStyleControlFlow);
emitEdgeStmt(*prevNode, nextNode, /*port=*/"", kLineStyleControlFlow);
prevNode = nextNode;
}
});
Expand All @@ -311,18 +408,19 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
for (Region &region : op->getRegions())
processRegion(region);
},
getLabel(op));
getClusterLabel(op));
} else {
node = emitNodeStmt(getLabel(op), kShapeNode,
node = emitNodeStmt(getRecordLabel(op), kShapeNode,
backgroundColors[op->getName()].second);
}

// Insert data flow edges originating from each operand.
if (printDataFlowEdges) {
unsigned numOperands = op->getNumOperands();
for (unsigned i = 0; i < numOperands; i++)
dataFlowEdges.push_back({op->getOperand(i), node,
numOperands == 1 ? "" : std::to_string(i)});
for (unsigned i = 0; i < numOperands; i++) {
auto operand = op->getOperand(i);
dataFlowEdges.push_back({operand, node, getValuePortName(operand)});
}
}

for (Value result : op->getResults())
Expand Down Expand Up @@ -352,7 +450,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
/// Mapping of SSA values to Graphviz nodes/clusters.
DenseMap<Value, Node> valueToNode;
/// Output for data flow edges is delayed until the end to handle cycles
std::vector<std::tuple<Value, Node, std::string>> dataFlowEdges;
std::vector<DataFlowEdge> dataFlowEdges;
/// Counter for generating unique node/subgraph identifiers.
int counter = 0;

Expand Down
30 changes: 15 additions & 15 deletions mlir/test/Transforms/print-op-graph-back-edges.mlir
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
// RUN: mlir-opt -view-op-graph %s -o %t 2>&1 | FileCheck -check-prefix=DFG %s

// DFG-LABEL: digraph G {
// DFG: compound = true;
// DFG: subgraph cluster_1 {
// DFG: v2 [label = " ", shape = plain];
// DFG: label = "builtin.module : ()\n";
// DFG: subgraph cluster_3 {
// DFG: v4 [label = " ", shape = plain];
// DFG: label = "";
// DFG: v5 [fillcolor = "0.000000 1.0 1.0", label = "arith.addi : (index)\n\noverflowFlags: #arith.overflow<none...", shape = ellipse, style = filled];
// DFG: v6 [fillcolor = "0.333333 1.0 1.0", label = "arith.constant : (index)\n\nvalue: 0 : index", shape = ellipse, style = filled];
// DFG: v7 [fillcolor = "0.333333 1.0 1.0", label = "arith.constant : (index)\n\nvalue: 1 : index", shape = ellipse, style = filled];
// DFG: }
// DFG: }
// DFG: v6 -> v5 [label = "0", style = solid];
// DFG: v7 -> v5 [label = "1", style = solid];
// DFG: }
// DFG-NEXT: compound = true;
// DFG-NEXT: subgraph cluster_1 {
// DFG-NEXT: v2 [label = " ", shape = plain];
// DFG-NEXT: label = "builtin.module : ()\l";
// DFG-NEXT: subgraph cluster_3 {
// DFG-NEXT: v4 [label = " ", shape = plain];
// DFG-NEXT: label = "";
// DFG-NEXT: v5 [fillcolor = "0.000000 0.3 0.95", label = "{{\{\{}}<arg_c0> %c0|<arg_c1> %c1}|arith.addi\l\loverflowFlags: #arith.overflow\<none...\l|{<res_0> %0 index}}", shape = Mrecord, style = filled];
// DFG-NEXT: v6 [fillcolor = "0.333333 0.3 0.95", label = "{arith.constant\l\lvalue: 0 : index\l|{<res_c0> %c0 index}}", shape = Mrecord, style = filled];
// DFG-NEXT: v7 [fillcolor = "0.333333 0.3 0.95", label = "{arith.constant\l\lvalue: 1 : index\l|{<res_c1> %c1 index}}", shape = Mrecord, style = filled];
// DFG-NEXT: }
// DFG-NEXT: }
// DFG-NEXT: v6:res_c0:s -> v5:arg_c0:n[style = solid];
// DFG-NEXT: v7:res_c1:s -> v5:arg_c1:n[style = solid];
// DFG-NEXT: }

module {
%add = arith.addi %c0, %c1 : index
Expand Down
Loading

0 comments on commit 1f67070

Please sign in to comment.