Skip to content

Commit

Permalink
further wildcard cleanups (pytorch#16041)
Browse files Browse the repository at this point in the history
Summary:
Some cleanup to wildcard handling, including one bugfix: previously, we were not considering writes to the wildcard set as part of the potential write set for nodes.
Pull Request resolved: pytorch#16041

Differential Revision: D13705738

Pulled By: suo

fbshipit-source-id: acb8ccbaa70fe47445577ddf24a69f84630de411
  • Loading branch information
suo authored and facebook-github-bot committed Jan 17, 2019
1 parent 962f3f4 commit 431a34f
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 24 deletions.
11 changes: 11 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9440,6 +9440,17 @@ def fn(lst):

self.checkScript(fn, ([torch.ones(2) + 2, torch.ones(2)],))

def test_view_write(self):
def fn(x, y):
l = []
l.append(x)
x_view = l[0]
a = x + x
x_view.add_(y)
b = x + x
return a == b
self.checkScript(fn, (torch.rand(2, 3), torch.rand(2, 3)))


class MnistNet(nn.Module):
def __init__(self):
Expand Down
85 changes: 61 additions & 24 deletions torch/csrc/jit/passes/alias_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,6 @@ AliasDb::AliasDb(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {
aliasToValue_[aliasSet].insert(value);
}
}
// - Set of all nodes with a wildcard
buildWildcardIndex(graph_->block());
}

void AliasDb::buildWildcardIndex(const Block* b) {
for (const auto node : b->nodes()) {
for (const auto block : node->blocks()) {
buildWildcardIndex(block);
}

if (hasWildcardImpl(node)) {
wildcardNodes_.insert(node);
}
}
}

bool AliasDb::hasWildcard(const Node* n) const {
Expand Down Expand Up @@ -111,7 +97,7 @@ bool AliasDb::hasWritersBefore(const Node* n) const {
}
const auto writers = getWriters(n);
return std::any_of(writers.cbegin(), writers.cend(), [&](const Node* writer) {
return writer->isBefore(n);
return isBeforeSameGraph(writer, n);
});
}

Expand Down Expand Up @@ -185,6 +171,15 @@ std::unordered_set<Node*> AliasDb::getWriters(const Node* n) const {
}
}
}

// A write to the wildcard set should be considered a write to `n`
if (aliasToWrites_.count(AliasInfo::wildcardSet())) {
const auto& wildcardWriters = aliasToWrites_.at(AliasInfo::wildcardSet());
for (auto writer : wildcardWriters) {
writers.insert(writer);
}
}

return writers;
}

Expand Down Expand Up @@ -447,6 +442,10 @@ void AliasDb::analyze(Node* node) {

addAlias(actual, outputAlias);
}
// Keep the wildcard index up to date.
if (hasWildcardImpl(node)) {
wildcardNodes_.insert(node);
}
}

void AliasDb::analyzeIf(Node* node) {
Expand Down Expand Up @@ -508,7 +507,11 @@ void AliasDb::analyzeLoop(Node* node) {
}

void AliasDb::analyzeSubgraph(Node* node) {
const auto subgraphBlock = node->g(attr::Subgraph)->block();
const auto subgraph = node->g(attr::Subgraph).get();

subgraphToOwner_.insert({subgraph, node});

const auto subgraphBlock = subgraph->block();
mapAliases(subgraphBlock->inputs(), node->inputs());

analyze(subgraphBlock);
Expand Down Expand Up @@ -789,6 +792,7 @@ class AliasDb::WorkingSet {
// outside), then return nullptr. Since we can only reorder nodes within a
// block, `target` would be irrelevant.
static Node* findSameBlock(Node* target, Node* n) {
JIT_ASSERT(target->owningGraph() == n->owningGraph());
if (target->owningBlock() == n->owningBlock()) {
return target;
} else {
Expand Down Expand Up @@ -927,20 +931,53 @@ void AliasDb::move(Node* toMove, Node* movePoint, MoveSide moveSide) {
}
}

c10::optional<const Node*> AliasDb::getLastWildcard() const {
auto it = std::max_element(
wildcardNodes_.cbegin(),
wildcardNodes_.cend(),
[this](const Node* a, const Node* b) { return isBeforeSameGraph(a, b); });
if (it != wildcardNodes_.end()) {
return *it;
} else {
return c10::nullopt;
}
}

bool AliasDb::hasUntrackedEffects(Node* node) const {
bool touchesWildcard = false;
if (!wildcardNodes_.empty()) {
auto lastWildcard = *wildcardNodes_.begin();
for (const auto wildcard : wildcardNodes_) {
if (wildcard->isAfter(lastWildcard)) {
lastWildcard = wildcard;
}
}
if (const auto lastWildcard = getLastWildcard()) {
touchesWildcard = hasWrites(node) &&
(node->isBefore(lastWildcard) || node == lastWildcard);
(isBeforeSameGraph(node, *lastWildcard) || node == *lastWildcard);
}

return writesToInputAlias(node) || touchesWildcard;
}

// Nodes must be in the same graph in order to do `isBefore` or `isAfter`. This
// traverses the subgraph "chain" upward until we find two nodes that share an
// owning graph.
//
// NOTE: this is n^2 in subgraph depth. Right now the maximum depth is like 2,
// but if we ever do huge nested subgraphs we'll need to reconsider this.
bool AliasDb::isBeforeSameGraph(const Node* a, const Node* b) const {
auto lhs = a;
while (true) {
auto rhs = b;
while (true) {
if (lhs->owningGraph() == rhs->owningGraph()) {
return lhs->isBefore(rhs);
}
if (!subgraphToOwner_.count(rhs->owningGraph())) {
break;
}
rhs = subgraphToOwner_.at(rhs->owningGraph());
}
if (!subgraphToOwner_.count(lhs->owningGraph())) {
break;
}
lhs = subgraphToOwner_.at(lhs->owningGraph());
}
JIT_ASSERT(false);
}
} // namespace jit
} // namespace torch
5 changes: 5 additions & 0 deletions torch/csrc/jit/passes/alias_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class AliasDb {

// Does `n` use or write to any wildcard aliases?
bool hasWildcard(const Node* n) const;
// Returns nullopt if there are no wildcard nodes
c10::optional<const Node*> getLastWildcard() const;

// Does `n` write to a value that may alias one of the graph inputs?
bool writesToInputAlias(Node* n) const;
Expand Down Expand Up @@ -113,13 +115,16 @@ class AliasDb {
bool hasWildcardImpl(const Node* n) const;
bool writesTo(Node* n, const Value* v) const;

bool isBeforeSameGraph(const Node* lhs, const Node* rhs) const;

std::shared_ptr<Graph> graph_;
Symbol latestSymbol_ = Symbol::fromQualString("alias::0");
std::unordered_map<const Value*, AliasInfo> valueToAlias_;
std::unordered_map<Symbol, std::unordered_set<const Value*>> aliasToValue_;
std::unordered_map<Symbol, std::unordered_set<Node*>> aliasToWrites_;
std::unordered_set<const Node*> wildcardNodes_;
std::unordered_set<Symbol> graphInputAliases_;
std::unordered_map<const Graph*, const Node*> subgraphToOwner_;
};

inline TORCH_API AliasDb AliasAnalysis(std::shared_ptr<Graph> graph) {
Expand Down

0 comments on commit 431a34f

Please sign in to comment.